assistant_app_runner.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317
  1. import logging
  2. from typing import cast
  3. from core.app_runner.app_runner import AppRunner
  4. from core.application_queue_manager import ApplicationQueueManager, PublishFrom
  5. from core.entities.application_entities import AgentEntity, ApplicationGenerateEntity, ModelConfigEntity
  6. from core.features.assistant_cot_runner import AssistantCotApplicationRunner
  7. from core.features.assistant_fc_runner import AssistantFunctionCallApplicationRunner
  8. from core.memory.token_buffer_memory import TokenBufferMemory
  9. from core.model_manager import ModelInstance
  10. from core.model_runtime.entities.llm_entities import LLMUsage
  11. from core.model_runtime.entities.model_entities import ModelFeature
  12. from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
  13. from core.moderation.base import ModerationException
  14. from core.tools.entities.tool_entities import ToolRuntimeVariablePool
  15. from extensions.ext_database import db
  16. from models.model import App, Conversation, Message, MessageAgentThought
  17. from models.tools import ToolConversationVariables
  18. logger = logging.getLogger(__name__)
  19. class AssistantApplicationRunner(AppRunner):
  20. """
  21. Assistant Application Runner
  22. """
  23. def run(self, application_generate_entity: ApplicationGenerateEntity,
  24. queue_manager: ApplicationQueueManager,
  25. conversation: Conversation,
  26. message: Message) -> None:
  27. """
  28. Run assistant application
  29. :param application_generate_entity: application generate entity
  30. :param queue_manager: application queue manager
  31. :param conversation: conversation
  32. :param message: message
  33. :return:
  34. """
  35. app_record = db.session.query(App).filter(App.id == application_generate_entity.app_id).first()
  36. if not app_record:
  37. raise ValueError("App not found")
  38. app_orchestration_config = application_generate_entity.app_orchestration_config_entity
  39. inputs = application_generate_entity.inputs
  40. query = application_generate_entity.query
  41. files = application_generate_entity.files
  42. # Pre-calculate the number of tokens of the prompt messages,
  43. # and return the rest number of tokens by model context token size limit and max token size limit.
  44. # If the rest number of tokens is not enough, raise exception.
  45. # Include: prompt template, inputs, query(optional), files(optional)
  46. # Not Include: memory, external data, dataset context
  47. self.get_pre_calculate_rest_tokens(
  48. app_record=app_record,
  49. model_config=app_orchestration_config.model_config,
  50. prompt_template_entity=app_orchestration_config.prompt_template,
  51. inputs=inputs,
  52. files=files,
  53. query=query
  54. )
  55. memory = None
  56. if application_generate_entity.conversation_id:
  57. # get memory of conversation (read-only)
  58. model_instance = ModelInstance(
  59. provider_model_bundle=app_orchestration_config.model_config.provider_model_bundle,
  60. model=app_orchestration_config.model_config.model
  61. )
  62. memory = TokenBufferMemory(
  63. conversation=conversation,
  64. model_instance=model_instance
  65. )
  66. # organize all inputs and template to prompt messages
  67. # Include: prompt template, inputs, query(optional), files(optional)
  68. # memory(optional)
  69. prompt_messages, _ = self.organize_prompt_messages(
  70. app_record=app_record,
  71. model_config=app_orchestration_config.model_config,
  72. prompt_template_entity=app_orchestration_config.prompt_template,
  73. inputs=inputs,
  74. files=files,
  75. query=query,
  76. memory=memory
  77. )
  78. # moderation
  79. try:
  80. # process sensitive_word_avoidance
  81. _, inputs, query = self.moderation_for_inputs(
  82. app_id=app_record.id,
  83. tenant_id=application_generate_entity.tenant_id,
  84. app_orchestration_config_entity=app_orchestration_config,
  85. inputs=inputs,
  86. query=query,
  87. )
  88. except ModerationException as e:
  89. self.direct_output(
  90. queue_manager=queue_manager,
  91. app_orchestration_config=app_orchestration_config,
  92. prompt_messages=prompt_messages,
  93. text=str(e),
  94. stream=application_generate_entity.stream
  95. )
  96. return
  97. if query:
  98. # annotation reply
  99. annotation_reply = self.query_app_annotations_to_reply(
  100. app_record=app_record,
  101. message=message,
  102. query=query,
  103. user_id=application_generate_entity.user_id,
  104. invoke_from=application_generate_entity.invoke_from
  105. )
  106. if annotation_reply:
  107. queue_manager.publish_annotation_reply(
  108. message_annotation_id=annotation_reply.id,
  109. pub_from=PublishFrom.APPLICATION_MANAGER
  110. )
  111. self.direct_output(
  112. queue_manager=queue_manager,
  113. app_orchestration_config=app_orchestration_config,
  114. prompt_messages=prompt_messages,
  115. text=annotation_reply.content,
  116. stream=application_generate_entity.stream
  117. )
  118. return
  119. # fill in variable inputs from external data tools if exists
  120. external_data_tools = app_orchestration_config.external_data_variables
  121. if external_data_tools:
  122. inputs = self.fill_in_inputs_from_external_data_tools(
  123. tenant_id=app_record.tenant_id,
  124. app_id=app_record.id,
  125. external_data_tools=external_data_tools,
  126. inputs=inputs,
  127. query=query
  128. )
  129. # reorganize all inputs and template to prompt messages
  130. # Include: prompt template, inputs, query(optional), files(optional)
  131. # memory(optional), external data, dataset context(optional)
  132. prompt_messages, _ = self.organize_prompt_messages(
  133. app_record=app_record,
  134. model_config=app_orchestration_config.model_config,
  135. prompt_template_entity=app_orchestration_config.prompt_template,
  136. inputs=inputs,
  137. files=files,
  138. query=query,
  139. memory=memory
  140. )
  141. # check hosting moderation
  142. hosting_moderation_result = self.check_hosting_moderation(
  143. application_generate_entity=application_generate_entity,
  144. queue_manager=queue_manager,
  145. prompt_messages=prompt_messages
  146. )
  147. if hosting_moderation_result:
  148. return
  149. agent_entity = app_orchestration_config.agent
  150. # load tool variables
  151. tool_conversation_variables = self._load_tool_variables(conversation_id=conversation.id,
  152. user_id=application_generate_entity.user_id,
  153. tenant_id=application_generate_entity.tenant_id)
  154. # convert db variables to tool variables
  155. tool_variables = self._convert_db_variables_to_tool_variables(tool_conversation_variables)
  156. # init model instance
  157. model_instance = ModelInstance(
  158. provider_model_bundle=app_orchestration_config.model_config.provider_model_bundle,
  159. model=app_orchestration_config.model_config.model
  160. )
  161. prompt_message, _ = self.organize_prompt_messages(
  162. app_record=app_record,
  163. model_config=app_orchestration_config.model_config,
  164. prompt_template_entity=app_orchestration_config.prompt_template,
  165. inputs=inputs,
  166. files=files,
  167. query=query,
  168. memory=memory,
  169. )
  170. # change function call strategy based on LLM model
  171. llm_model = cast(LargeLanguageModel, model_instance.model_type_instance)
  172. model_schema = llm_model.get_model_schema(model_instance.model, model_instance.credentials)
  173. if set([ModelFeature.MULTI_TOOL_CALL, ModelFeature.TOOL_CALL]).intersection(model_schema.features or []):
  174. agent_entity.strategy = AgentEntity.Strategy.FUNCTION_CALLING
  175. db.session.refresh(conversation)
  176. db.session.refresh(message)
  177. db.session.close()
  178. # start agent runner
  179. if agent_entity.strategy == AgentEntity.Strategy.CHAIN_OF_THOUGHT:
  180. assistant_cot_runner = AssistantCotApplicationRunner(
  181. tenant_id=application_generate_entity.tenant_id,
  182. application_generate_entity=application_generate_entity,
  183. app_orchestration_config=app_orchestration_config,
  184. model_config=app_orchestration_config.model_config,
  185. config=agent_entity,
  186. queue_manager=queue_manager,
  187. message=message,
  188. user_id=application_generate_entity.user_id,
  189. memory=memory,
  190. prompt_messages=prompt_message,
  191. variables_pool=tool_variables,
  192. db_variables=tool_conversation_variables,
  193. model_instance=model_instance
  194. )
  195. invoke_result = assistant_cot_runner.run(
  196. conversation=conversation,
  197. message=message,
  198. query=query,
  199. inputs=inputs,
  200. )
  201. elif agent_entity.strategy == AgentEntity.Strategy.FUNCTION_CALLING:
  202. assistant_fc_runner = AssistantFunctionCallApplicationRunner(
  203. tenant_id=application_generate_entity.tenant_id,
  204. application_generate_entity=application_generate_entity,
  205. app_orchestration_config=app_orchestration_config,
  206. model_config=app_orchestration_config.model_config,
  207. config=agent_entity,
  208. queue_manager=queue_manager,
  209. message=message,
  210. user_id=application_generate_entity.user_id,
  211. memory=memory,
  212. prompt_messages=prompt_message,
  213. variables_pool=tool_variables,
  214. db_variables=tool_conversation_variables,
  215. model_instance=model_instance
  216. )
  217. invoke_result = assistant_fc_runner.run(
  218. conversation=conversation,
  219. message=message,
  220. query=query,
  221. )
  222. # handle invoke result
  223. self._handle_invoke_result(
  224. invoke_result=invoke_result,
  225. queue_manager=queue_manager,
  226. stream=application_generate_entity.stream,
  227. agent=True
  228. )
  229. def _load_tool_variables(self, conversation_id: str, user_id: str, tenant_id: str) -> ToolConversationVariables:
  230. """
  231. load tool variables from database
  232. """
  233. tool_variables: ToolConversationVariables = db.session.query(ToolConversationVariables).filter(
  234. ToolConversationVariables.conversation_id == conversation_id,
  235. ToolConversationVariables.tenant_id == tenant_id
  236. ).first()
  237. if tool_variables:
  238. # save tool variables to session, so that we can update it later
  239. db.session.add(tool_variables)
  240. else:
  241. # create new tool variables
  242. tool_variables = ToolConversationVariables(
  243. conversation_id=conversation_id,
  244. user_id=user_id,
  245. tenant_id=tenant_id,
  246. variables_str='[]',
  247. )
  248. db.session.add(tool_variables)
  249. db.session.commit()
  250. return tool_variables
  251. def _convert_db_variables_to_tool_variables(self, db_variables: ToolConversationVariables) -> ToolRuntimeVariablePool:
  252. """
  253. convert db variables to tool variables
  254. """
  255. return ToolRuntimeVariablePool(**{
  256. 'conversation_id': db_variables.conversation_id,
  257. 'user_id': db_variables.user_id,
  258. 'tenant_id': db_variables.tenant_id,
  259. 'pool': db_variables.variables
  260. })
  261. def _get_usage_of_all_agent_thoughts(self, model_config: ModelConfigEntity,
  262. message: Message) -> LLMUsage:
  263. """
  264. Get usage of all agent thoughts
  265. :param model_config: model config
  266. :param message: message
  267. :return:
  268. """
  269. agent_thoughts = (db.session.query(MessageAgentThought)
  270. .filter(MessageAgentThought.message_id == message.id).all())
  271. all_message_tokens = 0
  272. all_answer_tokens = 0
  273. for agent_thought in agent_thoughts:
  274. all_message_tokens += agent_thought.message_tokens
  275. all_answer_tokens += agent_thought.answer_tokens
  276. model_type_instance = model_config.provider_model_bundle.model_type_instance
  277. model_type_instance = cast(LargeLanguageModel, model_type_instance)
  278. return model_type_instance._calc_response_usage(
  279. model_config.model,
  280. model_config.credentials,
  281. all_message_tokens,
  282. all_answer_tokens
  283. )