completion.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425
  1. import logging
  2. import re
  3. from typing import Optional, List, Union, Tuple
  4. from langchain.base_language import BaseLanguageModel
  5. from langchain.callbacks.base import BaseCallbackHandler
  6. from langchain.chat_models.base import BaseChatModel
  7. from langchain.llms import BaseLLM
  8. from langchain.schema import BaseMessage, HumanMessage
  9. from requests.exceptions import ChunkedEncodingError
  10. from core.agent.agent_executor import AgentExecuteResult, PlanningStrategy
  11. from core.callback_handler.main_chain_gather_callback_handler import MainChainGatherCallbackHandler
  12. from core.constant import llm_constant
  13. from core.callback_handler.llm_callback_handler import LLMCallbackHandler
  14. from core.callback_handler.std_out_callback_handler import DifyStreamingStdOutCallbackHandler, \
  15. DifyStdOutCallbackHandler
  16. from core.conversation_message_task import ConversationMessageTask, ConversationTaskStoppedException
  17. from core.llm.error import LLMBadRequestError
  18. from core.llm.fake import FakeLLM
  19. from core.llm.llm_builder import LLMBuilder
  20. from core.llm.streamable_chat_open_ai import StreamableChatOpenAI
  21. from core.llm.streamable_open_ai import StreamableOpenAI
  22. from core.memory.read_only_conversation_token_db_buffer_shared_memory import \
  23. ReadOnlyConversationTokenDBBufferSharedMemory
  24. from core.orchestrator_rule_parser import OrchestratorRuleParser
  25. from core.prompt.prompt_builder import PromptBuilder
  26. from core.prompt.prompt_template import JinjaPromptTemplate
  27. from core.prompt.prompts import MORE_LIKE_THIS_GENERATE_PROMPT
  28. from models.model import App, AppModelConfig, Account, Conversation, Message, EndUser
  29. class Completion:
  30. @classmethod
  31. def generate(cls, task_id: str, app: App, app_model_config: AppModelConfig, query: str, inputs: dict,
  32. user: Union[Account, EndUser], conversation: Optional[Conversation], streaming: bool, is_override: bool = False):
  33. """
  34. errors: ProviderTokenNotInitError
  35. """
  36. query = PromptBuilder.process_template(query)
  37. memory = None
  38. if conversation:
  39. # get memory of conversation (read-only)
  40. memory = cls.get_memory_from_conversation(
  41. tenant_id=app.tenant_id,
  42. app_model_config=app_model_config,
  43. conversation=conversation,
  44. return_messages=False
  45. )
  46. inputs = conversation.inputs
  47. rest_tokens_for_context_and_memory = cls.get_validate_rest_tokens(
  48. mode=app.mode,
  49. tenant_id=app.tenant_id,
  50. app_model_config=app_model_config,
  51. query=query,
  52. inputs=inputs
  53. )
  54. conversation_message_task = ConversationMessageTask(
  55. task_id=task_id,
  56. app=app,
  57. app_model_config=app_model_config,
  58. user=user,
  59. conversation=conversation,
  60. is_override=is_override,
  61. inputs=inputs,
  62. query=query,
  63. streaming=streaming
  64. )
  65. chain_callback = MainChainGatherCallbackHandler(conversation_message_task)
  66. # init orchestrator rule parser
  67. orchestrator_rule_parser = OrchestratorRuleParser(
  68. tenant_id=app.tenant_id,
  69. app_model_config=app_model_config
  70. )
  71. # parse sensitive_word_avoidance_chain
  72. sensitive_word_avoidance_chain = orchestrator_rule_parser.to_sensitive_word_avoidance_chain([chain_callback])
  73. if sensitive_word_avoidance_chain:
  74. query = sensitive_word_avoidance_chain.run(query)
  75. # get agent executor
  76. agent_executor = orchestrator_rule_parser.to_agent_executor(
  77. conversation_message_task=conversation_message_task,
  78. memory=memory,
  79. rest_tokens=rest_tokens_for_context_and_memory,
  80. chain_callback=chain_callback
  81. )
  82. # run agent executor
  83. agent_execute_result = None
  84. if agent_executor:
  85. should_use_agent = agent_executor.should_use_agent(query)
  86. if should_use_agent:
  87. agent_execute_result = agent_executor.run(query)
  88. # run the final llm
  89. try:
  90. cls.run_final_llm(
  91. tenant_id=app.tenant_id,
  92. mode=app.mode,
  93. app_model_config=app_model_config,
  94. query=query,
  95. inputs=inputs,
  96. agent_execute_result=agent_execute_result,
  97. conversation_message_task=conversation_message_task,
  98. memory=memory,
  99. streaming=streaming
  100. )
  101. except ConversationTaskStoppedException:
  102. return
  103. except ChunkedEncodingError as e:
  104. # Interrupt by LLM (like OpenAI), handle it.
  105. logging.warning(f'ChunkedEncodingError: {e}')
  106. conversation_message_task.end()
  107. return
  108. @classmethod
  109. def run_final_llm(cls, tenant_id: str, mode: str, app_model_config: AppModelConfig, query: str, inputs: dict,
  110. agent_execute_result: Optional[AgentExecuteResult],
  111. conversation_message_task: ConversationMessageTask,
  112. memory: Optional[ReadOnlyConversationTokenDBBufferSharedMemory], streaming: bool):
  113. # When no extra pre prompt is specified,
  114. # the output of the agent can be used directly as the main output content without calling LLM again
  115. if not app_model_config.pre_prompt and agent_execute_result and agent_execute_result.output \
  116. and agent_execute_result.strategy != PlanningStrategy.ROUTER:
  117. final_llm = FakeLLM(response=agent_execute_result.output,
  118. origin_llm=agent_execute_result.configuration.llm,
  119. streaming=streaming)
  120. final_llm.callbacks = cls.get_llm_callbacks(final_llm, streaming, conversation_message_task)
  121. response = final_llm.generate([[HumanMessage(content=query)]])
  122. return response
  123. final_llm = LLMBuilder.to_llm_from_model(
  124. tenant_id=tenant_id,
  125. model=app_model_config.model_dict,
  126. streaming=streaming
  127. )
  128. # get llm prompt
  129. prompt, stop_words = cls.get_main_llm_prompt(
  130. mode=mode,
  131. llm=final_llm,
  132. model=app_model_config.model_dict,
  133. pre_prompt=app_model_config.pre_prompt,
  134. query=query,
  135. inputs=inputs,
  136. agent_execute_result=agent_execute_result,
  137. memory=memory
  138. )
  139. final_llm.callbacks = cls.get_llm_callbacks(final_llm, streaming, conversation_message_task)
  140. cls.recale_llm_max_tokens(
  141. final_llm=final_llm,
  142. model=app_model_config.model_dict,
  143. prompt=prompt,
  144. mode=mode
  145. )
  146. response = final_llm.generate([prompt], stop_words)
  147. return response
  148. @classmethod
  149. def get_main_llm_prompt(cls, mode: str, llm: BaseLanguageModel, model: dict,
  150. pre_prompt: str, query: str, inputs: dict,
  151. agent_execute_result: Optional[AgentExecuteResult],
  152. memory: Optional[ReadOnlyConversationTokenDBBufferSharedMemory]) -> \
  153. Tuple[Union[str | List[BaseMessage]], Optional[List[str]]]:
  154. if mode == 'completion':
  155. prompt_template = JinjaPromptTemplate.from_template(
  156. template=("""Use the following context as your learned knowledge, inside <context></context> XML tags.
  157. <context>
  158. {{context}}
  159. </context>
  160. When answer to user:
  161. - If you don't know, just say that you don't know.
  162. - If you don't know when you are not sure, ask for clarification.
  163. Avoid mentioning that you obtained the information from the context.
  164. And answer according to the language of the user's question.
  165. """ if agent_execute_result else "")
  166. + (pre_prompt + "\n" if pre_prompt else "")
  167. + "{{query}}\n"
  168. )
  169. if agent_execute_result:
  170. inputs['context'] = agent_execute_result.output
  171. prompt_inputs = {k: inputs[k] for k in prompt_template.input_variables if k in inputs}
  172. prompt_content = prompt_template.format(
  173. query=query,
  174. **prompt_inputs
  175. )
  176. if isinstance(llm, BaseChatModel):
  177. # use chat llm as completion model
  178. return [HumanMessage(content=prompt_content)], None
  179. else:
  180. return prompt_content, None
  181. else:
  182. messages: List[BaseMessage] = []
  183. human_inputs = {
  184. "query": query
  185. }
  186. human_message_prompt = ""
  187. if pre_prompt:
  188. pre_prompt_inputs = {k: inputs[k] for k in
  189. JinjaPromptTemplate.from_template(template=pre_prompt).input_variables
  190. if k in inputs}
  191. if pre_prompt_inputs:
  192. human_inputs.update(pre_prompt_inputs)
  193. if agent_execute_result:
  194. human_inputs['context'] = agent_execute_result.output
  195. human_message_prompt += """Use the following context as your learned knowledge, inside <context></context> XML tags.
  196. <context>
  197. {{context}}
  198. </context>
  199. When answer to user:
  200. - If you don't know, just say that you don't know.
  201. - If you don't know when you are not sure, ask for clarification.
  202. Avoid mentioning that you obtained the information from the context.
  203. And answer according to the language of the user's question.
  204. """
  205. if pre_prompt:
  206. human_message_prompt += pre_prompt
  207. query_prompt = "\n\nHuman: {{query}}\n\nAssistant: "
  208. if memory:
  209. # append chat histories
  210. tmp_human_message = PromptBuilder.to_human_message(
  211. prompt_content=human_message_prompt + query_prompt,
  212. inputs=human_inputs
  213. )
  214. curr_message_tokens = memory.llm.get_num_tokens_from_messages([tmp_human_message])
  215. model_name = model['name']
  216. max_tokens = model.get("completion_params").get('max_tokens')
  217. rest_tokens = llm_constant.max_context_token_length[model_name] \
  218. - max_tokens - curr_message_tokens
  219. rest_tokens = max(rest_tokens, 0)
  220. histories = cls.get_history_messages_from_memory(memory, rest_tokens)
  221. human_message_prompt += "\n\n" if human_message_prompt else ""
  222. human_message_prompt += "Here is the chat histories between human and assistant, " \
  223. "inside <histories></histories> XML tags.\n\n<histories>\n"
  224. human_message_prompt += histories + "\n</histories>"
  225. human_message_prompt += query_prompt
  226. # construct main prompt
  227. human_message = PromptBuilder.to_human_message(
  228. prompt_content=human_message_prompt,
  229. inputs=human_inputs
  230. )
  231. messages.append(human_message)
  232. for message in messages:
  233. message.content = re.sub(r'<\|.*?\|>', '', message.content)
  234. return messages, ['\nHuman:', '</histories>']
  235. @classmethod
  236. def get_llm_callbacks(cls, llm: BaseLanguageModel,
  237. streaming: bool,
  238. conversation_message_task: ConversationMessageTask) -> List[BaseCallbackHandler]:
  239. llm_callback_handler = LLMCallbackHandler(llm, conversation_message_task)
  240. if streaming:
  241. return [llm_callback_handler, DifyStreamingStdOutCallbackHandler()]
  242. else:
  243. return [llm_callback_handler, DifyStdOutCallbackHandler()]
  244. @classmethod
  245. def get_history_messages_from_memory(cls, memory: ReadOnlyConversationTokenDBBufferSharedMemory,
  246. max_token_limit: int) -> str:
  247. """Get memory messages."""
  248. memory.max_token_limit = max_token_limit
  249. memory_key = memory.memory_variables[0]
  250. external_context = memory.load_memory_variables({})
  251. return external_context[memory_key]
  252. @classmethod
  253. def get_memory_from_conversation(cls, tenant_id: str, app_model_config: AppModelConfig,
  254. conversation: Conversation,
  255. **kwargs) -> ReadOnlyConversationTokenDBBufferSharedMemory:
  256. # only for calc token in memory
  257. memory_llm = LLMBuilder.to_llm_from_model(
  258. tenant_id=tenant_id,
  259. model=app_model_config.model_dict
  260. )
  261. # use llm config from conversation
  262. memory = ReadOnlyConversationTokenDBBufferSharedMemory(
  263. conversation=conversation,
  264. llm=memory_llm,
  265. max_token_limit=kwargs.get("max_token_limit", 2048),
  266. memory_key=kwargs.get("memory_key", "chat_history"),
  267. return_messages=kwargs.get("return_messages", True),
  268. input_key=kwargs.get("input_key", "input"),
  269. output_key=kwargs.get("output_key", "output"),
  270. message_limit=kwargs.get("message_limit", 10),
  271. )
  272. return memory
  273. @classmethod
  274. def get_validate_rest_tokens(cls, mode: str, tenant_id: str, app_model_config: AppModelConfig,
  275. query: str, inputs: dict) -> int:
  276. llm = LLMBuilder.to_llm_from_model(
  277. tenant_id=tenant_id,
  278. model=app_model_config.model_dict
  279. )
  280. model_name = app_model_config.model_dict.get("name")
  281. model_limited_tokens = llm_constant.max_context_token_length[model_name]
  282. max_tokens = app_model_config.model_dict.get("completion_params").get('max_tokens')
  283. # get prompt without memory and context
  284. prompt, _ = cls.get_main_llm_prompt(
  285. mode=mode,
  286. llm=llm,
  287. model=app_model_config.model_dict,
  288. pre_prompt=app_model_config.pre_prompt,
  289. query=query,
  290. inputs=inputs,
  291. agent_execute_result=None,
  292. memory=None
  293. )
  294. prompt_tokens = llm.get_num_tokens(prompt) if isinstance(prompt, str) \
  295. else llm.get_num_tokens_from_messages(prompt)
  296. rest_tokens = model_limited_tokens - max_tokens - prompt_tokens
  297. if rest_tokens < 0:
  298. raise LLMBadRequestError("Query or prefix prompt is too long, you can reduce the prefix prompt, "
  299. "or shrink the max token, or switch to a llm with a larger token limit size.")
  300. return rest_tokens
  301. @classmethod
  302. def recale_llm_max_tokens(cls, final_llm: BaseLanguageModel, model: dict,
  303. prompt: Union[str, List[BaseMessage]], mode: str):
  304. # recalc max_tokens if sum(prompt_token + max_tokens) over model token limit
  305. model_name = model.get("name")
  306. model_limited_tokens = llm_constant.max_context_token_length[model_name]
  307. max_tokens = model.get("completion_params").get('max_tokens')
  308. if mode == 'completion' and isinstance(final_llm, BaseLLM):
  309. prompt_tokens = final_llm.get_num_tokens(prompt)
  310. else:
  311. prompt_tokens = final_llm.get_num_tokens_from_messages(prompt)
  312. if prompt_tokens + max_tokens > model_limited_tokens:
  313. max_tokens = max(model_limited_tokens - prompt_tokens, 16)
  314. final_llm.max_tokens = max_tokens
  315. @classmethod
  316. def generate_more_like_this(cls, task_id: str, app: App, message: Message, pre_prompt: str,
  317. app_model_config: AppModelConfig, user: Account, streaming: bool):
  318. llm = LLMBuilder.to_llm_from_model(
  319. tenant_id=app.tenant_id,
  320. model=app_model_config.model_dict,
  321. streaming=streaming
  322. )
  323. # get llm prompt
  324. original_prompt, _ = cls.get_main_llm_prompt(
  325. mode="completion",
  326. llm=llm,
  327. model=app_model_config.model_dict,
  328. pre_prompt=pre_prompt,
  329. query=message.query,
  330. inputs=message.inputs,
  331. chain_output=None,
  332. agent_execute_result=None,
  333. memory=None
  334. )
  335. original_completion = message.answer.strip()
  336. prompt = MORE_LIKE_THIS_GENERATE_PROMPT
  337. prompt = prompt.format(prompt=original_prompt, original_completion=original_completion)
  338. if isinstance(llm, BaseChatModel):
  339. prompt = [HumanMessage(content=prompt)]
  340. conversation_message_task = ConversationMessageTask(
  341. task_id=task_id,
  342. app=app,
  343. app_model_config=app_model_config,
  344. user=user,
  345. inputs=message.inputs,
  346. query=message.query,
  347. is_override=True if message.override_model_configs else False,
  348. streaming=streaming
  349. )
  350. llm.callbacks = cls.get_llm_callbacks(llm, streaming, conversation_message_task)
  351. cls.recale_llm_max_tokens(
  352. final_llm=llm,
  353. model=app_model_config.model_dict,
  354. prompt=prompt,
  355. mode='completion'
  356. )
  357. llm.generate([prompt])