assistant_app_runner.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351
  1. import json
  2. import logging
  3. from typing import cast
  4. from core.app_runner.app_runner import AppRunner
  5. from core.features.assistant_cot_runner import AssistantCotApplicationRunner
  6. from core.features.assistant_fc_runner import AssistantFunctionCallApplicationRunner
  7. from core.entities.application_entities import ApplicationGenerateEntity, ModelConfigEntity, \
  8. AgentEntity
  9. from core.application_queue_manager import ApplicationQueueManager, PublishFrom
  10. from core.memory.token_buffer_memory import TokenBufferMemory
  11. from core.model_manager import ModelInstance
  12. from core.model_runtime.entities.llm_entities import LLMUsage
  13. from core.model_runtime.entities.model_entities import ModelFeature
  14. from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
  15. from core.moderation.base import ModerationException
  16. from core.tools.entities.tool_entities import ToolRuntimeVariablePool
  17. from extensions.ext_database import db
  18. from models.model import Conversation, Message, App, MessageChain, MessageAgentThought
  19. from models.tools import ToolConversationVariables
  20. logger = logging.getLogger(__name__)
  21. class AssistantApplicationRunner(AppRunner):
  22. """
  23. Assistant Application Runner
  24. """
  25. def run(self, application_generate_entity: ApplicationGenerateEntity,
  26. queue_manager: ApplicationQueueManager,
  27. conversation: Conversation,
  28. message: Message) -> None:
  29. """
  30. Run assistant application
  31. :param application_generate_entity: application generate entity
  32. :param queue_manager: application queue manager
  33. :param conversation: conversation
  34. :param message: message
  35. :return:
  36. """
  37. app_record = db.session.query(App).filter(App.id == application_generate_entity.app_id).first()
  38. if not app_record:
  39. raise ValueError(f"App not found")
  40. app_orchestration_config = application_generate_entity.app_orchestration_config_entity
  41. inputs = application_generate_entity.inputs
  42. query = application_generate_entity.query
  43. files = application_generate_entity.files
  44. # Pre-calculate the number of tokens of the prompt messages,
  45. # and return the rest number of tokens by model context token size limit and max token size limit.
  46. # If the rest number of tokens is not enough, raise exception.
  47. # Include: prompt template, inputs, query(optional), files(optional)
  48. # Not Include: memory, external data, dataset context
  49. self.get_pre_calculate_rest_tokens(
  50. app_record=app_record,
  51. model_config=app_orchestration_config.model_config,
  52. prompt_template_entity=app_orchestration_config.prompt_template,
  53. inputs=inputs,
  54. files=files,
  55. query=query
  56. )
  57. memory = None
  58. if application_generate_entity.conversation_id:
  59. # get memory of conversation (read-only)
  60. model_instance = ModelInstance(
  61. provider_model_bundle=app_orchestration_config.model_config.provider_model_bundle,
  62. model=app_orchestration_config.model_config.model
  63. )
  64. memory = TokenBufferMemory(
  65. conversation=conversation,
  66. model_instance=model_instance
  67. )
  68. # organize all inputs and template to prompt messages
  69. # Include: prompt template, inputs, query(optional), files(optional)
  70. # memory(optional)
  71. prompt_messages, _ = self.organize_prompt_messages(
  72. app_record=app_record,
  73. model_config=app_orchestration_config.model_config,
  74. prompt_template_entity=app_orchestration_config.prompt_template,
  75. inputs=inputs,
  76. files=files,
  77. query=query,
  78. memory=memory
  79. )
  80. # moderation
  81. try:
  82. # process sensitive_word_avoidance
  83. _, inputs, query = self.moderation_for_inputs(
  84. app_id=app_record.id,
  85. tenant_id=application_generate_entity.tenant_id,
  86. app_orchestration_config_entity=app_orchestration_config,
  87. inputs=inputs,
  88. query=query,
  89. )
  90. except ModerationException as e:
  91. self.direct_output(
  92. queue_manager=queue_manager,
  93. app_orchestration_config=app_orchestration_config,
  94. prompt_messages=prompt_messages,
  95. text=str(e),
  96. stream=application_generate_entity.stream
  97. )
  98. return
  99. if query:
  100. # annotation reply
  101. annotation_reply = self.query_app_annotations_to_reply(
  102. app_record=app_record,
  103. message=message,
  104. query=query,
  105. user_id=application_generate_entity.user_id,
  106. invoke_from=application_generate_entity.invoke_from
  107. )
  108. if annotation_reply:
  109. queue_manager.publish_annotation_reply(
  110. message_annotation_id=annotation_reply.id,
  111. pub_from=PublishFrom.APPLICATION_MANAGER
  112. )
  113. self.direct_output(
  114. queue_manager=queue_manager,
  115. app_orchestration_config=app_orchestration_config,
  116. prompt_messages=prompt_messages,
  117. text=annotation_reply.content,
  118. stream=application_generate_entity.stream
  119. )
  120. return
  121. # fill in variable inputs from external data tools if exists
  122. external_data_tools = app_orchestration_config.external_data_variables
  123. if external_data_tools:
  124. inputs = self.fill_in_inputs_from_external_data_tools(
  125. tenant_id=app_record.tenant_id,
  126. app_id=app_record.id,
  127. external_data_tools=external_data_tools,
  128. inputs=inputs,
  129. query=query
  130. )
  131. # reorganize all inputs and template to prompt messages
  132. # Include: prompt template, inputs, query(optional), files(optional)
  133. # memory(optional), external data, dataset context(optional)
  134. prompt_messages, _ = self.organize_prompt_messages(
  135. app_record=app_record,
  136. model_config=app_orchestration_config.model_config,
  137. prompt_template_entity=app_orchestration_config.prompt_template,
  138. inputs=inputs,
  139. files=files,
  140. query=query,
  141. memory=memory
  142. )
  143. # check hosting moderation
  144. hosting_moderation_result = self.check_hosting_moderation(
  145. application_generate_entity=application_generate_entity,
  146. queue_manager=queue_manager,
  147. prompt_messages=prompt_messages
  148. )
  149. if hosting_moderation_result:
  150. return
  151. agent_entity = app_orchestration_config.agent
  152. # load tool variables
  153. tool_conversation_variables = self._load_tool_variables(conversation_id=conversation.id,
  154. user_id=application_generate_entity.user_id,
  155. tenant_id=application_generate_entity.tenant_id)
  156. # convert db variables to tool variables
  157. tool_variables = self._convert_db_variables_to_tool_variables(tool_conversation_variables)
  158. message_chain = self._init_message_chain(
  159. message=message,
  160. query=query
  161. )
  162. # init model instance
  163. model_instance = ModelInstance(
  164. provider_model_bundle=app_orchestration_config.model_config.provider_model_bundle,
  165. model=app_orchestration_config.model_config.model
  166. )
  167. prompt_message, _ = self.organize_prompt_messages(
  168. app_record=app_record,
  169. model_config=app_orchestration_config.model_config,
  170. prompt_template_entity=app_orchestration_config.prompt_template,
  171. inputs=inputs,
  172. files=files,
  173. query=query,
  174. memory=memory,
  175. )
  176. # change function call strategy based on LLM model
  177. llm_model = cast(LargeLanguageModel, model_instance.model_type_instance)
  178. model_schema = llm_model.get_model_schema(model_instance.model, model_instance.credentials)
  179. if set([ModelFeature.MULTI_TOOL_CALL, ModelFeature.TOOL_CALL]).intersection(model_schema.features or []):
  180. agent_entity.strategy = AgentEntity.Strategy.FUNCTION_CALLING
  181. # start agent runner
  182. if agent_entity.strategy == AgentEntity.Strategy.CHAIN_OF_THOUGHT:
  183. assistant_cot_runner = AssistantCotApplicationRunner(
  184. tenant_id=application_generate_entity.tenant_id,
  185. application_generate_entity=application_generate_entity,
  186. app_orchestration_config=app_orchestration_config,
  187. model_config=app_orchestration_config.model_config,
  188. config=agent_entity,
  189. queue_manager=queue_manager,
  190. message=message,
  191. user_id=application_generate_entity.user_id,
  192. memory=memory,
  193. prompt_messages=prompt_message,
  194. variables_pool=tool_variables,
  195. db_variables=tool_conversation_variables,
  196. model_instance=model_instance
  197. )
  198. invoke_result = assistant_cot_runner.run(
  199. conversation=conversation,
  200. message=message,
  201. query=query,
  202. )
  203. elif agent_entity.strategy == AgentEntity.Strategy.FUNCTION_CALLING:
  204. assistant_fc_runner = AssistantFunctionCallApplicationRunner(
  205. tenant_id=application_generate_entity.tenant_id,
  206. application_generate_entity=application_generate_entity,
  207. app_orchestration_config=app_orchestration_config,
  208. model_config=app_orchestration_config.model_config,
  209. config=agent_entity,
  210. queue_manager=queue_manager,
  211. message=message,
  212. user_id=application_generate_entity.user_id,
  213. memory=memory,
  214. prompt_messages=prompt_message,
  215. variables_pool=tool_variables,
  216. db_variables=tool_conversation_variables,
  217. model_instance=model_instance
  218. )
  219. invoke_result = assistant_fc_runner.run(
  220. conversation=conversation,
  221. message=message,
  222. query=query,
  223. )
  224. # handle invoke result
  225. self._handle_invoke_result(
  226. invoke_result=invoke_result,
  227. queue_manager=queue_manager,
  228. stream=application_generate_entity.stream,
  229. agent=True
  230. )
  231. def _load_tool_variables(self, conversation_id: str, user_id: str, tenant_id: str) -> ToolConversationVariables:
  232. """
  233. load tool variables from database
  234. """
  235. tool_variables: ToolConversationVariables = db.session.query(ToolConversationVariables).filter(
  236. ToolConversationVariables.conversation_id == conversation_id,
  237. ToolConversationVariables.tenant_id == tenant_id
  238. ).first()
  239. if tool_variables:
  240. # save tool variables to session, so that we can update it later
  241. db.session.add(tool_variables)
  242. else:
  243. # create new tool variables
  244. tool_variables = ToolConversationVariables(
  245. conversation_id=conversation_id,
  246. user_id=user_id,
  247. tenant_id=tenant_id,
  248. variables_str='[]',
  249. )
  250. db.session.add(tool_variables)
  251. db.session.commit()
  252. return tool_variables
  253. def _convert_db_variables_to_tool_variables(self, db_variables: ToolConversationVariables) -> ToolRuntimeVariablePool:
  254. """
  255. convert db variables to tool variables
  256. """
  257. return ToolRuntimeVariablePool(**{
  258. 'conversation_id': db_variables.conversation_id,
  259. 'user_id': db_variables.user_id,
  260. 'tenant_id': db_variables.tenant_id,
  261. 'pool': db_variables.variables
  262. })
  263. def _init_message_chain(self, message: Message, query: str) -> MessageChain:
  264. """
  265. Init MessageChain
  266. :param message: message
  267. :param query: query
  268. :return:
  269. """
  270. message_chain = MessageChain(
  271. message_id=message.id,
  272. type="AgentExecutor",
  273. input=json.dumps({
  274. "input": query
  275. })
  276. )
  277. db.session.add(message_chain)
  278. db.session.commit()
  279. return message_chain
  280. def _save_message_chain(self, message_chain: MessageChain, output_text: str) -> None:
  281. """
  282. Save MessageChain
  283. :param message_chain: message chain
  284. :param output_text: output text
  285. :return:
  286. """
  287. message_chain.output = json.dumps({
  288. "output": output_text
  289. })
  290. db.session.commit()
  291. def _get_usage_of_all_agent_thoughts(self, model_config: ModelConfigEntity,
  292. message: Message) -> LLMUsage:
  293. """
  294. Get usage of all agent thoughts
  295. :param model_config: model config
  296. :param message: message
  297. :return:
  298. """
  299. agent_thoughts = (db.session.query(MessageAgentThought)
  300. .filter(MessageAgentThought.message_id == message.id).all())
  301. all_message_tokens = 0
  302. all_answer_tokens = 0
  303. for agent_thought in agent_thoughts:
  304. all_message_tokens += agent_thought.message_tokens
  305. all_answer_tokens += agent_thought.answer_tokens
  306. model_type_instance = model_config.provider_model_bundle.model_type_instance
  307. model_type_instance = cast(LargeLanguageModel, model_type_instance)
  308. return model_type_instance._calc_response_usage(
  309. model_config.model,
  310. model_config.credentials,
  311. all_message_tokens,
  312. all_answer_tokens
  313. )