completion.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364
  1. from typing import Optional, List, Union, Tuple
  2. from langchain.callbacks import CallbackManager
  3. from langchain.chat_models.base import BaseChatModel
  4. from langchain.llms import BaseLLM
  5. from langchain.schema import BaseMessage, BaseLanguageModel, HumanMessage
  6. from core.constant import llm_constant
  7. from core.callback_handler.llm_callback_handler import LLMCallbackHandler
  8. from core.callback_handler.std_out_callback_handler import DifyStreamingStdOutCallbackHandler, \
  9. DifyStdOutCallbackHandler
  10. from core.conversation_message_task import ConversationMessageTask, ConversationTaskStoppedException
  11. from core.llm.error import LLMBadRequestError
  12. from core.llm.llm_builder import LLMBuilder
  13. from core.chain.main_chain_builder import MainChainBuilder
  14. from core.llm.streamable_chat_open_ai import StreamableChatOpenAI
  15. from core.llm.streamable_open_ai import StreamableOpenAI
  16. from core.memory.read_only_conversation_token_db_buffer_shared_memory import \
  17. ReadOnlyConversationTokenDBBufferSharedMemory
  18. from core.memory.read_only_conversation_token_db_string_buffer_shared_memory import \
  19. ReadOnlyConversationTokenDBStringBufferSharedMemory
  20. from core.prompt.prompt_builder import PromptBuilder
  21. from core.prompt.prompt_template import OutLinePromptTemplate
  22. from core.prompt.prompts import MORE_LIKE_THIS_GENERATE_PROMPT
  23. from models.model import App, AppModelConfig, Account, Conversation, Message
  24. class Completion:
  25. @classmethod
  26. def generate(cls, task_id: str, app: App, app_model_config: AppModelConfig, query: str, inputs: dict,
  27. user: Account, conversation: Optional[Conversation], streaming: bool, is_override: bool = False):
  28. """
  29. errors: ProviderTokenNotInitError
  30. """
  31. cls.validate_query_tokens(app.tenant_id, app_model_config, query)
  32. memory = None
  33. if conversation:
  34. # get memory of conversation (read-only)
  35. memory = cls.get_memory_from_conversation(
  36. tenant_id=app.tenant_id,
  37. app_model_config=app_model_config,
  38. conversation=conversation,
  39. return_messages=False
  40. )
  41. inputs = conversation.inputs
  42. conversation_message_task = ConversationMessageTask(
  43. task_id=task_id,
  44. app=app,
  45. app_model_config=app_model_config,
  46. user=user,
  47. conversation=conversation,
  48. is_override=is_override,
  49. inputs=inputs,
  50. query=query,
  51. streaming=streaming
  52. )
  53. # build main chain include agent
  54. main_chain = MainChainBuilder.to_langchain_components(
  55. tenant_id=app.tenant_id,
  56. agent_mode=app_model_config.agent_mode_dict,
  57. memory=ReadOnlyConversationTokenDBStringBufferSharedMemory(memory=memory) if memory else None,
  58. conversation_message_task=conversation_message_task
  59. )
  60. chain_output = ''
  61. if main_chain:
  62. chain_output = main_chain.run(query)
  63. # run the final llm
  64. try:
  65. cls.run_final_llm(
  66. tenant_id=app.tenant_id,
  67. mode=app.mode,
  68. app_model_config=app_model_config,
  69. query=query,
  70. inputs=inputs,
  71. chain_output=chain_output,
  72. conversation_message_task=conversation_message_task,
  73. memory=memory,
  74. streaming=streaming
  75. )
  76. except ConversationTaskStoppedException:
  77. return
  78. @classmethod
  79. def run_final_llm(cls, tenant_id: str, mode: str, app_model_config: AppModelConfig, query: str, inputs: dict,
  80. chain_output: str,
  81. conversation_message_task: ConversationMessageTask,
  82. memory: Optional[ReadOnlyConversationTokenDBBufferSharedMemory], streaming: bool):
  83. final_llm = LLMBuilder.to_llm_from_model(
  84. tenant_id=tenant_id,
  85. model=app_model_config.model_dict,
  86. streaming=streaming
  87. )
  88. # get llm prompt
  89. prompt, stop_words = cls.get_main_llm_prompt(
  90. mode=mode,
  91. llm=final_llm,
  92. pre_prompt=app_model_config.pre_prompt,
  93. query=query,
  94. inputs=inputs,
  95. chain_output=chain_output,
  96. memory=memory
  97. )
  98. final_llm.callback_manager = cls.get_llm_callback_manager(final_llm, streaming, conversation_message_task)
  99. cls.recale_llm_max_tokens(
  100. final_llm=final_llm,
  101. prompt=prompt,
  102. mode=mode
  103. )
  104. response = final_llm.generate([prompt], stop_words)
  105. return response
  106. @classmethod
  107. def get_main_llm_prompt(cls, mode: str, llm: BaseLanguageModel, pre_prompt: str, query: str, inputs: dict,
  108. chain_output: Optional[str],
  109. memory: Optional[ReadOnlyConversationTokenDBBufferSharedMemory]) -> \
  110. Tuple[Union[str | List[BaseMessage]], Optional[List[str]]]:
  111. # disable template string in query
  112. query_params = OutLinePromptTemplate.from_template(template=query).input_variables
  113. if query_params:
  114. for query_param in query_params:
  115. if query_param not in inputs:
  116. inputs[query_param] = '{' + query_param + '}'
  117. pre_prompt = PromptBuilder.process_template(pre_prompt) if pre_prompt else pre_prompt
  118. if mode == 'completion':
  119. prompt_template = OutLinePromptTemplate.from_template(
  120. template=("""Use the following CONTEXT as your learned knowledge:
  121. [CONTEXT]
  122. {context}
  123. [END CONTEXT]
  124. When answer to user:
  125. - If you don't know, just say that you don't know.
  126. - If you don't know when you are not sure, ask for clarification.
  127. Avoid mentioning that you obtained the information from the context.
  128. And answer according to the language of the user's question.
  129. """ if chain_output else "")
  130. + (pre_prompt + "\n" if pre_prompt else "")
  131. + "{query}\n"
  132. )
  133. if chain_output:
  134. inputs['context'] = chain_output
  135. context_params = OutLinePromptTemplate.from_template(template=chain_output).input_variables
  136. if context_params:
  137. for context_param in context_params:
  138. if context_param not in inputs:
  139. inputs[context_param] = '{' + context_param + '}'
  140. prompt_inputs = {k: inputs[k] for k in prompt_template.input_variables if k in inputs}
  141. prompt_content = prompt_template.format(
  142. query=query,
  143. **prompt_inputs
  144. )
  145. if isinstance(llm, BaseChatModel):
  146. # use chat llm as completion model
  147. return [HumanMessage(content=prompt_content)], None
  148. else:
  149. return prompt_content, None
  150. else:
  151. messages: List[BaseMessage] = []
  152. human_inputs = {
  153. "query": query
  154. }
  155. human_message_prompt = ""
  156. if pre_prompt:
  157. pre_prompt_inputs = {k: inputs[k] for k in
  158. OutLinePromptTemplate.from_template(template=pre_prompt).input_variables
  159. if k in inputs}
  160. if pre_prompt_inputs:
  161. human_inputs.update(pre_prompt_inputs)
  162. if chain_output:
  163. human_inputs['context'] = chain_output
  164. human_message_prompt += """Use the following CONTEXT as your learned knowledge.
  165. [CONTEXT]
  166. {context}
  167. [END CONTEXT]
  168. When answer to user:
  169. - If you don't know, just say that you don't know.
  170. - If you don't know when you are not sure, ask for clarification.
  171. Avoid mentioning that you obtained the information from the context.
  172. And answer according to the language of the user's question.
  173. """
  174. if pre_prompt:
  175. human_message_prompt += pre_prompt
  176. query_prompt = "\nHuman: {query}\nAI: "
  177. if memory:
  178. # append chat histories
  179. tmp_human_message = PromptBuilder.to_human_message(
  180. prompt_content=human_message_prompt + query_prompt,
  181. inputs=human_inputs
  182. )
  183. curr_message_tokens = memory.llm.get_messages_tokens([tmp_human_message])
  184. rest_tokens = llm_constant.max_context_token_length[memory.llm.model_name] \
  185. - memory.llm.max_tokens - curr_message_tokens
  186. rest_tokens = max(rest_tokens, 0)
  187. histories = cls.get_history_messages_from_memory(memory, rest_tokens)
  188. # disable template string in query
  189. histories_params = OutLinePromptTemplate.from_template(template=histories).input_variables
  190. if histories_params:
  191. for histories_param in histories_params:
  192. if histories_param not in human_inputs:
  193. human_inputs[histories_param] = '{' + histories_param + '}'
  194. human_message_prompt += "\n\n" + histories
  195. human_message_prompt += query_prompt
  196. # construct main prompt
  197. human_message = PromptBuilder.to_human_message(
  198. prompt_content=human_message_prompt,
  199. inputs=human_inputs
  200. )
  201. messages.append(human_message)
  202. return messages, ['\nHuman:']
  203. @classmethod
  204. def get_llm_callback_manager(cls, llm: Union[StreamableOpenAI, StreamableChatOpenAI],
  205. streaming: bool,
  206. conversation_message_task: ConversationMessageTask) -> CallbackManager:
  207. llm_callback_handler = LLMCallbackHandler(llm, conversation_message_task)
  208. if streaming:
  209. callback_handlers = [llm_callback_handler, DifyStreamingStdOutCallbackHandler()]
  210. else:
  211. callback_handlers = [llm_callback_handler, DifyStdOutCallbackHandler()]
  212. return CallbackManager(callback_handlers)
  213. @classmethod
  214. def get_history_messages_from_memory(cls, memory: ReadOnlyConversationTokenDBBufferSharedMemory,
  215. max_token_limit: int) -> \
  216. str:
  217. """Get memory messages."""
  218. memory.max_token_limit = max_token_limit
  219. memory_key = memory.memory_variables[0]
  220. external_context = memory.load_memory_variables({})
  221. return external_context[memory_key]
  222. @classmethod
  223. def get_memory_from_conversation(cls, tenant_id: str, app_model_config: AppModelConfig,
  224. conversation: Conversation,
  225. **kwargs) -> ReadOnlyConversationTokenDBBufferSharedMemory:
  226. # only for calc token in memory
  227. memory_llm = LLMBuilder.to_llm_from_model(
  228. tenant_id=tenant_id,
  229. model=app_model_config.model_dict
  230. )
  231. # use llm config from conversation
  232. memory = ReadOnlyConversationTokenDBBufferSharedMemory(
  233. conversation=conversation,
  234. llm=memory_llm,
  235. max_token_limit=kwargs.get("max_token_limit", 2048),
  236. memory_key=kwargs.get("memory_key", "chat_history"),
  237. return_messages=kwargs.get("return_messages", True),
  238. input_key=kwargs.get("input_key", "input"),
  239. output_key=kwargs.get("output_key", "output"),
  240. message_limit=kwargs.get("message_limit", 10),
  241. )
  242. return memory
  243. @classmethod
  244. def validate_query_tokens(cls, tenant_id: str, app_model_config: AppModelConfig, query: str):
  245. llm = LLMBuilder.to_llm_from_model(
  246. tenant_id=tenant_id,
  247. model=app_model_config.model_dict
  248. )
  249. model_limited_tokens = llm_constant.max_context_token_length[llm.model_name]
  250. max_tokens = llm.max_tokens
  251. if model_limited_tokens - max_tokens - llm.get_num_tokens(query) < 0:
  252. raise LLMBadRequestError("Query is too long")
  253. @classmethod
  254. def recale_llm_max_tokens(cls, final_llm: Union[StreamableOpenAI, StreamableChatOpenAI],
  255. prompt: Union[str, List[BaseMessage]], mode: str):
  256. # recalc max_tokens if sum(prompt_token + max_tokens) over model token limit
  257. model_limited_tokens = llm_constant.max_context_token_length[final_llm.model_name]
  258. max_tokens = final_llm.max_tokens
  259. if mode == 'completion' and isinstance(final_llm, BaseLLM):
  260. prompt_tokens = final_llm.get_num_tokens(prompt)
  261. else:
  262. prompt_tokens = final_llm.get_messages_tokens(prompt)
  263. if prompt_tokens + max_tokens > model_limited_tokens:
  264. max_tokens = max(model_limited_tokens - prompt_tokens, 16)
  265. final_llm.max_tokens = max_tokens
  266. @classmethod
  267. def generate_more_like_this(cls, task_id: str, app: App, message: Message, pre_prompt: str,
  268. app_model_config: AppModelConfig, user: Account, streaming: bool):
  269. llm: StreamableOpenAI = LLMBuilder.to_llm(
  270. tenant_id=app.tenant_id,
  271. model_name='gpt-3.5-turbo',
  272. streaming=streaming
  273. )
  274. # get llm prompt
  275. original_prompt, _ = cls.get_main_llm_prompt(
  276. mode="completion",
  277. llm=llm,
  278. pre_prompt=pre_prompt,
  279. query=message.query,
  280. inputs=message.inputs,
  281. chain_output=None,
  282. memory=None
  283. )
  284. original_completion = message.answer.strip()
  285. prompt = MORE_LIKE_THIS_GENERATE_PROMPT
  286. prompt = prompt.format(prompt=original_prompt, original_completion=original_completion)
  287. if isinstance(llm, BaseChatModel):
  288. prompt = [HumanMessage(content=prompt)]
  289. conversation_message_task = ConversationMessageTask(
  290. task_id=task_id,
  291. app=app,
  292. app_model_config=app_model_config,
  293. user=user,
  294. inputs=message.inputs,
  295. query=message.query,
  296. is_override=True if message.override_model_configs else False,
  297. streaming=streaming
  298. )
  299. llm.callback_manager = cls.get_llm_callback_manager(llm, streaming, conversation_message_task)
  300. cls.recale_llm_max_tokens(
  301. final_llm=llm,
  302. prompt=prompt,
  303. mode='completion'
  304. )
  305. llm.generate([prompt])