simple_prompt_transform.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341
  1. import enum
  2. import json
  3. import os
  4. from collections.abc import Mapping, Sequence
  5. from typing import TYPE_CHECKING, Any, Optional, cast
  6. from core.app.app_config.entities import PromptTemplateEntity
  7. from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
  8. from core.file import file_manager
  9. from core.memory.token_buffer_memory import TokenBufferMemory
  10. from core.model_runtime.entities.message_entities import (
  11. ImagePromptMessageContent,
  12. PromptMessage,
  13. PromptMessageContent,
  14. SystemPromptMessage,
  15. TextPromptMessageContent,
  16. UserPromptMessage,
  17. )
  18. from core.prompt.entities.advanced_prompt_entities import MemoryConfig
  19. from core.prompt.prompt_transform import PromptTransform
  20. from core.prompt.utils.prompt_template_parser import PromptTemplateParser
  21. from models.model import AppMode
  22. if TYPE_CHECKING:
  23. from core.file.models import File
  24. class ModelMode(enum.StrEnum):
  25. COMPLETION = "completion"
  26. CHAT = "chat"
  27. @classmethod
  28. def value_of(cls, value: str) -> "ModelMode":
  29. """
  30. Get value of given mode.
  31. :param value: mode value
  32. :return: mode
  33. """
  34. for mode in cls:
  35. if mode.value == value:
  36. return mode
  37. raise ValueError(f"invalid mode value {value}")
  38. prompt_file_contents: dict[str, Any] = {}
  39. class SimplePromptTransform(PromptTransform):
  40. """
  41. Simple Prompt Transform for Chatbot App Basic Mode.
  42. """
  43. def get_prompt(
  44. self,
  45. app_mode: AppMode,
  46. prompt_template_entity: PromptTemplateEntity,
  47. inputs: Mapping[str, str],
  48. query: str,
  49. files: Sequence["File"],
  50. context: Optional[str],
  51. memory: Optional[TokenBufferMemory],
  52. model_config: ModelConfigWithCredentialsEntity,
  53. image_detail_config: Optional[ImagePromptMessageContent.DETAIL] = None,
  54. ) -> tuple[list[PromptMessage], Optional[list[str]]]:
  55. inputs = {key: str(value) for key, value in inputs.items()}
  56. model_mode = ModelMode.value_of(model_config.mode)
  57. if model_mode == ModelMode.CHAT:
  58. prompt_messages, stops = self._get_chat_model_prompt_messages(
  59. app_mode=app_mode,
  60. pre_prompt=prompt_template_entity.simple_prompt_template or "",
  61. inputs=inputs,
  62. query=query,
  63. files=files,
  64. context=context,
  65. memory=memory,
  66. model_config=model_config,
  67. image_detail_config=image_detail_config,
  68. )
  69. else:
  70. prompt_messages, stops = self._get_completion_model_prompt_messages(
  71. app_mode=app_mode,
  72. pre_prompt=prompt_template_entity.simple_prompt_template or "",
  73. inputs=inputs,
  74. query=query,
  75. files=files,
  76. context=context,
  77. memory=memory,
  78. model_config=model_config,
  79. image_detail_config=image_detail_config,
  80. )
  81. return prompt_messages, stops
  82. def _get_prompt_str_and_rules(
  83. self,
  84. app_mode: AppMode,
  85. model_config: ModelConfigWithCredentialsEntity,
  86. pre_prompt: str,
  87. inputs: dict,
  88. query: Optional[str] = None,
  89. context: Optional[str] = None,
  90. histories: Optional[str] = None,
  91. ) -> tuple[str, dict]:
  92. # get prompt template
  93. prompt_template_config = self.get_prompt_template(
  94. app_mode=app_mode,
  95. provider=model_config.provider,
  96. model=model_config.model,
  97. pre_prompt=pre_prompt,
  98. has_context=context is not None,
  99. query_in_prompt=query is not None,
  100. with_memory_prompt=histories is not None,
  101. )
  102. variables = {k: inputs[k] for k in prompt_template_config["custom_variable_keys"] if k in inputs}
  103. for v in prompt_template_config["special_variable_keys"]:
  104. # support #context#, #query# and #histories#
  105. if v == "#context#":
  106. variables["#context#"] = context or ""
  107. elif v == "#query#":
  108. variables["#query#"] = query or ""
  109. elif v == "#histories#":
  110. variables["#histories#"] = histories or ""
  111. prompt_template = prompt_template_config["prompt_template"]
  112. prompt = prompt_template.format(variables)
  113. return prompt, prompt_template_config["prompt_rules"]
  114. def get_prompt_template(
  115. self,
  116. app_mode: AppMode,
  117. provider: str,
  118. model: str,
  119. pre_prompt: str,
  120. has_context: bool,
  121. query_in_prompt: bool,
  122. with_memory_prompt: bool = False,
  123. ) -> dict:
  124. prompt_rules = self._get_prompt_rule(app_mode=app_mode, provider=provider, model=model)
  125. custom_variable_keys = []
  126. special_variable_keys = []
  127. prompt = ""
  128. for order in prompt_rules["system_prompt_orders"]:
  129. if order == "context_prompt" and has_context:
  130. prompt += prompt_rules["context_prompt"]
  131. special_variable_keys.append("#context#")
  132. elif order == "pre_prompt" and pre_prompt:
  133. prompt += pre_prompt + "\n"
  134. pre_prompt_template = PromptTemplateParser(template=pre_prompt)
  135. custom_variable_keys = pre_prompt_template.variable_keys
  136. elif order == "histories_prompt" and with_memory_prompt:
  137. prompt += prompt_rules["histories_prompt"]
  138. special_variable_keys.append("#histories#")
  139. if query_in_prompt:
  140. prompt += prompt_rules.get("query_prompt", "{{#query#}}")
  141. special_variable_keys.append("#query#")
  142. return {
  143. "prompt_template": PromptTemplateParser(template=prompt),
  144. "custom_variable_keys": custom_variable_keys,
  145. "special_variable_keys": special_variable_keys,
  146. "prompt_rules": prompt_rules,
  147. }
  148. def _get_chat_model_prompt_messages(
  149. self,
  150. app_mode: AppMode,
  151. pre_prompt: str,
  152. inputs: dict,
  153. query: str,
  154. context: Optional[str],
  155. files: Sequence["File"],
  156. memory: Optional[TokenBufferMemory],
  157. model_config: ModelConfigWithCredentialsEntity,
  158. image_detail_config: Optional[ImagePromptMessageContent.DETAIL] = None,
  159. ) -> tuple[list[PromptMessage], Optional[list[str]]]:
  160. prompt_messages: list[PromptMessage] = []
  161. # get prompt
  162. prompt, _ = self._get_prompt_str_and_rules(
  163. app_mode=app_mode,
  164. model_config=model_config,
  165. pre_prompt=pre_prompt,
  166. inputs=inputs,
  167. query=None,
  168. context=context,
  169. )
  170. if prompt and query:
  171. prompt_messages.append(SystemPromptMessage(content=prompt))
  172. if memory:
  173. prompt_messages = self._append_chat_histories(
  174. memory=memory,
  175. memory_config=MemoryConfig(
  176. window=MemoryConfig.WindowConfig(
  177. enabled=False,
  178. )
  179. ),
  180. prompt_messages=prompt_messages,
  181. model_config=model_config,
  182. )
  183. if query:
  184. prompt_messages.append(self._get_last_user_message(query, files, image_detail_config))
  185. else:
  186. prompt_messages.append(self._get_last_user_message(prompt, files, image_detail_config))
  187. return prompt_messages, None
  188. def _get_completion_model_prompt_messages(
  189. self,
  190. app_mode: AppMode,
  191. pre_prompt: str,
  192. inputs: dict,
  193. query: str,
  194. context: Optional[str],
  195. files: Sequence["File"],
  196. memory: Optional[TokenBufferMemory],
  197. model_config: ModelConfigWithCredentialsEntity,
  198. image_detail_config: Optional[ImagePromptMessageContent.DETAIL] = None,
  199. ) -> tuple[list[PromptMessage], Optional[list[str]]]:
  200. # get prompt
  201. prompt, prompt_rules = self._get_prompt_str_and_rules(
  202. app_mode=app_mode,
  203. model_config=model_config,
  204. pre_prompt=pre_prompt,
  205. inputs=inputs,
  206. query=query,
  207. context=context,
  208. )
  209. if memory:
  210. tmp_human_message = UserPromptMessage(content=prompt)
  211. rest_tokens = self._calculate_rest_token([tmp_human_message], model_config)
  212. histories = self._get_history_messages_from_memory(
  213. memory=memory,
  214. memory_config=MemoryConfig(
  215. window=MemoryConfig.WindowConfig(
  216. enabled=False,
  217. )
  218. ),
  219. max_token_limit=rest_tokens,
  220. human_prefix=prompt_rules.get("human_prefix", "Human"),
  221. ai_prefix=prompt_rules.get("assistant_prefix", "Assistant"),
  222. )
  223. # get prompt
  224. prompt, prompt_rules = self._get_prompt_str_and_rules(
  225. app_mode=app_mode,
  226. model_config=model_config,
  227. pre_prompt=pre_prompt,
  228. inputs=inputs,
  229. query=query,
  230. context=context,
  231. histories=histories,
  232. )
  233. stops = prompt_rules.get("stops")
  234. if stops is not None and len(stops) == 0:
  235. stops = None
  236. return [self._get_last_user_message(prompt, files, image_detail_config)], stops
  237. def _get_last_user_message(
  238. self,
  239. prompt: str,
  240. files: Sequence["File"],
  241. image_detail_config: Optional[ImagePromptMessageContent.DETAIL] = None,
  242. ) -> UserPromptMessage:
  243. if files:
  244. prompt_message_contents: list[PromptMessageContent] = []
  245. prompt_message_contents.append(TextPromptMessageContent(data=prompt))
  246. for file in files:
  247. prompt_message_contents.append(
  248. file_manager.to_prompt_message_content(file, image_detail_config=image_detail_config)
  249. )
  250. prompt_message = UserPromptMessage(content=prompt_message_contents)
  251. else:
  252. prompt_message = UserPromptMessage(content=prompt)
  253. return prompt_message
  254. def _get_prompt_rule(self, app_mode: AppMode, provider: str, model: str) -> dict:
  255. """
  256. Get simple prompt rule.
  257. :param app_mode: app mode
  258. :param provider: model provider
  259. :param model: model name
  260. :return:
  261. """
  262. prompt_file_name = self._prompt_file_name(app_mode=app_mode, provider=provider, model=model)
  263. # Check if the prompt file is already loaded
  264. if prompt_file_name in prompt_file_contents:
  265. return cast(dict, prompt_file_contents[prompt_file_name])
  266. # Get the absolute path of the subdirectory
  267. prompt_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "prompt_templates")
  268. json_file_path = os.path.join(prompt_path, f"{prompt_file_name}.json")
  269. # Open the JSON file and read its content
  270. with open(json_file_path, encoding="utf-8") as json_file:
  271. content = json.load(json_file)
  272. # Store the content of the prompt file
  273. prompt_file_contents[prompt_file_name] = content
  274. return cast(dict, content)
  275. def _prompt_file_name(self, app_mode: AppMode, provider: str, model: str) -> str:
  276. # baichuan
  277. is_baichuan = False
  278. if provider == "baichuan":
  279. is_baichuan = True
  280. else:
  281. baichuan_supported_providers = ["huggingface_hub", "openllm", "xinference"]
  282. if provider in baichuan_supported_providers and "baichuan" in model.lower():
  283. is_baichuan = True
  284. if is_baichuan:
  285. if app_mode == AppMode.COMPLETION:
  286. return "baichuan_completion"
  287. else:
  288. return "baichuan_chat"
  289. # common
  290. if app_mode == AppMode.COMPLETION:
  291. return "common_completion"
  292. else:
  293. return "common_chat"