simple_prompt_transform.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328
  1. import enum
  2. import json
  3. import os
  4. from collections.abc import Mapping, Sequence
  5. from typing import TYPE_CHECKING, Any, Optional, cast
  6. from core.app.app_config.entities import PromptTemplateEntity
  7. from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
  8. from core.file import file_manager
  9. from core.memory.token_buffer_memory import TokenBufferMemory
  10. from core.model_runtime.entities.message_entities import (
  11. PromptMessage,
  12. PromptMessageContent,
  13. SystemPromptMessage,
  14. TextPromptMessageContent,
  15. UserPromptMessage,
  16. )
  17. from core.prompt.entities.advanced_prompt_entities import MemoryConfig
  18. from core.prompt.prompt_transform import PromptTransform
  19. from core.prompt.utils.prompt_template_parser import PromptTemplateParser
  20. from models.model import AppMode
  21. if TYPE_CHECKING:
  22. from core.file.models import File
  23. class ModelMode(enum.StrEnum):
  24. COMPLETION = "completion"
  25. CHAT = "chat"
  26. @classmethod
  27. def value_of(cls, value: str) -> "ModelMode":
  28. """
  29. Get value of given mode.
  30. :param value: mode value
  31. :return: mode
  32. """
  33. for mode in cls:
  34. if mode.value == value:
  35. return mode
  36. raise ValueError(f"invalid mode value {value}")
  37. prompt_file_contents: dict[str, Any] = {}
  38. class SimplePromptTransform(PromptTransform):
  39. """
  40. Simple Prompt Transform for Chatbot App Basic Mode.
  41. """
  42. def get_prompt(
  43. self,
  44. app_mode: AppMode,
  45. prompt_template_entity: PromptTemplateEntity,
  46. inputs: Mapping[str, str],
  47. query: str,
  48. files: Sequence["File"],
  49. context: Optional[str],
  50. memory: Optional[TokenBufferMemory],
  51. model_config: ModelConfigWithCredentialsEntity,
  52. ) -> tuple[list[PromptMessage], Optional[list[str]]]:
  53. inputs = {key: str(value) for key, value in inputs.items()}
  54. model_mode = ModelMode.value_of(model_config.mode)
  55. if model_mode == ModelMode.CHAT:
  56. prompt_messages, stops = self._get_chat_model_prompt_messages(
  57. app_mode=app_mode,
  58. pre_prompt=prompt_template_entity.simple_prompt_template or "",
  59. inputs=inputs,
  60. query=query,
  61. files=files,
  62. context=context,
  63. memory=memory,
  64. model_config=model_config,
  65. )
  66. else:
  67. prompt_messages, stops = self._get_completion_model_prompt_messages(
  68. app_mode=app_mode,
  69. pre_prompt=prompt_template_entity.simple_prompt_template or "",
  70. inputs=inputs,
  71. query=query,
  72. files=files,
  73. context=context,
  74. memory=memory,
  75. model_config=model_config,
  76. )
  77. return prompt_messages, stops
  78. def get_prompt_str_and_rules(
  79. self,
  80. app_mode: AppMode,
  81. model_config: ModelConfigWithCredentialsEntity,
  82. pre_prompt: str,
  83. inputs: dict,
  84. query: Optional[str] = None,
  85. context: Optional[str] = None,
  86. histories: Optional[str] = None,
  87. ) -> tuple[str, dict]:
  88. # get prompt template
  89. prompt_template_config = self.get_prompt_template(
  90. app_mode=app_mode,
  91. provider=model_config.provider,
  92. model=model_config.model,
  93. pre_prompt=pre_prompt,
  94. has_context=context is not None,
  95. query_in_prompt=query is not None,
  96. with_memory_prompt=histories is not None,
  97. )
  98. variables = {k: inputs[k] for k in prompt_template_config["custom_variable_keys"] if k in inputs}
  99. for v in prompt_template_config["special_variable_keys"]:
  100. # support #context#, #query# and #histories#
  101. if v == "#context#":
  102. variables["#context#"] = context or ""
  103. elif v == "#query#":
  104. variables["#query#"] = query or ""
  105. elif v == "#histories#":
  106. variables["#histories#"] = histories or ""
  107. prompt_template = prompt_template_config["prompt_template"]
  108. prompt = prompt_template.format(variables)
  109. return prompt, prompt_template_config["prompt_rules"]
  110. def get_prompt_template(
  111. self,
  112. app_mode: AppMode,
  113. provider: str,
  114. model: str,
  115. pre_prompt: str,
  116. has_context: bool,
  117. query_in_prompt: bool,
  118. with_memory_prompt: bool = False,
  119. ) -> dict:
  120. prompt_rules = self._get_prompt_rule(app_mode=app_mode, provider=provider, model=model)
  121. custom_variable_keys = []
  122. special_variable_keys = []
  123. prompt = ""
  124. for order in prompt_rules["system_prompt_orders"]:
  125. if order == "context_prompt" and has_context:
  126. prompt += prompt_rules["context_prompt"]
  127. special_variable_keys.append("#context#")
  128. elif order == "pre_prompt" and pre_prompt:
  129. prompt += pre_prompt + "\n"
  130. pre_prompt_template = PromptTemplateParser(template=pre_prompt)
  131. custom_variable_keys = pre_prompt_template.variable_keys
  132. elif order == "histories_prompt" and with_memory_prompt:
  133. prompt += prompt_rules["histories_prompt"]
  134. special_variable_keys.append("#histories#")
  135. if query_in_prompt:
  136. prompt += prompt_rules.get("query_prompt", "{{#query#}}")
  137. special_variable_keys.append("#query#")
  138. return {
  139. "prompt_template": PromptTemplateParser(template=prompt),
  140. "custom_variable_keys": custom_variable_keys,
  141. "special_variable_keys": special_variable_keys,
  142. "prompt_rules": prompt_rules,
  143. }
  144. def _get_chat_model_prompt_messages(
  145. self,
  146. app_mode: AppMode,
  147. pre_prompt: str,
  148. inputs: dict,
  149. query: str,
  150. context: Optional[str],
  151. files: Sequence["File"],
  152. memory: Optional[TokenBufferMemory],
  153. model_config: ModelConfigWithCredentialsEntity,
  154. ) -> tuple[list[PromptMessage], Optional[list[str]]]:
  155. prompt_messages: list[PromptMessage] = []
  156. # get prompt
  157. prompt, _ = self.get_prompt_str_and_rules(
  158. app_mode=app_mode,
  159. model_config=model_config,
  160. pre_prompt=pre_prompt,
  161. inputs=inputs,
  162. query=None,
  163. context=context,
  164. )
  165. if prompt and query:
  166. prompt_messages.append(SystemPromptMessage(content=prompt))
  167. if memory:
  168. prompt_messages = self._append_chat_histories(
  169. memory=memory,
  170. memory_config=MemoryConfig(
  171. window=MemoryConfig.WindowConfig(
  172. enabled=False,
  173. )
  174. ),
  175. prompt_messages=prompt_messages,
  176. model_config=model_config,
  177. )
  178. if query:
  179. prompt_messages.append(self.get_last_user_message(query, files))
  180. else:
  181. prompt_messages.append(self.get_last_user_message(prompt, files))
  182. return prompt_messages, None
  183. def _get_completion_model_prompt_messages(
  184. self,
  185. app_mode: AppMode,
  186. pre_prompt: str,
  187. inputs: dict,
  188. query: str,
  189. context: Optional[str],
  190. files: Sequence["File"],
  191. memory: Optional[TokenBufferMemory],
  192. model_config: ModelConfigWithCredentialsEntity,
  193. ) -> tuple[list[PromptMessage], Optional[list[str]]]:
  194. # get prompt
  195. prompt, prompt_rules = self.get_prompt_str_and_rules(
  196. app_mode=app_mode,
  197. model_config=model_config,
  198. pre_prompt=pre_prompt,
  199. inputs=inputs,
  200. query=query,
  201. context=context,
  202. )
  203. if memory:
  204. tmp_human_message = UserPromptMessage(content=prompt)
  205. rest_tokens = self._calculate_rest_token([tmp_human_message], model_config)
  206. histories = self._get_history_messages_from_memory(
  207. memory=memory,
  208. memory_config=MemoryConfig(
  209. window=MemoryConfig.WindowConfig(
  210. enabled=False,
  211. )
  212. ),
  213. max_token_limit=rest_tokens,
  214. human_prefix=prompt_rules.get("human_prefix", "Human"),
  215. ai_prefix=prompt_rules.get("assistant_prefix", "Assistant"),
  216. )
  217. # get prompt
  218. prompt, prompt_rules = self.get_prompt_str_and_rules(
  219. app_mode=app_mode,
  220. model_config=model_config,
  221. pre_prompt=pre_prompt,
  222. inputs=inputs,
  223. query=query,
  224. context=context,
  225. histories=histories,
  226. )
  227. stops = prompt_rules.get("stops")
  228. if stops is not None and len(stops) == 0:
  229. stops = None
  230. return [self.get_last_user_message(prompt, files)], stops
  231. def get_last_user_message(self, prompt: str, files: Sequence["File"]) -> UserPromptMessage:
  232. if files:
  233. prompt_message_contents: list[PromptMessageContent] = []
  234. prompt_message_contents.append(TextPromptMessageContent(data=prompt))
  235. for file in files:
  236. prompt_message_contents.append(file_manager.to_prompt_message_content(file))
  237. prompt_message = UserPromptMessage(content=prompt_message_contents)
  238. else:
  239. prompt_message = UserPromptMessage(content=prompt)
  240. return prompt_message
  241. def _get_prompt_rule(self, app_mode: AppMode, provider: str, model: str) -> dict:
  242. """
  243. Get simple prompt rule.
  244. :param app_mode: app mode
  245. :param provider: model provider
  246. :param model: model name
  247. :return:
  248. """
  249. prompt_file_name = self._prompt_file_name(app_mode=app_mode, provider=provider, model=model)
  250. # Check if the prompt file is already loaded
  251. if prompt_file_name in prompt_file_contents:
  252. return cast(dict, prompt_file_contents[prompt_file_name])
  253. # Get the absolute path of the subdirectory
  254. prompt_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "prompt_templates")
  255. json_file_path = os.path.join(prompt_path, f"{prompt_file_name}.json")
  256. # Open the JSON file and read its content
  257. with open(json_file_path, encoding="utf-8") as json_file:
  258. content = json.load(json_file)
  259. # Store the content of the prompt file
  260. prompt_file_contents[prompt_file_name] = content
  261. return cast(dict, content)
  262. def _prompt_file_name(self, app_mode: AppMode, provider: str, model: str) -> str:
  263. # baichuan
  264. is_baichuan = False
  265. if provider == "baichuan":
  266. is_baichuan = True
  267. else:
  268. baichuan_supported_providers = ["huggingface_hub", "openllm", "xinference"]
  269. if provider in baichuan_supported_providers and "baichuan" in model.lower():
  270. is_baichuan = True
  271. if is_baichuan:
  272. if app_mode == AppMode.COMPLETION:
  273. return "baichuan_completion"
  274. else:
  275. return "baichuan_chat"
  276. # common
  277. if app_mode == AppMode.COMPLETION:
  278. return "common_completion"
  279. else:
  280. return "common_chat"