completion.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399
  1. import logging
  2. import re
  3. from typing import Optional, List, Union, Tuple
  4. from langchain.schema import BaseMessage
  5. from requests.exceptions import ChunkedEncodingError
  6. from core.agent.agent_executor import AgentExecuteResult, PlanningStrategy
  7. from core.callback_handler.main_chain_gather_callback_handler import MainChainGatherCallbackHandler
  8. from core.callback_handler.llm_callback_handler import LLMCallbackHandler
  9. from core.conversation_message_task import ConversationMessageTask, ConversationTaskStoppedException
  10. from core.model_providers.error import LLMBadRequestError
  11. from core.memory.read_only_conversation_token_db_buffer_shared_memory import \
  12. ReadOnlyConversationTokenDBBufferSharedMemory
  13. from core.model_providers.model_factory import ModelFactory
  14. from core.model_providers.models.entity.message import PromptMessage, to_prompt_messages
  15. from core.model_providers.models.llm.base import BaseLLM
  16. from core.orchestrator_rule_parser import OrchestratorRuleParser
  17. from core.prompt.prompt_builder import PromptBuilder
  18. from core.prompt.prompt_template import JinjaPromptTemplate
  19. from core.prompt.prompts import MORE_LIKE_THIS_GENERATE_PROMPT
  20. from models.model import App, AppModelConfig, Account, Conversation, Message, EndUser
  21. class Completion:
  22. @classmethod
  23. def generate(cls, task_id: str, app: App, app_model_config: AppModelConfig, query: str, inputs: dict,
  24. user: Union[Account, EndUser], conversation: Optional[Conversation], streaming: bool, is_override: bool = False):
  25. """
  26. errors: ProviderTokenNotInitError
  27. """
  28. query = PromptBuilder.process_template(query)
  29. memory = None
  30. if conversation:
  31. # get memory of conversation (read-only)
  32. memory = cls.get_memory_from_conversation(
  33. tenant_id=app.tenant_id,
  34. app_model_config=app_model_config,
  35. conversation=conversation,
  36. return_messages=False
  37. )
  38. inputs = conversation.inputs
  39. final_model_instance = ModelFactory.get_text_generation_model_from_model_config(
  40. tenant_id=app.tenant_id,
  41. model_config=app_model_config.model_dict,
  42. streaming=streaming
  43. )
  44. conversation_message_task = ConversationMessageTask(
  45. task_id=task_id,
  46. app=app,
  47. app_model_config=app_model_config,
  48. user=user,
  49. conversation=conversation,
  50. is_override=is_override,
  51. inputs=inputs,
  52. query=query,
  53. streaming=streaming,
  54. model_instance=final_model_instance
  55. )
  56. rest_tokens_for_context_and_memory = cls.get_validate_rest_tokens(
  57. mode=app.mode,
  58. model_instance=final_model_instance,
  59. app_model_config=app_model_config,
  60. query=query,
  61. inputs=inputs
  62. )
  63. # init orchestrator rule parser
  64. orchestrator_rule_parser = OrchestratorRuleParser(
  65. tenant_id=app.tenant_id,
  66. app_model_config=app_model_config
  67. )
  68. # parse sensitive_word_avoidance_chain
  69. chain_callback = MainChainGatherCallbackHandler(conversation_message_task)
  70. sensitive_word_avoidance_chain = orchestrator_rule_parser.to_sensitive_word_avoidance_chain([chain_callback])
  71. if sensitive_word_avoidance_chain:
  72. query = sensitive_word_avoidance_chain.run(query)
  73. # get agent executor
  74. agent_executor = orchestrator_rule_parser.to_agent_executor(
  75. conversation_message_task=conversation_message_task,
  76. memory=memory,
  77. rest_tokens=rest_tokens_for_context_and_memory,
  78. chain_callback=chain_callback
  79. )
  80. # run agent executor
  81. agent_execute_result = None
  82. if agent_executor:
  83. should_use_agent = agent_executor.should_use_agent(query)
  84. if should_use_agent:
  85. agent_execute_result = agent_executor.run(query)
  86. # run the final llm
  87. try:
  88. cls.run_final_llm(
  89. model_instance=final_model_instance,
  90. mode=app.mode,
  91. app_model_config=app_model_config,
  92. query=query,
  93. inputs=inputs,
  94. agent_execute_result=agent_execute_result,
  95. conversation_message_task=conversation_message_task,
  96. memory=memory
  97. )
  98. except ConversationTaskStoppedException:
  99. return
  100. except ChunkedEncodingError as e:
  101. # Interrupt by LLM (like OpenAI), handle it.
  102. logging.warning(f'ChunkedEncodingError: {e}')
  103. conversation_message_task.end()
  104. return
  105. @classmethod
  106. def run_final_llm(cls, model_instance: BaseLLM, mode: str, app_model_config: AppModelConfig, query: str, inputs: dict,
  107. agent_execute_result: Optional[AgentExecuteResult],
  108. conversation_message_task: ConversationMessageTask,
  109. memory: Optional[ReadOnlyConversationTokenDBBufferSharedMemory]):
  110. # When no extra pre prompt is specified,
  111. # the output of the agent can be used directly as the main output content without calling LLM again
  112. fake_response = None
  113. if not app_model_config.pre_prompt and agent_execute_result and agent_execute_result.output \
  114. and agent_execute_result.strategy not in [PlanningStrategy.ROUTER, PlanningStrategy.REACT_ROUTER]:
  115. fake_response = agent_execute_result.output
  116. # get llm prompt
  117. prompt_messages, stop_words = cls.get_main_llm_prompt(
  118. mode=mode,
  119. model=app_model_config.model_dict,
  120. pre_prompt=app_model_config.pre_prompt,
  121. query=query,
  122. inputs=inputs,
  123. agent_execute_result=agent_execute_result,
  124. memory=memory
  125. )
  126. cls.recale_llm_max_tokens(
  127. model_instance=model_instance,
  128. prompt_messages=prompt_messages,
  129. )
  130. response = model_instance.run(
  131. messages=prompt_messages,
  132. stop=stop_words,
  133. callbacks=[LLMCallbackHandler(model_instance, conversation_message_task)],
  134. fake_response=fake_response
  135. )
  136. return response
  137. @classmethod
  138. def get_main_llm_prompt(cls, mode: str, model: dict,
  139. pre_prompt: str, query: str, inputs: dict,
  140. agent_execute_result: Optional[AgentExecuteResult],
  141. memory: Optional[ReadOnlyConversationTokenDBBufferSharedMemory]) -> \
  142. Tuple[List[PromptMessage], Optional[List[str]]]:
  143. if mode == 'completion':
  144. prompt_template = JinjaPromptTemplate.from_template(
  145. template=("""Use the following context as your learned knowledge, inside <context></context> XML tags.
  146. <context>
  147. {{context}}
  148. </context>
  149. When answer to user:
  150. - If you don't know, just say that you don't know.
  151. - If you don't know when you are not sure, ask for clarification.
  152. Avoid mentioning that you obtained the information from the context.
  153. And answer according to the language of the user's question.
  154. """ if agent_execute_result else "")
  155. + (pre_prompt + "\n" if pre_prompt else "")
  156. + "{{query}}\n"
  157. )
  158. if agent_execute_result:
  159. inputs['context'] = agent_execute_result.output
  160. prompt_inputs = {k: inputs[k] for k in prompt_template.input_variables if k in inputs}
  161. prompt_content = prompt_template.format(
  162. query=query,
  163. **prompt_inputs
  164. )
  165. return [PromptMessage(content=prompt_content)], None
  166. else:
  167. messages: List[BaseMessage] = []
  168. human_inputs = {
  169. "query": query
  170. }
  171. human_message_prompt = ""
  172. if pre_prompt:
  173. pre_prompt_inputs = {k: inputs[k] for k in
  174. JinjaPromptTemplate.from_template(template=pre_prompt).input_variables
  175. if k in inputs}
  176. if pre_prompt_inputs:
  177. human_inputs.update(pre_prompt_inputs)
  178. if agent_execute_result:
  179. human_inputs['context'] = agent_execute_result.output
  180. human_message_prompt += """Use the following context as your learned knowledge, inside <context></context> XML tags.
  181. <context>
  182. {{context}}
  183. </context>
  184. When answer to user:
  185. - If you don't know, just say that you don't know.
  186. - If you don't know when you are not sure, ask for clarification.
  187. Avoid mentioning that you obtained the information from the context.
  188. And answer according to the language of the user's question.
  189. """
  190. if pre_prompt:
  191. human_message_prompt += pre_prompt
  192. query_prompt = "\n\nHuman: {{query}}\n\nAssistant: "
  193. if memory:
  194. # append chat histories
  195. tmp_human_message = PromptBuilder.to_human_message(
  196. prompt_content=human_message_prompt + query_prompt,
  197. inputs=human_inputs
  198. )
  199. if memory.model_instance.model_rules.max_tokens.max:
  200. curr_message_tokens = memory.model_instance.get_num_tokens(to_prompt_messages([tmp_human_message]))
  201. max_tokens = model.get("completion_params").get('max_tokens')
  202. rest_tokens = memory.model_instance.model_rules.max_tokens.max - max_tokens - curr_message_tokens
  203. rest_tokens = max(rest_tokens, 0)
  204. else:
  205. rest_tokens = 2000
  206. histories = cls.get_history_messages_from_memory(memory, rest_tokens)
  207. human_message_prompt += "\n\n" if human_message_prompt else ""
  208. human_message_prompt += "Here is the chat histories between human and assistant, " \
  209. "inside <histories></histories> XML tags.\n\n<histories>\n"
  210. human_message_prompt += histories + "\n</histories>"
  211. human_message_prompt += query_prompt
  212. # construct main prompt
  213. human_message = PromptBuilder.to_human_message(
  214. prompt_content=human_message_prompt,
  215. inputs=human_inputs
  216. )
  217. messages.append(human_message)
  218. for message in messages:
  219. message.content = re.sub(r'<\|.*?\|>', '', message.content)
  220. return to_prompt_messages(messages), ['\nHuman:', '</histories>']
  221. @classmethod
  222. def get_history_messages_from_memory(cls, memory: ReadOnlyConversationTokenDBBufferSharedMemory,
  223. max_token_limit: int) -> str:
  224. """Get memory messages."""
  225. memory.max_token_limit = max_token_limit
  226. memory_key = memory.memory_variables[0]
  227. external_context = memory.load_memory_variables({})
  228. return external_context[memory_key]
  229. @classmethod
  230. def get_memory_from_conversation(cls, tenant_id: str, app_model_config: AppModelConfig,
  231. conversation: Conversation,
  232. **kwargs) -> ReadOnlyConversationTokenDBBufferSharedMemory:
  233. # only for calc token in memory
  234. memory_model_instance = ModelFactory.get_text_generation_model_from_model_config(
  235. tenant_id=tenant_id,
  236. model_config=app_model_config.model_dict
  237. )
  238. # use llm config from conversation
  239. memory = ReadOnlyConversationTokenDBBufferSharedMemory(
  240. conversation=conversation,
  241. model_instance=memory_model_instance,
  242. max_token_limit=kwargs.get("max_token_limit", 2048),
  243. memory_key=kwargs.get("memory_key", "chat_history"),
  244. return_messages=kwargs.get("return_messages", True),
  245. input_key=kwargs.get("input_key", "input"),
  246. output_key=kwargs.get("output_key", "output"),
  247. message_limit=kwargs.get("message_limit", 10),
  248. )
  249. return memory
  250. @classmethod
  251. def get_validate_rest_tokens(cls, mode: str, model_instance: BaseLLM, app_model_config: AppModelConfig,
  252. query: str, inputs: dict) -> int:
  253. model_limited_tokens = model_instance.model_rules.max_tokens.max
  254. max_tokens = model_instance.get_model_kwargs().max_tokens
  255. if model_limited_tokens is None:
  256. return -1
  257. if max_tokens is None:
  258. max_tokens = 0
  259. # get prompt without memory and context
  260. prompt_messages, _ = cls.get_main_llm_prompt(
  261. mode=mode,
  262. model=app_model_config.model_dict,
  263. pre_prompt=app_model_config.pre_prompt,
  264. query=query,
  265. inputs=inputs,
  266. agent_execute_result=None,
  267. memory=None
  268. )
  269. prompt_tokens = model_instance.get_num_tokens(prompt_messages)
  270. rest_tokens = model_limited_tokens - max_tokens - prompt_tokens
  271. if rest_tokens < 0:
  272. raise LLMBadRequestError("Query or prefix prompt is too long, you can reduce the prefix prompt, "
  273. "or shrink the max token, or switch to a llm with a larger token limit size.")
  274. return rest_tokens
  275. @classmethod
  276. def recale_llm_max_tokens(cls, model_instance: BaseLLM, prompt_messages: List[PromptMessage]):
  277. # recalc max_tokens if sum(prompt_token + max_tokens) over model token limit
  278. model_limited_tokens = model_instance.model_rules.max_tokens.max
  279. max_tokens = model_instance.get_model_kwargs().max_tokens
  280. if model_limited_tokens is None:
  281. return
  282. if max_tokens is None:
  283. max_tokens = 0
  284. prompt_tokens = model_instance.get_num_tokens(prompt_messages)
  285. if prompt_tokens + max_tokens > model_limited_tokens:
  286. max_tokens = max(model_limited_tokens - prompt_tokens, 16)
  287. # update model instance max tokens
  288. model_kwargs = model_instance.get_model_kwargs()
  289. model_kwargs.max_tokens = max_tokens
  290. model_instance.set_model_kwargs(model_kwargs)
  291. @classmethod
  292. def generate_more_like_this(cls, task_id: str, app: App, message: Message, pre_prompt: str,
  293. app_model_config: AppModelConfig, user: Account, streaming: bool):
  294. final_model_instance = ModelFactory.get_text_generation_model_from_model_config(
  295. tenant_id=app.tenant_id,
  296. model_config=app_model_config.model_dict,
  297. streaming=streaming
  298. )
  299. # get llm prompt
  300. old_prompt_messages, _ = cls.get_main_llm_prompt(
  301. mode="completion",
  302. model=app_model_config.model_dict,
  303. pre_prompt=pre_prompt,
  304. query=message.query,
  305. inputs=message.inputs,
  306. agent_execute_result=None,
  307. memory=None
  308. )
  309. original_completion = message.answer.strip()
  310. prompt = MORE_LIKE_THIS_GENERATE_PROMPT
  311. prompt = prompt.format(prompt=old_prompt_messages[0].content, original_completion=original_completion)
  312. prompt_messages = [PromptMessage(content=prompt)]
  313. conversation_message_task = ConversationMessageTask(
  314. task_id=task_id,
  315. app=app,
  316. app_model_config=app_model_config,
  317. user=user,
  318. inputs=message.inputs,
  319. query=message.query,
  320. is_override=True if message.override_model_configs else False,
  321. streaming=streaming,
  322. model_instance=final_model_instance
  323. )
  324. cls.recale_llm_max_tokens(
  325. model_instance=final_model_instance,
  326. prompt_messages=prompt_messages
  327. )
  328. final_model_instance.run(
  329. messages=prompt_messages,
  330. callbacks=[LLMCallbackHandler(final_model_instance, conversation_message_task)]
  331. )