base_agent_runner.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531
  1. import json
  2. import logging
  3. import uuid
  4. from datetime import datetime, timezone
  5. from typing import Optional, Union, cast
  6. from core.agent.entities import AgentEntity, AgentToolEntity
  7. from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
  8. from core.app.apps.agent_chat.app_config_manager import AgentChatAppConfig
  9. from core.app.apps.base_app_queue_manager import AppQueueManager
  10. from core.app.apps.base_app_runner import AppRunner
  11. from core.app.entities.app_invoke_entities import (
  12. AgentChatAppGenerateEntity,
  13. ModelConfigWithCredentialsEntity,
  14. )
  15. from core.callback_handler.agent_tool_callback_handler import DifyAgentCallbackHandler
  16. from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
  17. from core.file.message_file_parser import MessageFileParser
  18. from core.memory.token_buffer_memory import TokenBufferMemory
  19. from core.model_manager import ModelInstance
  20. from core.model_runtime.entities.llm_entities import LLMUsage
  21. from core.model_runtime.entities.message_entities import (
  22. AssistantPromptMessage,
  23. PromptMessage,
  24. PromptMessageTool,
  25. SystemPromptMessage,
  26. TextPromptMessageContent,
  27. ToolPromptMessage,
  28. UserPromptMessage,
  29. )
  30. from core.model_runtime.entities.model_entities import ModelFeature
  31. from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
  32. from core.model_runtime.utils.encoders import jsonable_encoder
  33. from core.tools.entities.tool_entities import (
  34. ToolInvokeMessage,
  35. ToolParameter,
  36. ToolRuntimeVariablePool,
  37. )
  38. from core.tools.tool.dataset_retriever_tool import DatasetRetrieverTool
  39. from core.tools.tool.tool import Tool
  40. from core.tools.tool_manager import ToolManager
  41. from core.tools.utils.tool_parameter_converter import ToolParameterConverter
  42. from extensions.ext_database import db
  43. from models.model import Conversation, Message, MessageAgentThought
  44. from models.tools import ToolConversationVariables
  45. logger = logging.getLogger(__name__)
  46. class BaseAgentRunner(AppRunner):
  47. def __init__(self, tenant_id: str,
  48. application_generate_entity: AgentChatAppGenerateEntity,
  49. conversation: Conversation,
  50. app_config: AgentChatAppConfig,
  51. model_config: ModelConfigWithCredentialsEntity,
  52. config: AgentEntity,
  53. queue_manager: AppQueueManager,
  54. message: Message,
  55. user_id: str,
  56. memory: Optional[TokenBufferMemory] = None,
  57. prompt_messages: Optional[list[PromptMessage]] = None,
  58. variables_pool: Optional[ToolRuntimeVariablePool] = None,
  59. db_variables: Optional[ToolConversationVariables] = None,
  60. model_instance: ModelInstance = None
  61. ) -> None:
  62. """
  63. Agent runner
  64. :param tenant_id: tenant id
  65. :param app_config: app generate entity
  66. :param model_config: model config
  67. :param config: dataset config
  68. :param queue_manager: queue manager
  69. :param message: message
  70. :param user_id: user id
  71. :param agent_llm_callback: agent llm callback
  72. :param callback: callback
  73. :param memory: memory
  74. """
  75. self.tenant_id = tenant_id
  76. self.application_generate_entity = application_generate_entity
  77. self.conversation = conversation
  78. self.app_config = app_config
  79. self.model_config = model_config
  80. self.config = config
  81. self.queue_manager = queue_manager
  82. self.message = message
  83. self.user_id = user_id
  84. self.memory = memory
  85. self.history_prompt_messages = self.organize_agent_history(
  86. prompt_messages=prompt_messages or []
  87. )
  88. self.variables_pool = variables_pool
  89. self.db_variables_pool = db_variables
  90. self.model_instance = model_instance
  91. # init callback
  92. self.agent_callback = DifyAgentCallbackHandler()
  93. # init dataset tools
  94. hit_callback = DatasetIndexToolCallbackHandler(
  95. queue_manager=queue_manager,
  96. app_id=self.app_config.app_id,
  97. message_id=message.id,
  98. user_id=user_id,
  99. invoke_from=self.application_generate_entity.invoke_from,
  100. )
  101. self.dataset_tools = DatasetRetrieverTool.get_dataset_tools(
  102. tenant_id=tenant_id,
  103. dataset_ids=app_config.dataset.dataset_ids if app_config.dataset else [],
  104. retrieve_config=app_config.dataset.retrieve_config if app_config.dataset else None,
  105. return_resource=app_config.additional_features.show_retrieve_source,
  106. invoke_from=application_generate_entity.invoke_from,
  107. hit_callback=hit_callback
  108. )
  109. # get how many agent thoughts have been created
  110. self.agent_thought_count = db.session.query(MessageAgentThought).filter(
  111. MessageAgentThought.message_id == self.message.id,
  112. ).count()
  113. db.session.close()
  114. # check if model supports stream tool call
  115. llm_model = cast(LargeLanguageModel, model_instance.model_type_instance)
  116. model_schema = llm_model.get_model_schema(model_instance.model, model_instance.credentials)
  117. if model_schema and ModelFeature.STREAM_TOOL_CALL in (model_schema.features or []):
  118. self.stream_tool_call = True
  119. else:
  120. self.stream_tool_call = False
  121. # check if model supports vision
  122. if model_schema and ModelFeature.VISION in (model_schema.features or []):
  123. self.files = application_generate_entity.files
  124. else:
  125. self.files = []
  126. self.query = None
  127. self._current_thoughts: list[PromptMessage] = []
  128. def _repack_app_generate_entity(self, app_generate_entity: AgentChatAppGenerateEntity) \
  129. -> AgentChatAppGenerateEntity:
  130. """
  131. Repack app generate entity
  132. """
  133. if app_generate_entity.app_config.prompt_template.simple_prompt_template is None:
  134. app_generate_entity.app_config.prompt_template.simple_prompt_template = ''
  135. return app_generate_entity
  136. def _convert_tool_response_to_str(self, tool_response: list[ToolInvokeMessage]) -> str:
  137. """
  138. Handle tool response
  139. """
  140. result = ''
  141. for response in tool_response:
  142. if response.type == ToolInvokeMessage.MessageType.TEXT:
  143. result += response.message
  144. elif response.type == ToolInvokeMessage.MessageType.LINK:
  145. result += f"result link: {response.message}. please tell user to check it."
  146. elif response.type == ToolInvokeMessage.MessageType.IMAGE_LINK or \
  147. response.type == ToolInvokeMessage.MessageType.IMAGE:
  148. result += "image has been created and sent to user already, you do not need to create it, just tell the user to check it now."
  149. else:
  150. result += f"tool response: {response.message}."
  151. return result
  152. def _convert_tool_to_prompt_message_tool(self, tool: AgentToolEntity) -> tuple[PromptMessageTool, Tool]:
  153. """
  154. convert tool to prompt message tool
  155. """
  156. tool_entity = ToolManager.get_agent_tool_runtime(
  157. tenant_id=self.tenant_id,
  158. app_id=self.app_config.app_id,
  159. agent_tool=tool,
  160. invoke_from=self.application_generate_entity.invoke_from
  161. )
  162. tool_entity.load_variables(self.variables_pool)
  163. message_tool = PromptMessageTool(
  164. name=tool.tool_name,
  165. description=tool_entity.description.llm,
  166. parameters={
  167. "type": "object",
  168. "properties": {},
  169. "required": [],
  170. }
  171. )
  172. parameters = tool_entity.get_all_runtime_parameters()
  173. for parameter in parameters:
  174. if parameter.form != ToolParameter.ToolParameterForm.LLM:
  175. continue
  176. parameter_type = ToolParameterConverter.get_parameter_type(parameter.type)
  177. enum = []
  178. if parameter.type == ToolParameter.ToolParameterType.SELECT:
  179. enum = [option.value for option in parameter.options]
  180. message_tool.parameters['properties'][parameter.name] = {
  181. "type": parameter_type,
  182. "description": parameter.llm_description or '',
  183. }
  184. if len(enum) > 0:
  185. message_tool.parameters['properties'][parameter.name]['enum'] = enum
  186. if parameter.required:
  187. message_tool.parameters['required'].append(parameter.name)
  188. return message_tool, tool_entity
  189. def _convert_dataset_retriever_tool_to_prompt_message_tool(self, tool: DatasetRetrieverTool) -> PromptMessageTool:
  190. """
  191. convert dataset retriever tool to prompt message tool
  192. """
  193. prompt_tool = PromptMessageTool(
  194. name=tool.identity.name,
  195. description=tool.description.llm,
  196. parameters={
  197. "type": "object",
  198. "properties": {},
  199. "required": [],
  200. }
  201. )
  202. for parameter in tool.get_runtime_parameters():
  203. parameter_type = 'string'
  204. prompt_tool.parameters['properties'][parameter.name] = {
  205. "type": parameter_type,
  206. "description": parameter.llm_description or '',
  207. }
  208. if parameter.required:
  209. if parameter.name not in prompt_tool.parameters['required']:
  210. prompt_tool.parameters['required'].append(parameter.name)
  211. return prompt_tool
  212. def _init_prompt_tools(self) -> tuple[dict[str, Tool], list[PromptMessageTool]]:
  213. """
  214. Init tools
  215. """
  216. tool_instances = {}
  217. prompt_messages_tools = []
  218. for tool in self.app_config.agent.tools if self.app_config.agent else []:
  219. try:
  220. prompt_tool, tool_entity = self._convert_tool_to_prompt_message_tool(tool)
  221. except Exception:
  222. # api tool may be deleted
  223. continue
  224. # save tool entity
  225. tool_instances[tool.tool_name] = tool_entity
  226. # save prompt tool
  227. prompt_messages_tools.append(prompt_tool)
  228. # convert dataset tools into ModelRuntime Tool format
  229. for dataset_tool in self.dataset_tools:
  230. prompt_tool = self._convert_dataset_retriever_tool_to_prompt_message_tool(dataset_tool)
  231. # save prompt tool
  232. prompt_messages_tools.append(prompt_tool)
  233. # save tool entity
  234. tool_instances[dataset_tool.identity.name] = dataset_tool
  235. return tool_instances, prompt_messages_tools
  236. def update_prompt_message_tool(self, tool: Tool, prompt_tool: PromptMessageTool) -> PromptMessageTool:
  237. """
  238. update prompt message tool
  239. """
  240. # try to get tool runtime parameters
  241. tool_runtime_parameters = tool.get_runtime_parameters() or []
  242. for parameter in tool_runtime_parameters:
  243. if parameter.form != ToolParameter.ToolParameterForm.LLM:
  244. continue
  245. parameter_type = ToolParameterConverter.get_parameter_type(parameter.type)
  246. enum = []
  247. if parameter.type == ToolParameter.ToolParameterType.SELECT:
  248. enum = [option.value for option in parameter.options]
  249. prompt_tool.parameters['properties'][parameter.name] = {
  250. "type": parameter_type,
  251. "description": parameter.llm_description or '',
  252. }
  253. if len(enum) > 0:
  254. prompt_tool.parameters['properties'][parameter.name]['enum'] = enum
  255. if parameter.required:
  256. if parameter.name not in prompt_tool.parameters['required']:
  257. prompt_tool.parameters['required'].append(parameter.name)
  258. return prompt_tool
  259. def create_agent_thought(self, message_id: str, message: str,
  260. tool_name: str, tool_input: str, messages_ids: list[str]
  261. ) -> MessageAgentThought:
  262. """
  263. Create agent thought
  264. """
  265. thought = MessageAgentThought(
  266. message_id=message_id,
  267. message_chain_id=None,
  268. thought='',
  269. tool=tool_name,
  270. tool_labels_str='{}',
  271. tool_meta_str='{}',
  272. tool_input=tool_input,
  273. message=message,
  274. message_token=0,
  275. message_unit_price=0,
  276. message_price_unit=0,
  277. message_files=json.dumps(messages_ids) if messages_ids else '',
  278. answer='',
  279. observation='',
  280. answer_token=0,
  281. answer_unit_price=0,
  282. answer_price_unit=0,
  283. tokens=0,
  284. total_price=0,
  285. position=self.agent_thought_count + 1,
  286. currency='USD',
  287. latency=0,
  288. created_by_role='account',
  289. created_by=self.user_id,
  290. )
  291. db.session.add(thought)
  292. db.session.commit()
  293. db.session.refresh(thought)
  294. db.session.close()
  295. self.agent_thought_count += 1
  296. return thought
  297. def save_agent_thought(self,
  298. agent_thought: MessageAgentThought,
  299. tool_name: str,
  300. tool_input: Union[str, dict],
  301. thought: str,
  302. observation: Union[str, dict],
  303. tool_invoke_meta: Union[str, dict],
  304. answer: str,
  305. messages_ids: list[str],
  306. llm_usage: LLMUsage = None) -> MessageAgentThought:
  307. """
  308. Save agent thought
  309. """
  310. agent_thought = db.session.query(MessageAgentThought).filter(
  311. MessageAgentThought.id == agent_thought.id
  312. ).first()
  313. if thought is not None:
  314. agent_thought.thought = thought
  315. if tool_name is not None:
  316. agent_thought.tool = tool_name
  317. if tool_input is not None:
  318. if isinstance(tool_input, dict):
  319. try:
  320. tool_input = json.dumps(tool_input, ensure_ascii=False)
  321. except Exception as e:
  322. tool_input = json.dumps(tool_input)
  323. agent_thought.tool_input = tool_input
  324. if observation is not None:
  325. if isinstance(observation, dict):
  326. try:
  327. observation = json.dumps(observation, ensure_ascii=False)
  328. except Exception as e:
  329. observation = json.dumps(observation)
  330. agent_thought.observation = observation
  331. if answer is not None:
  332. agent_thought.answer = answer
  333. if messages_ids is not None and len(messages_ids) > 0:
  334. agent_thought.message_files = json.dumps(messages_ids)
  335. if llm_usage:
  336. agent_thought.message_token = llm_usage.prompt_tokens
  337. agent_thought.message_price_unit = llm_usage.prompt_price_unit
  338. agent_thought.message_unit_price = llm_usage.prompt_unit_price
  339. agent_thought.answer_token = llm_usage.completion_tokens
  340. agent_thought.answer_price_unit = llm_usage.completion_price_unit
  341. agent_thought.answer_unit_price = llm_usage.completion_unit_price
  342. agent_thought.tokens = llm_usage.total_tokens
  343. agent_thought.total_price = llm_usage.total_price
  344. # check if tool labels is not empty
  345. labels = agent_thought.tool_labels or {}
  346. tools = agent_thought.tool.split(';') if agent_thought.tool else []
  347. for tool in tools:
  348. if not tool:
  349. continue
  350. if tool not in labels:
  351. tool_label = ToolManager.get_tool_label(tool)
  352. if tool_label:
  353. labels[tool] = tool_label.to_dict()
  354. else:
  355. labels[tool] = {'en_US': tool, 'zh_Hans': tool}
  356. agent_thought.tool_labels_str = json.dumps(labels)
  357. if tool_invoke_meta is not None:
  358. if isinstance(tool_invoke_meta, dict):
  359. try:
  360. tool_invoke_meta = json.dumps(tool_invoke_meta, ensure_ascii=False)
  361. except Exception as e:
  362. tool_invoke_meta = json.dumps(tool_invoke_meta)
  363. agent_thought.tool_meta_str = tool_invoke_meta
  364. db.session.commit()
  365. db.session.close()
  366. def update_db_variables(self, tool_variables: ToolRuntimeVariablePool, db_variables: ToolConversationVariables):
  367. """
  368. convert tool variables to db variables
  369. """
  370. db_variables = db.session.query(ToolConversationVariables).filter(
  371. ToolConversationVariables.conversation_id == self.message.conversation_id,
  372. ).first()
  373. db_variables.updated_at = datetime.now(timezone.utc).replace(tzinfo=None)
  374. db_variables.variables_str = json.dumps(jsonable_encoder(tool_variables.pool))
  375. db.session.commit()
  376. db.session.close()
  377. def organize_agent_history(self, prompt_messages: list[PromptMessage]) -> list[PromptMessage]:
  378. """
  379. Organize agent history
  380. """
  381. result = []
  382. # check if there is a system message in the beginning of the conversation
  383. for prompt_message in prompt_messages:
  384. if isinstance(prompt_message, SystemPromptMessage):
  385. result.append(prompt_message)
  386. messages: list[Message] = db.session.query(Message).filter(
  387. Message.conversation_id == self.message.conversation_id,
  388. ).order_by(Message.created_at.asc()).all()
  389. for message in messages:
  390. if message.id == self.message.id:
  391. continue
  392. result.append(self.organize_agent_user_prompt(message))
  393. agent_thoughts: list[MessageAgentThought] = message.agent_thoughts
  394. if agent_thoughts:
  395. for agent_thought in agent_thoughts:
  396. tools = agent_thought.tool
  397. if tools:
  398. tools = tools.split(';')
  399. tool_calls: list[AssistantPromptMessage.ToolCall] = []
  400. tool_call_response: list[ToolPromptMessage] = []
  401. try:
  402. tool_inputs = json.loads(agent_thought.tool_input)
  403. except Exception as e:
  404. tool_inputs = { tool: {} for tool in tools }
  405. try:
  406. tool_responses = json.loads(agent_thought.observation)
  407. except Exception as e:
  408. tool_responses = { tool: agent_thought.observation for tool in tools }
  409. for tool in tools:
  410. # generate a uuid for tool call
  411. tool_call_id = str(uuid.uuid4())
  412. tool_calls.append(AssistantPromptMessage.ToolCall(
  413. id=tool_call_id,
  414. type='function',
  415. function=AssistantPromptMessage.ToolCall.ToolCallFunction(
  416. name=tool,
  417. arguments=json.dumps(tool_inputs.get(tool, {})),
  418. )
  419. ))
  420. tool_call_response.append(ToolPromptMessage(
  421. content=tool_responses.get(tool, agent_thought.observation),
  422. name=tool,
  423. tool_call_id=tool_call_id,
  424. ))
  425. result.extend([
  426. AssistantPromptMessage(
  427. content=agent_thought.thought,
  428. tool_calls=tool_calls,
  429. ),
  430. *tool_call_response
  431. ])
  432. if not tools:
  433. result.append(AssistantPromptMessage(content=agent_thought.thought))
  434. else:
  435. if message.answer:
  436. result.append(AssistantPromptMessage(content=message.answer))
  437. db.session.close()
  438. return result
  439. def organize_agent_user_prompt(self, message: Message) -> UserPromptMessage:
  440. message_file_parser = MessageFileParser(
  441. tenant_id=self.tenant_id,
  442. app_id=self.app_config.app_id,
  443. )
  444. files = message.message_files
  445. if files:
  446. file_extra_config = FileUploadConfigManager.convert(message.app_model_config.to_dict())
  447. if file_extra_config:
  448. file_objs = message_file_parser.transform_message_files(
  449. files,
  450. file_extra_config
  451. )
  452. else:
  453. file_objs = []
  454. if not file_objs:
  455. return UserPromptMessage(content=message.query)
  456. else:
  457. prompt_message_contents = [TextPromptMessageContent(data=message.query)]
  458. for file_obj in file_objs:
  459. prompt_message_contents.append(file_obj.prompt_message_content)
  460. return UserPromptMessage(content=prompt_message_contents)
  461. else:
  462. return UserPromptMessage(content=message.query)