simple_prompt_transform.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324
  1. import enum
  2. import json
  3. import os
  4. from typing import TYPE_CHECKING, Optional
  5. from core.app.app_config.entities import PromptTemplateEntity
  6. from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
  7. from core.memory.token_buffer_memory import TokenBufferMemory
  8. from core.model_runtime.entities.message_entities import (
  9. PromptMessage,
  10. SystemPromptMessage,
  11. TextPromptMessageContent,
  12. UserPromptMessage,
  13. )
  14. from core.prompt.entities.advanced_prompt_entities import MemoryConfig
  15. from core.prompt.prompt_transform import PromptTransform
  16. from core.prompt.utils.prompt_template_parser import PromptTemplateParser
  17. from models.model import AppMode
  18. if TYPE_CHECKING:
  19. from core.file.file_obj import FileVar
  20. class ModelMode(enum.Enum):
  21. COMPLETION = 'completion'
  22. CHAT = 'chat'
  23. @classmethod
  24. def value_of(cls, value: str) -> 'ModelMode':
  25. """
  26. Get value of given mode.
  27. :param value: mode value
  28. :return: mode
  29. """
  30. for mode in cls:
  31. if mode.value == value:
  32. return mode
  33. raise ValueError(f'invalid mode value {value}')
  34. prompt_file_contents = {}
  35. class SimplePromptTransform(PromptTransform):
  36. """
  37. Simple Prompt Transform for Chatbot App Basic Mode.
  38. """
  39. def get_prompt(self,
  40. app_mode: AppMode,
  41. prompt_template_entity: PromptTemplateEntity,
  42. inputs: dict,
  43. query: str,
  44. files: list["FileVar"],
  45. context: Optional[str],
  46. memory: Optional[TokenBufferMemory],
  47. model_config: ModelConfigWithCredentialsEntity) -> \
  48. tuple[list[PromptMessage], Optional[list[str]]]:
  49. inputs = {key: str(value) for key, value in inputs.items()}
  50. model_mode = ModelMode.value_of(model_config.mode)
  51. if model_mode == ModelMode.CHAT:
  52. prompt_messages, stops = self._get_chat_model_prompt_messages(
  53. app_mode=app_mode,
  54. pre_prompt=prompt_template_entity.simple_prompt_template,
  55. inputs=inputs,
  56. query=query,
  57. files=files,
  58. context=context,
  59. memory=memory,
  60. model_config=model_config
  61. )
  62. else:
  63. prompt_messages, stops = self._get_completion_model_prompt_messages(
  64. app_mode=app_mode,
  65. pre_prompt=prompt_template_entity.simple_prompt_template,
  66. inputs=inputs,
  67. query=query,
  68. files=files,
  69. context=context,
  70. memory=memory,
  71. model_config=model_config
  72. )
  73. return prompt_messages, stops
  74. def get_prompt_str_and_rules(self, app_mode: AppMode,
  75. model_config: ModelConfigWithCredentialsEntity,
  76. pre_prompt: str,
  77. inputs: dict,
  78. query: Optional[str] = None,
  79. context: Optional[str] = None,
  80. histories: Optional[str] = None,
  81. ) -> tuple[str, dict]:
  82. # get prompt template
  83. prompt_template_config = self.get_prompt_template(
  84. app_mode=app_mode,
  85. provider=model_config.provider,
  86. model=model_config.model,
  87. pre_prompt=pre_prompt,
  88. has_context=context is not None,
  89. query_in_prompt=query is not None,
  90. with_memory_prompt=histories is not None
  91. )
  92. variables = {k: inputs[k] for k in prompt_template_config['custom_variable_keys'] if k in inputs}
  93. for v in prompt_template_config['special_variable_keys']:
  94. # support #context#, #query# and #histories#
  95. if v == '#context#':
  96. variables['#context#'] = context if context else ''
  97. elif v == '#query#':
  98. variables['#query#'] = query if query else ''
  99. elif v == '#histories#':
  100. variables['#histories#'] = histories if histories else ''
  101. prompt_template = prompt_template_config['prompt_template']
  102. prompt = prompt_template.format(variables)
  103. return prompt, prompt_template_config['prompt_rules']
  104. def get_prompt_template(self, app_mode: AppMode,
  105. provider: str,
  106. model: str,
  107. pre_prompt: str,
  108. has_context: bool,
  109. query_in_prompt: bool,
  110. with_memory_prompt: bool = False) -> dict:
  111. prompt_rules = self._get_prompt_rule(
  112. app_mode=app_mode,
  113. provider=provider,
  114. model=model
  115. )
  116. custom_variable_keys = []
  117. special_variable_keys = []
  118. prompt = ''
  119. for order in prompt_rules['system_prompt_orders']:
  120. if order == 'context_prompt' and has_context:
  121. prompt += prompt_rules['context_prompt']
  122. special_variable_keys.append('#context#')
  123. elif order == 'pre_prompt' and pre_prompt:
  124. prompt += pre_prompt + '\n'
  125. pre_prompt_template = PromptTemplateParser(template=pre_prompt)
  126. custom_variable_keys = pre_prompt_template.variable_keys
  127. elif order == 'histories_prompt' and with_memory_prompt:
  128. prompt += prompt_rules['histories_prompt']
  129. special_variable_keys.append('#histories#')
  130. if query_in_prompt:
  131. prompt += prompt_rules.get('query_prompt', '{{#query#}}')
  132. special_variable_keys.append('#query#')
  133. return {
  134. "prompt_template": PromptTemplateParser(template=prompt),
  135. "custom_variable_keys": custom_variable_keys,
  136. "special_variable_keys": special_variable_keys,
  137. "prompt_rules": prompt_rules
  138. }
  139. def _get_chat_model_prompt_messages(self, app_mode: AppMode,
  140. pre_prompt: str,
  141. inputs: dict,
  142. query: str,
  143. context: Optional[str],
  144. files: list["FileVar"],
  145. memory: Optional[TokenBufferMemory],
  146. model_config: ModelConfigWithCredentialsEntity) \
  147. -> tuple[list[PromptMessage], Optional[list[str]]]:
  148. prompt_messages = []
  149. # get prompt
  150. prompt, _ = self.get_prompt_str_and_rules(
  151. app_mode=app_mode,
  152. model_config=model_config,
  153. pre_prompt=pre_prompt,
  154. inputs=inputs,
  155. query=None,
  156. context=context
  157. )
  158. if prompt and query:
  159. prompt_messages.append(SystemPromptMessage(content=prompt))
  160. if memory:
  161. prompt_messages = self._append_chat_histories(
  162. memory=memory,
  163. memory_config=MemoryConfig(
  164. window=MemoryConfig.WindowConfig(
  165. enabled=False,
  166. )
  167. ),
  168. prompt_messages=prompt_messages,
  169. model_config=model_config
  170. )
  171. if query:
  172. prompt_messages.append(self.get_last_user_message(query, files))
  173. else:
  174. prompt_messages.append(self.get_last_user_message(prompt, files))
  175. return prompt_messages, None
  176. def _get_completion_model_prompt_messages(self, app_mode: AppMode,
  177. pre_prompt: str,
  178. inputs: dict,
  179. query: str,
  180. context: Optional[str],
  181. files: list["FileVar"],
  182. memory: Optional[TokenBufferMemory],
  183. model_config: ModelConfigWithCredentialsEntity) \
  184. -> tuple[list[PromptMessage], Optional[list[str]]]:
  185. # get prompt
  186. prompt, prompt_rules = self.get_prompt_str_and_rules(
  187. app_mode=app_mode,
  188. model_config=model_config,
  189. pre_prompt=pre_prompt,
  190. inputs=inputs,
  191. query=query,
  192. context=context
  193. )
  194. if memory:
  195. tmp_human_message = UserPromptMessage(
  196. content=prompt
  197. )
  198. rest_tokens = self._calculate_rest_token([tmp_human_message], model_config)
  199. histories = self._get_history_messages_from_memory(
  200. memory=memory,
  201. memory_config=MemoryConfig(
  202. window=MemoryConfig.WindowConfig(
  203. enabled=False,
  204. )
  205. ),
  206. max_token_limit=rest_tokens,
  207. human_prefix=prompt_rules.get('human_prefix', 'Human'),
  208. ai_prefix=prompt_rules.get('assistant_prefix', 'Assistant')
  209. )
  210. # get prompt
  211. prompt, prompt_rules = self.get_prompt_str_and_rules(
  212. app_mode=app_mode,
  213. model_config=model_config,
  214. pre_prompt=pre_prompt,
  215. inputs=inputs,
  216. query=query,
  217. context=context,
  218. histories=histories
  219. )
  220. stops = prompt_rules.get('stops')
  221. if stops is not None and len(stops) == 0:
  222. stops = None
  223. return [self.get_last_user_message(prompt, files)], stops
  224. def get_last_user_message(self, prompt: str, files: list["FileVar"]) -> UserPromptMessage:
  225. if files:
  226. prompt_message_contents = [TextPromptMessageContent(data=prompt)]
  227. for file in files:
  228. prompt_message_contents.append(file.prompt_message_content)
  229. prompt_message = UserPromptMessage(content=prompt_message_contents)
  230. else:
  231. prompt_message = UserPromptMessage(content=prompt)
  232. return prompt_message
  233. def _get_prompt_rule(self, app_mode: AppMode, provider: str, model: str) -> dict:
  234. """
  235. Get simple prompt rule.
  236. :param app_mode: app mode
  237. :param provider: model provider
  238. :param model: model name
  239. :return:
  240. """
  241. prompt_file_name = self._prompt_file_name(
  242. app_mode=app_mode,
  243. provider=provider,
  244. model=model
  245. )
  246. # Check if the prompt file is already loaded
  247. if prompt_file_name in prompt_file_contents:
  248. return prompt_file_contents[prompt_file_name]
  249. # Get the absolute path of the subdirectory
  250. prompt_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'prompt_templates')
  251. json_file_path = os.path.join(prompt_path, f'{prompt_file_name}.json')
  252. # Open the JSON file and read its content
  253. with open(json_file_path, encoding='utf-8') as json_file:
  254. content = json.load(json_file)
  255. # Store the content of the prompt file
  256. prompt_file_contents[prompt_file_name] = content
  257. return content
  258. def _prompt_file_name(self, app_mode: AppMode, provider: str, model: str) -> str:
  259. # baichuan
  260. is_baichuan = False
  261. if provider == 'baichuan':
  262. is_baichuan = True
  263. else:
  264. baichuan_supported_providers = ["huggingface_hub", "openllm", "xinference"]
  265. if provider in baichuan_supported_providers and 'baichuan' in model.lower():
  266. is_baichuan = True
  267. if is_baichuan:
  268. if app_mode == AppMode.COMPLETION:
  269. return 'baichuan_completion'
  270. else:
  271. return 'baichuan_chat'
  272. # common
  273. if app_mode == AppMode.COMPLETION:
  274. return 'common_completion'
  275. else:
  276. return 'common_chat'