completion.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398
  1. import logging
  2. from typing import Optional, List, Union, Tuple
  3. from langchain.base_language import BaseLanguageModel
  4. from langchain.callbacks.base import BaseCallbackHandler
  5. from langchain.chat_models.base import BaseChatModel
  6. from langchain.llms import BaseLLM
  7. from langchain.schema import BaseMessage, HumanMessage
  8. from requests.exceptions import ChunkedEncodingError
  9. from core.constant import llm_constant
  10. from core.callback_handler.llm_callback_handler import LLMCallbackHandler
  11. from core.callback_handler.std_out_callback_handler import DifyStreamingStdOutCallbackHandler, \
  12. DifyStdOutCallbackHandler
  13. from core.conversation_message_task import ConversationMessageTask, ConversationTaskStoppedException
  14. from core.llm.error import LLMBadRequestError
  15. from core.llm.llm_builder import LLMBuilder
  16. from core.chain.main_chain_builder import MainChainBuilder
  17. from core.llm.streamable_chat_open_ai import StreamableChatOpenAI
  18. from core.llm.streamable_open_ai import StreamableOpenAI
  19. from core.memory.read_only_conversation_token_db_buffer_shared_memory import \
  20. ReadOnlyConversationTokenDBBufferSharedMemory
  21. from core.memory.read_only_conversation_token_db_string_buffer_shared_memory import \
  22. ReadOnlyConversationTokenDBStringBufferSharedMemory
  23. from core.prompt.prompt_builder import PromptBuilder
  24. from core.prompt.prompt_template import JinjaPromptTemplate
  25. from core.prompt.prompts import MORE_LIKE_THIS_GENERATE_PROMPT
  26. from models.model import App, AppModelConfig, Account, Conversation, Message
  27. class Completion:
  28. @classmethod
  29. def generate(cls, task_id: str, app: App, app_model_config: AppModelConfig, query: str, inputs: dict,
  30. user: Account, conversation: Optional[Conversation], streaming: bool, is_override: bool = False):
  31. """
  32. errors: ProviderTokenNotInitError
  33. """
  34. query = PromptBuilder.process_template(query)
  35. memory = None
  36. if conversation:
  37. # get memory of conversation (read-only)
  38. memory = cls.get_memory_from_conversation(
  39. tenant_id=app.tenant_id,
  40. app_model_config=app_model_config,
  41. conversation=conversation,
  42. return_messages=False
  43. )
  44. inputs = conversation.inputs
  45. rest_tokens_for_context_and_memory = cls.get_validate_rest_tokens(
  46. mode=app.mode,
  47. tenant_id=app.tenant_id,
  48. app_model_config=app_model_config,
  49. query=query,
  50. inputs=inputs
  51. )
  52. conversation_message_task = ConversationMessageTask(
  53. task_id=task_id,
  54. app=app,
  55. app_model_config=app_model_config,
  56. user=user,
  57. conversation=conversation,
  58. is_override=is_override,
  59. inputs=inputs,
  60. query=query,
  61. streaming=streaming
  62. )
  63. # build main chain include agent
  64. main_chain = MainChainBuilder.to_langchain_components(
  65. tenant_id=app.tenant_id,
  66. agent_mode=app_model_config.agent_mode_dict,
  67. rest_tokens=rest_tokens_for_context_and_memory,
  68. memory=ReadOnlyConversationTokenDBStringBufferSharedMemory(memory=memory) if memory else None,
  69. conversation_message_task=conversation_message_task
  70. )
  71. chain_output = ''
  72. if main_chain:
  73. chain_output = main_chain.run(query)
  74. # run the final llm
  75. try:
  76. cls.run_final_llm(
  77. tenant_id=app.tenant_id,
  78. mode=app.mode,
  79. app_model_config=app_model_config,
  80. query=query,
  81. inputs=inputs,
  82. chain_output=chain_output,
  83. conversation_message_task=conversation_message_task,
  84. memory=memory,
  85. streaming=streaming
  86. )
  87. except ConversationTaskStoppedException:
  88. return
  89. except ChunkedEncodingError as e:
  90. # Interrupt by LLM (like OpenAI), handle it.
  91. logging.warning(f'ChunkedEncodingError: {e}')
  92. conversation_message_task.end()
  93. return
  94. @classmethod
  95. def run_final_llm(cls, tenant_id: str, mode: str, app_model_config: AppModelConfig, query: str, inputs: dict,
  96. chain_output: str,
  97. conversation_message_task: ConversationMessageTask,
  98. memory: Optional[ReadOnlyConversationTokenDBBufferSharedMemory], streaming: bool):
  99. final_llm = LLMBuilder.to_llm_from_model(
  100. tenant_id=tenant_id,
  101. model=app_model_config.model_dict,
  102. streaming=streaming
  103. )
  104. # get llm prompt
  105. prompt, stop_words = cls.get_main_llm_prompt(
  106. mode=mode,
  107. llm=final_llm,
  108. pre_prompt=app_model_config.pre_prompt,
  109. query=query,
  110. inputs=inputs,
  111. chain_output=chain_output,
  112. memory=memory
  113. )
  114. final_llm.callbacks = cls.get_llm_callbacks(final_llm, streaming, conversation_message_task)
  115. cls.recale_llm_max_tokens(
  116. final_llm=final_llm,
  117. prompt=prompt,
  118. mode=mode
  119. )
  120. response = final_llm.generate([prompt], stop_words)
  121. return response
  122. @classmethod
  123. def get_main_llm_prompt(cls, mode: str, llm: BaseLanguageModel, pre_prompt: str, query: str, inputs: dict,
  124. chain_output: Optional[str],
  125. memory: Optional[ReadOnlyConversationTokenDBBufferSharedMemory]) -> \
  126. Tuple[Union[str | List[BaseMessage]], Optional[List[str]]]:
  127. # disable template string in query
  128. # query_params = JinjaPromptTemplate.from_template(template=query).input_variables
  129. # if query_params:
  130. # for query_param in query_params:
  131. # if query_param not in inputs:
  132. # inputs[query_param] = '{{' + query_param + '}}'
  133. if mode == 'completion':
  134. prompt_template = JinjaPromptTemplate.from_template(
  135. template=("""Use the following CONTEXT as your learned knowledge:
  136. [CONTEXT]
  137. {{context}}
  138. [END CONTEXT]
  139. When answer to user:
  140. - If you don't know, just say that you don't know.
  141. - If you don't know when you are not sure, ask for clarification.
  142. Avoid mentioning that you obtained the information from the context.
  143. And answer according to the language of the user's question.
  144. """ if chain_output else "")
  145. + (pre_prompt + "\n" if pre_prompt else "")
  146. + "{{query}}\n"
  147. )
  148. if chain_output:
  149. inputs['context'] = chain_output
  150. # context_params = JinjaPromptTemplate.from_template(template=chain_output).input_variables
  151. # if context_params:
  152. # for context_param in context_params:
  153. # if context_param not in inputs:
  154. # inputs[context_param] = '{{' + context_param + '}}'
  155. prompt_inputs = {k: inputs[k] for k in prompt_template.input_variables if k in inputs}
  156. prompt_content = prompt_template.format(
  157. query=query,
  158. **prompt_inputs
  159. )
  160. if isinstance(llm, BaseChatModel):
  161. # use chat llm as completion model
  162. return [HumanMessage(content=prompt_content)], None
  163. else:
  164. return prompt_content, None
  165. else:
  166. messages: List[BaseMessage] = []
  167. human_inputs = {
  168. "query": query
  169. }
  170. human_message_prompt = ""
  171. if pre_prompt:
  172. pre_prompt_inputs = {k: inputs[k] for k in
  173. JinjaPromptTemplate.from_template(template=pre_prompt).input_variables
  174. if k in inputs}
  175. if pre_prompt_inputs:
  176. human_inputs.update(pre_prompt_inputs)
  177. if chain_output:
  178. human_inputs['context'] = chain_output
  179. human_message_prompt += """Use the following CONTEXT as your learned knowledge.
  180. [CONTEXT]
  181. {{context}}
  182. [END CONTEXT]
  183. When answer to user:
  184. - If you don't know, just say that you don't know.
  185. - If you don't know when you are not sure, ask for clarification.
  186. Avoid mentioning that you obtained the information from the context.
  187. And answer according to the language of the user's question.
  188. """
  189. if pre_prompt:
  190. human_message_prompt += pre_prompt
  191. query_prompt = "\nHuman: {{query}}\nAI: "
  192. if memory:
  193. # append chat histories
  194. tmp_human_message = PromptBuilder.to_human_message(
  195. prompt_content=human_message_prompt + query_prompt,
  196. inputs=human_inputs
  197. )
  198. curr_message_tokens = memory.llm.get_messages_tokens([tmp_human_message])
  199. rest_tokens = llm_constant.max_context_token_length[memory.llm.model_name] \
  200. - memory.llm.max_tokens - curr_message_tokens
  201. rest_tokens = max(rest_tokens, 0)
  202. histories = cls.get_history_messages_from_memory(memory, rest_tokens)
  203. # disable template string in query
  204. # histories_params = JinjaPromptTemplate.from_template(template=histories).input_variables
  205. # if histories_params:
  206. # for histories_param in histories_params:
  207. # if histories_param not in human_inputs:
  208. # human_inputs[histories_param] = '{{' + histories_param + '}}'
  209. human_message_prompt += "\n\n" + histories
  210. human_message_prompt += query_prompt
  211. # construct main prompt
  212. human_message = PromptBuilder.to_human_message(
  213. prompt_content=human_message_prompt,
  214. inputs=human_inputs
  215. )
  216. messages.append(human_message)
  217. return messages, ['\nHuman:']
  218. @classmethod
  219. def get_llm_callbacks(cls, llm: Union[StreamableOpenAI, StreamableChatOpenAI],
  220. streaming: bool,
  221. conversation_message_task: ConversationMessageTask) -> List[BaseCallbackHandler]:
  222. llm_callback_handler = LLMCallbackHandler(llm, conversation_message_task)
  223. if streaming:
  224. return [llm_callback_handler, DifyStreamingStdOutCallbackHandler()]
  225. else:
  226. return [llm_callback_handler, DifyStdOutCallbackHandler()]
  227. @classmethod
  228. def get_history_messages_from_memory(cls, memory: ReadOnlyConversationTokenDBBufferSharedMemory,
  229. max_token_limit: int) -> \
  230. str:
  231. """Get memory messages."""
  232. memory.max_token_limit = max_token_limit
  233. memory_key = memory.memory_variables[0]
  234. external_context = memory.load_memory_variables({})
  235. return external_context[memory_key]
  236. @classmethod
  237. def get_memory_from_conversation(cls, tenant_id: str, app_model_config: AppModelConfig,
  238. conversation: Conversation,
  239. **kwargs) -> ReadOnlyConversationTokenDBBufferSharedMemory:
  240. # only for calc token in memory
  241. memory_llm = LLMBuilder.to_llm_from_model(
  242. tenant_id=tenant_id,
  243. model=app_model_config.model_dict
  244. )
  245. # use llm config from conversation
  246. memory = ReadOnlyConversationTokenDBBufferSharedMemory(
  247. conversation=conversation,
  248. llm=memory_llm,
  249. max_token_limit=kwargs.get("max_token_limit", 2048),
  250. memory_key=kwargs.get("memory_key", "chat_history"),
  251. return_messages=kwargs.get("return_messages", True),
  252. input_key=kwargs.get("input_key", "input"),
  253. output_key=kwargs.get("output_key", "output"),
  254. message_limit=kwargs.get("message_limit", 10),
  255. )
  256. return memory
  257. @classmethod
  258. def get_validate_rest_tokens(cls, mode: str, tenant_id: str, app_model_config: AppModelConfig,
  259. query: str, inputs: dict) -> int:
  260. llm = LLMBuilder.to_llm_from_model(
  261. tenant_id=tenant_id,
  262. model=app_model_config.model_dict
  263. )
  264. model_limited_tokens = llm_constant.max_context_token_length[llm.model_name]
  265. max_tokens = llm.max_tokens
  266. # get prompt without memory and context
  267. prompt, _ = cls.get_main_llm_prompt(
  268. mode=mode,
  269. llm=llm,
  270. pre_prompt=app_model_config.pre_prompt,
  271. query=query,
  272. inputs=inputs,
  273. chain_output=None,
  274. memory=None
  275. )
  276. prompt_tokens = llm.get_num_tokens(prompt) if isinstance(prompt, str) \
  277. else llm.get_num_tokens_from_messages(prompt)
  278. rest_tokens = model_limited_tokens - max_tokens - prompt_tokens
  279. if rest_tokens < 0:
  280. raise LLMBadRequestError("Query or prefix prompt is too long, you can reduce the prefix prompt, "
  281. "or shrink the max token, or switch to a llm with a larger token limit size.")
  282. return rest_tokens
  283. @classmethod
  284. def recale_llm_max_tokens(cls, final_llm: Union[StreamableOpenAI, StreamableChatOpenAI],
  285. prompt: Union[str, List[BaseMessage]], mode: str):
  286. # recalc max_tokens if sum(prompt_token + max_tokens) over model token limit
  287. model_limited_tokens = llm_constant.max_context_token_length[final_llm.model_name]
  288. max_tokens = final_llm.max_tokens
  289. if mode == 'completion' and isinstance(final_llm, BaseLLM):
  290. prompt_tokens = final_llm.get_num_tokens(prompt)
  291. else:
  292. prompt_tokens = final_llm.get_messages_tokens(prompt)
  293. if prompt_tokens + max_tokens > model_limited_tokens:
  294. max_tokens = max(model_limited_tokens - prompt_tokens, 16)
  295. final_llm.max_tokens = max_tokens
  296. @classmethod
  297. def generate_more_like_this(cls, task_id: str, app: App, message: Message, pre_prompt: str,
  298. app_model_config: AppModelConfig, user: Account, streaming: bool):
  299. llm: StreamableOpenAI = LLMBuilder.to_llm(
  300. tenant_id=app.tenant_id,
  301. model_name='gpt-3.5-turbo',
  302. streaming=streaming
  303. )
  304. # get llm prompt
  305. original_prompt, _ = cls.get_main_llm_prompt(
  306. mode="completion",
  307. llm=llm,
  308. pre_prompt=pre_prompt,
  309. query=message.query,
  310. inputs=message.inputs,
  311. chain_output=None,
  312. memory=None
  313. )
  314. original_completion = message.answer.strip()
  315. prompt = MORE_LIKE_THIS_GENERATE_PROMPT
  316. prompt = prompt.format(prompt=original_prompt, original_completion=original_completion)
  317. if isinstance(llm, BaseChatModel):
  318. prompt = [HumanMessage(content=prompt)]
  319. conversation_message_task = ConversationMessageTask(
  320. task_id=task_id,
  321. app=app,
  322. app_model_config=app_model_config,
  323. user=user,
  324. inputs=message.inputs,
  325. query=message.query,
  326. is_override=True if message.override_model_configs else False,
  327. streaming=streaming
  328. )
  329. llm.callbacks = cls.get_llm_callbacks(llm, streaming, conversation_message_task)
  330. cls.recale_llm_max_tokens(
  331. final_llm=llm,
  332. prompt=prompt,
  333. mode='completion'
  334. )
  335. llm.generate([prompt])