prompt_transform.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420
  1. import json
  2. import os
  3. import re
  4. import enum
  5. from typing import List, Optional, Tuple
  6. from langchain.memory.chat_memory import BaseChatMemory
  7. from langchain.schema import BaseMessage
  8. from core.model_providers.models.entity.model_params import ModelMode
  9. from core.model_providers.models.entity.message import PromptMessage, MessageType, to_prompt_messages, PromptMessageFile
  10. from core.model_providers.models.llm.base import BaseLLM
  11. from core.model_providers.models.llm.baichuan_model import BaichuanModel
  12. from core.model_providers.models.llm.huggingface_hub_model import HuggingfaceHubModel
  13. from core.model_providers.models.llm.openllm_model import OpenLLMModel
  14. from core.model_providers.models.llm.xinference_model import XinferenceModel
  15. from core.prompt.prompt_builder import PromptBuilder
  16. from core.prompt.prompt_template import PromptTemplateParser
  17. from models.model import AppModelConfig
  18. class AppMode(enum.Enum):
  19. COMPLETION = 'completion'
  20. CHAT = 'chat'
  21. class PromptTransform:
  22. def get_prompt(self,
  23. app_mode: str,
  24. pre_prompt: str,
  25. inputs: dict,
  26. query: str,
  27. files: List[PromptMessageFile],
  28. context: Optional[str],
  29. memory: Optional[BaseChatMemory],
  30. model_instance: BaseLLM) -> \
  31. Tuple[List[PromptMessage], Optional[List[str]]]:
  32. app_mode_enum = AppMode(app_mode)
  33. model_mode_enum = model_instance.model_mode
  34. prompt_rules = self._read_prompt_rules_from_file(self._prompt_file_name(app_mode, model_instance))
  35. if app_mode_enum == AppMode.CHAT and model_mode_enum == ModelMode.CHAT:
  36. stops = None
  37. prompt_messages = self._get_simple_chat_app_chat_model_prompt_messages(prompt_rules, pre_prompt, inputs,
  38. query, context, memory,
  39. model_instance, files)
  40. else:
  41. stops = prompt_rules.get('stops')
  42. if stops is not None and len(stops) == 0:
  43. stops = None
  44. prompt_messages = self._get_simple_others_prompt_messages(prompt_rules, pre_prompt, inputs, query, context,
  45. memory,
  46. model_instance, files)
  47. return prompt_messages, stops
  48. def get_advanced_prompt(self,
  49. app_mode: str,
  50. app_model_config: AppModelConfig,
  51. inputs: dict,
  52. query: str,
  53. files: List[PromptMessageFile],
  54. context: Optional[str],
  55. memory: Optional[BaseChatMemory],
  56. model_instance: BaseLLM) -> List[PromptMessage]:
  57. model_mode = app_model_config.model_dict['mode']
  58. app_mode_enum = AppMode(app_mode)
  59. model_mode_enum = ModelMode(model_mode)
  60. prompt_messages = []
  61. if app_mode_enum == AppMode.CHAT:
  62. if model_mode_enum == ModelMode.COMPLETION:
  63. prompt_messages = self._get_chat_app_completion_model_prompt_messages(app_model_config, inputs, query,
  64. files, context, memory,
  65. model_instance)
  66. elif model_mode_enum == ModelMode.CHAT:
  67. prompt_messages = self._get_chat_app_chat_model_prompt_messages(app_model_config, inputs, query, files,
  68. context, memory, model_instance)
  69. elif app_mode_enum == AppMode.COMPLETION:
  70. if model_mode_enum == ModelMode.CHAT:
  71. prompt_messages = self._get_completion_app_chat_model_prompt_messages(app_model_config, inputs,
  72. files, context)
  73. elif model_mode_enum == ModelMode.COMPLETION:
  74. prompt_messages = self._get_completion_app_completion_model_prompt_messages(app_model_config, inputs,
  75. files, context)
  76. return prompt_messages
  77. def _get_history_messages_from_memory(self, memory: BaseChatMemory,
  78. max_token_limit: int) -> str:
  79. """Get memory messages."""
  80. memory.max_token_limit = max_token_limit
  81. memory_key = memory.memory_variables[0]
  82. external_context = memory.load_memory_variables({})
  83. return external_context[memory_key]
  84. def _get_history_messages_list_from_memory(self, memory: BaseChatMemory,
  85. max_token_limit: int) -> List[PromptMessage]:
  86. """Get memory messages."""
  87. memory.max_token_limit = max_token_limit
  88. memory.return_messages = True
  89. memory_key = memory.memory_variables[0]
  90. external_context = memory.load_memory_variables({})
  91. memory.return_messages = False
  92. return to_prompt_messages(external_context[memory_key])
  93. def _prompt_file_name(self, mode: str, model_instance: BaseLLM) -> str:
  94. # baichuan
  95. if isinstance(model_instance, BaichuanModel):
  96. return self._prompt_file_name_for_baichuan(mode)
  97. baichuan_model_hosted_platforms = (HuggingfaceHubModel, OpenLLMModel, XinferenceModel)
  98. if isinstance(model_instance, baichuan_model_hosted_platforms) and 'baichuan' in model_instance.name.lower():
  99. return self._prompt_file_name_for_baichuan(mode)
  100. # common
  101. if mode == 'completion':
  102. return 'common_completion'
  103. else:
  104. return 'common_chat'
  105. def _prompt_file_name_for_baichuan(self, mode: str) -> str:
  106. if mode == 'completion':
  107. return 'baichuan_completion'
  108. else:
  109. return 'baichuan_chat'
  110. def _read_prompt_rules_from_file(self, prompt_name: str) -> dict:
  111. # Get the absolute path of the subdirectory
  112. prompt_path = os.path.join(
  113. os.path.dirname(os.path.realpath(__file__)),
  114. 'generate_prompts')
  115. json_file_path = os.path.join(prompt_path, f'{prompt_name}.json')
  116. # Open the JSON file and read its content
  117. with open(json_file_path, 'r') as json_file:
  118. return json.load(json_file)
  119. def _get_simple_chat_app_chat_model_prompt_messages(self, prompt_rules: dict, pre_prompt: str, inputs: dict,
  120. query: str,
  121. context: Optional[str],
  122. memory: Optional[BaseChatMemory],
  123. model_instance: BaseLLM,
  124. files: List[PromptMessageFile]) -> List[PromptMessage]:
  125. prompt_messages = []
  126. context_prompt_content = ''
  127. if context and 'context_prompt' in prompt_rules:
  128. prompt_template = PromptTemplateParser(template=prompt_rules['context_prompt'])
  129. context_prompt_content = prompt_template.format(
  130. {'context': context}
  131. )
  132. pre_prompt_content = ''
  133. if pre_prompt:
  134. prompt_template = PromptTemplateParser(template=pre_prompt)
  135. prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs}
  136. pre_prompt_content = prompt_template.format(
  137. prompt_inputs
  138. )
  139. prompt = ''
  140. for order in prompt_rules['system_prompt_orders']:
  141. if order == 'context_prompt':
  142. prompt += context_prompt_content
  143. elif order == 'pre_prompt':
  144. prompt += pre_prompt_content
  145. prompt = re.sub(r'<\|.*?\|>', '', prompt)
  146. prompt_messages.append(PromptMessage(type=MessageType.SYSTEM, content=prompt))
  147. self._append_chat_histories(memory, prompt_messages, model_instance)
  148. prompt_messages.append(PromptMessage(type=MessageType.USER, content=query, files=files))
  149. return prompt_messages
  150. def _get_simple_others_prompt_messages(self, prompt_rules: dict, pre_prompt: str, inputs: dict,
  151. query: str,
  152. context: Optional[str],
  153. memory: Optional[BaseChatMemory],
  154. model_instance: BaseLLM,
  155. files: List[PromptMessageFile]) -> List[PromptMessage]:
  156. context_prompt_content = ''
  157. if context and 'context_prompt' in prompt_rules:
  158. prompt_template = PromptTemplateParser(template=prompt_rules['context_prompt'])
  159. context_prompt_content = prompt_template.format(
  160. {'context': context}
  161. )
  162. pre_prompt_content = ''
  163. if pre_prompt:
  164. prompt_template = PromptTemplateParser(template=pre_prompt)
  165. prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs}
  166. pre_prompt_content = prompt_template.format(
  167. prompt_inputs
  168. )
  169. prompt = ''
  170. for order in prompt_rules['system_prompt_orders']:
  171. if order == 'context_prompt':
  172. prompt += context_prompt_content
  173. elif order == 'pre_prompt':
  174. prompt += pre_prompt_content
  175. query_prompt = prompt_rules['query_prompt'] if 'query_prompt' in prompt_rules else '{{query}}'
  176. if memory and 'histories_prompt' in prompt_rules:
  177. # append chat histories
  178. tmp_human_message = PromptBuilder.to_human_message(
  179. prompt_content=prompt + query_prompt,
  180. inputs={
  181. 'query': query
  182. }
  183. )
  184. rest_tokens = self._calculate_rest_token(tmp_human_message, model_instance)
  185. memory.human_prefix = prompt_rules['human_prefix'] if 'human_prefix' in prompt_rules else 'Human'
  186. memory.ai_prefix = prompt_rules['assistant_prefix'] if 'assistant_prefix' in prompt_rules else 'Assistant'
  187. histories = self._get_history_messages_from_memory(memory, rest_tokens)
  188. prompt_template = PromptTemplateParser(template=prompt_rules['histories_prompt'])
  189. histories_prompt_content = prompt_template.format({'histories': histories})
  190. prompt = ''
  191. for order in prompt_rules['system_prompt_orders']:
  192. if order == 'context_prompt':
  193. prompt += context_prompt_content
  194. elif order == 'pre_prompt':
  195. prompt += (pre_prompt_content + '\n') if pre_prompt_content else ''
  196. elif order == 'histories_prompt':
  197. prompt += histories_prompt_content
  198. prompt_template = PromptTemplateParser(template=query_prompt)
  199. query_prompt_content = prompt_template.format({'query': query})
  200. prompt += query_prompt_content
  201. prompt = re.sub(r'<\|.*?\|>', '', prompt)
  202. return [PromptMessage(content=prompt, files=files)]
  203. def _set_context_variable(self, context: str, prompt_template: PromptTemplateParser, prompt_inputs: dict) -> None:
  204. if '#context#' in prompt_template.variable_keys:
  205. if context:
  206. prompt_inputs['#context#'] = context
  207. else:
  208. prompt_inputs['#context#'] = ''
  209. def _set_query_variable(self, query: str, prompt_template: PromptTemplateParser, prompt_inputs: dict) -> None:
  210. if '#query#' in prompt_template.variable_keys:
  211. if query:
  212. prompt_inputs['#query#'] = query
  213. else:
  214. prompt_inputs['#query#'] = ''
  215. def _set_histories_variable(self, memory: BaseChatMemory, raw_prompt: str, conversation_histories_role: dict,
  216. prompt_template: PromptTemplateParser, prompt_inputs: dict,
  217. model_instance: BaseLLM) -> None:
  218. if '#histories#' in prompt_template.variable_keys:
  219. if memory:
  220. tmp_human_message = PromptBuilder.to_human_message(
  221. prompt_content=raw_prompt,
  222. inputs={'#histories#': '', **prompt_inputs}
  223. )
  224. rest_tokens = self._calculate_rest_token(tmp_human_message, model_instance)
  225. memory.human_prefix = conversation_histories_role['user_prefix']
  226. memory.ai_prefix = conversation_histories_role['assistant_prefix']
  227. histories = self._get_history_messages_from_memory(memory, rest_tokens)
  228. prompt_inputs['#histories#'] = histories
  229. else:
  230. prompt_inputs['#histories#'] = ''
  231. def _append_chat_histories(self, memory: BaseChatMemory, prompt_messages: list[PromptMessage],
  232. model_instance: BaseLLM) -> None:
  233. if memory:
  234. rest_tokens = self._calculate_rest_token(prompt_messages, model_instance)
  235. memory.human_prefix = MessageType.USER.value
  236. memory.ai_prefix = MessageType.ASSISTANT.value
  237. histories = self._get_history_messages_list_from_memory(memory, rest_tokens)
  238. prompt_messages.extend(histories)
  239. def _calculate_rest_token(self, prompt_messages: BaseMessage, model_instance: BaseLLM) -> int:
  240. rest_tokens = 2000
  241. if model_instance.model_rules.max_tokens.max:
  242. curr_message_tokens = model_instance.get_num_tokens(to_prompt_messages(prompt_messages))
  243. max_tokens = model_instance.model_kwargs.max_tokens
  244. rest_tokens = model_instance.model_rules.max_tokens.max - max_tokens - curr_message_tokens
  245. rest_tokens = max(rest_tokens, 0)
  246. return rest_tokens
  247. def _format_prompt(self, prompt_template: PromptTemplateParser, prompt_inputs: dict) -> str:
  248. prompt = prompt_template.format(
  249. prompt_inputs
  250. )
  251. prompt = re.sub(r'<\|.*?\|>', '', prompt)
  252. return prompt
  253. def _get_chat_app_completion_model_prompt_messages(self,
  254. app_model_config: AppModelConfig,
  255. inputs: dict,
  256. query: str,
  257. files: List[PromptMessageFile],
  258. context: Optional[str],
  259. memory: Optional[BaseChatMemory],
  260. model_instance: BaseLLM) -> List[PromptMessage]:
  261. raw_prompt = app_model_config.completion_prompt_config_dict['prompt']['text']
  262. conversation_histories_role = app_model_config.completion_prompt_config_dict['conversation_histories_role']
  263. prompt_messages = []
  264. prompt_template = PromptTemplateParser(template=raw_prompt)
  265. prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs}
  266. self._set_context_variable(context, prompt_template, prompt_inputs)
  267. self._set_query_variable(query, prompt_template, prompt_inputs)
  268. self._set_histories_variable(memory, raw_prompt, conversation_histories_role, prompt_template, prompt_inputs,
  269. model_instance)
  270. prompt = self._format_prompt(prompt_template, prompt_inputs)
  271. prompt_messages.append(PromptMessage(type=MessageType.USER, content=prompt, files=files))
  272. return prompt_messages
  273. def _get_chat_app_chat_model_prompt_messages(self,
  274. app_model_config: AppModelConfig,
  275. inputs: dict,
  276. query: str,
  277. files: List[PromptMessageFile],
  278. context: Optional[str],
  279. memory: Optional[BaseChatMemory],
  280. model_instance: BaseLLM) -> List[PromptMessage]:
  281. raw_prompt_list = app_model_config.chat_prompt_config_dict['prompt']
  282. prompt_messages = []
  283. for prompt_item in raw_prompt_list:
  284. raw_prompt = prompt_item['text']
  285. prompt_template = PromptTemplateParser(template=raw_prompt)
  286. prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs}
  287. self._set_context_variable(context, prompt_template, prompt_inputs)
  288. prompt = self._format_prompt(prompt_template, prompt_inputs)
  289. prompt_messages.append(PromptMessage(type=MessageType(prompt_item['role']), content=prompt))
  290. self._append_chat_histories(memory, prompt_messages, model_instance)
  291. prompt_messages.append(PromptMessage(type=MessageType.USER, content=query, files=files))
  292. return prompt_messages
  293. def _get_completion_app_completion_model_prompt_messages(self,
  294. app_model_config: AppModelConfig,
  295. inputs: dict,
  296. files: List[PromptMessageFile],
  297. context: Optional[str]) -> List[PromptMessage]:
  298. raw_prompt = app_model_config.completion_prompt_config_dict['prompt']['text']
  299. prompt_messages = []
  300. prompt_template = PromptTemplateParser(template=raw_prompt)
  301. prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs}
  302. self._set_context_variable(context, prompt_template, prompt_inputs)
  303. prompt = self._format_prompt(prompt_template, prompt_inputs)
  304. prompt_messages.append(PromptMessage(type=MessageType(MessageType.USER), content=prompt, files=files))
  305. return prompt_messages
  306. def _get_completion_app_chat_model_prompt_messages(self,
  307. app_model_config: AppModelConfig,
  308. inputs: dict,
  309. files: List[PromptMessageFile],
  310. context: Optional[str]) -> List[PromptMessage]:
  311. raw_prompt_list = app_model_config.chat_prompt_config_dict['prompt']
  312. prompt_messages = []
  313. for prompt_item in raw_prompt_list:
  314. raw_prompt = prompt_item['text']
  315. prompt_template = PromptTemplateParser(template=raw_prompt)
  316. prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs}
  317. self._set_context_variable(context, prompt_template, prompt_inputs)
  318. prompt = self._format_prompt(prompt_template, prompt_inputs)
  319. prompt_messages.append(PromptMessage(type=MessageType(prompt_item['role']), content=prompt))
  320. for prompt_message in prompt_messages[::-1]:
  321. if prompt_message.type == MessageType.USER:
  322. prompt_message.files = files
  323. break
  324. return prompt_messages