base_agent_runner.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517
  1. import json
  2. import logging
  3. import uuid
  4. from collections.abc import Mapping, Sequence
  5. from typing import Optional, Union, cast
  6. from core.agent.entities import AgentEntity, AgentToolEntity
  7. from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
  8. from core.app.apps.agent_chat.app_config_manager import AgentChatAppConfig
  9. from core.app.apps.base_app_queue_manager import AppQueueManager
  10. from core.app.apps.base_app_runner import AppRunner
  11. from core.app.entities.app_invoke_entities import (
  12. AgentChatAppGenerateEntity,
  13. ModelConfigWithCredentialsEntity,
  14. )
  15. from core.callback_handler.agent_tool_callback_handler import DifyAgentCallbackHandler
  16. from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
  17. from core.file.message_file_parser import MessageFileParser
  18. from core.memory.token_buffer_memory import TokenBufferMemory
  19. from core.model_manager import ModelInstance
  20. from core.model_runtime.entities.llm_entities import LLMUsage
  21. from core.model_runtime.entities.message_entities import (
  22. AssistantPromptMessage,
  23. PromptMessage,
  24. PromptMessageContent,
  25. PromptMessageTool,
  26. SystemPromptMessage,
  27. TextPromptMessageContent,
  28. ToolPromptMessage,
  29. UserPromptMessage,
  30. )
  31. from core.model_runtime.entities.model_entities import ModelFeature
  32. from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
  33. from core.tools.__base.tool import Tool
  34. from core.tools.entities.tool_entities import (
  35. ToolParameter,
  36. )
  37. from core.tools.tool_manager import ToolManager
  38. from core.tools.utils.dataset_retriever_tool import DatasetRetrieverTool
  39. from core.tools.utils.tool_parameter_converter import ToolParameterConverter
  40. from extensions.ext_database import db
  41. from models.model import Conversation, Message, MessageAgentThought
  42. logger = logging.getLogger(__name__)
  43. class BaseAgentRunner(AppRunner):
  44. def __init__(
  45. self,
  46. tenant_id: str,
  47. application_generate_entity: AgentChatAppGenerateEntity,
  48. conversation: Conversation,
  49. app_config: AgentChatAppConfig,
  50. model_config: ModelConfigWithCredentialsEntity,
  51. config: AgentEntity,
  52. queue_manager: AppQueueManager,
  53. message: Message,
  54. user_id: str,
  55. model_instance: ModelInstance,
  56. memory: Optional[TokenBufferMemory] = None,
  57. prompt_messages: Optional[list[PromptMessage]] = None,
  58. ) -> None:
  59. """
  60. Agent runner
  61. :param tenant_id: tenant id
  62. :param application_generate_entity: application generate entity
  63. :param conversation: conversation
  64. :param app_config: app generate entity
  65. :param model_config: model config
  66. :param config: dataset config
  67. :param queue_manager: queue manager
  68. :param message: message
  69. :param user_id: user id
  70. :param memory: memory
  71. :param prompt_messages: prompt messages
  72. :param variables_pool: variables pool
  73. :param db_variables: db variables
  74. :param model_instance: model instance
  75. """
  76. self.tenant_id = tenant_id
  77. self.application_generate_entity = application_generate_entity
  78. self.conversation = conversation
  79. self.app_config = app_config
  80. self.model_config = model_config
  81. self.config = config
  82. self.queue_manager = queue_manager
  83. self.message = message
  84. self.user_id = user_id
  85. self.memory = memory
  86. self.history_prompt_messages = self.organize_agent_history(prompt_messages=prompt_messages or [])
  87. self.model_instance = model_instance
  88. # init callback
  89. self.agent_callback = DifyAgentCallbackHandler()
  90. # init dataset tools
  91. hit_callback = DatasetIndexToolCallbackHandler(
  92. queue_manager=queue_manager,
  93. app_id=self.app_config.app_id,
  94. message_id=message.id,
  95. user_id=user_id,
  96. invoke_from=self.application_generate_entity.invoke_from,
  97. )
  98. self.dataset_tools = DatasetRetrieverTool.get_dataset_tools(
  99. tenant_id=tenant_id,
  100. dataset_ids=app_config.dataset.dataset_ids if app_config.dataset else [],
  101. retrieve_config=app_config.dataset.retrieve_config if app_config.dataset else None,
  102. return_resource=app_config.additional_features.show_retrieve_source,
  103. invoke_from=application_generate_entity.invoke_from,
  104. hit_callback=hit_callback,
  105. )
  106. # get how many agent thoughts have been created
  107. self.agent_thought_count = (
  108. db.session.query(MessageAgentThought)
  109. .filter(
  110. MessageAgentThought.message_id == self.message.id,
  111. )
  112. .count()
  113. )
  114. db.session.close()
  115. # check if model supports stream tool call
  116. llm_model = cast(LargeLanguageModel, model_instance.model_type_instance)
  117. model_schema = llm_model.get_model_schema(model_instance.model, model_instance.credentials)
  118. if model_schema and ModelFeature.STREAM_TOOL_CALL in (model_schema.features or []):
  119. self.stream_tool_call = True
  120. else:
  121. self.stream_tool_call = False
  122. # check if model supports vision
  123. if model_schema and ModelFeature.VISION in (model_schema.features or []):
  124. self.files = application_generate_entity.files
  125. else:
  126. self.files = []
  127. self.query = None
  128. self._current_thoughts: list[PromptMessage] = []
  129. def _repack_app_generate_entity(
  130. self, app_generate_entity: AgentChatAppGenerateEntity
  131. ) -> AgentChatAppGenerateEntity:
  132. """
  133. Repack app generate entity
  134. """
  135. if app_generate_entity.app_config.prompt_template.simple_prompt_template is None:
  136. app_generate_entity.app_config.prompt_template.simple_prompt_template = ""
  137. return app_generate_entity
  138. def _convert_tool_to_prompt_message_tool(self, tool: AgentToolEntity) -> tuple[PromptMessageTool, Tool]:
  139. """
  140. convert tool to prompt message tool
  141. """
  142. tool_entity = ToolManager.get_agent_tool_runtime(
  143. tenant_id=self.tenant_id,
  144. app_id=self.app_config.app_id,
  145. agent_tool=tool,
  146. invoke_from=self.application_generate_entity.invoke_from,
  147. )
  148. assert tool_entity.entity.description
  149. message_tool = PromptMessageTool(
  150. name=tool.tool_name,
  151. description=tool_entity.entity.description.llm,
  152. parameters={
  153. "type": "object",
  154. "properties": {},
  155. "required": [],
  156. },
  157. )
  158. parameters = tool_entity.get_merged_runtime_parameters()
  159. for parameter in parameters:
  160. if parameter.form != ToolParameter.ToolParameterForm.LLM:
  161. continue
  162. parameter_type = ToolParameterConverter.get_parameter_type(parameter.type)
  163. enum = []
  164. if parameter.type == ToolParameter.ToolParameterType.SELECT:
  165. enum = [option.value for option in parameter.options]
  166. message_tool.parameters["properties"][parameter.name] = {
  167. "type": parameter_type,
  168. "description": parameter.llm_description or "",
  169. }
  170. if len(enum) > 0:
  171. message_tool.parameters["properties"][parameter.name]["enum"] = enum
  172. if parameter.required:
  173. message_tool.parameters["required"].append(parameter.name)
  174. return message_tool, tool_entity
  175. def _convert_dataset_retriever_tool_to_prompt_message_tool(self, tool: DatasetRetrieverTool) -> PromptMessageTool:
  176. """
  177. convert dataset retriever tool to prompt message tool
  178. """
  179. assert tool.entity.description
  180. prompt_tool = PromptMessageTool(
  181. name=tool.entity.identity.name,
  182. description=tool.entity.description.llm,
  183. parameters={
  184. "type": "object",
  185. "properties": {},
  186. "required": [],
  187. },
  188. )
  189. for parameter in tool.get_runtime_parameters():
  190. parameter_type = "string"
  191. prompt_tool.parameters["properties"][parameter.name] = {
  192. "type": parameter_type,
  193. "description": parameter.llm_description or "",
  194. }
  195. if parameter.required:
  196. if parameter.name not in prompt_tool.parameters["required"]:
  197. prompt_tool.parameters["required"].append(parameter.name)
  198. return prompt_tool
  199. def _init_prompt_tools(self) -> tuple[Mapping[str, Tool], Sequence[PromptMessageTool]]:
  200. """
  201. Init tools
  202. """
  203. tool_instances = {}
  204. prompt_messages_tools = []
  205. for tool in self.app_config.agent.tools or [] if self.app_config.agent else []:
  206. try:
  207. prompt_tool, tool_entity = self._convert_tool_to_prompt_message_tool(tool)
  208. except Exception:
  209. # api tool may be deleted
  210. continue
  211. # save tool entity
  212. tool_instances[tool.tool_name] = tool_entity
  213. # save prompt tool
  214. prompt_messages_tools.append(prompt_tool)
  215. # convert dataset tools into ModelRuntime Tool format
  216. for dataset_tool in self.dataset_tools:
  217. prompt_tool = self._convert_dataset_retriever_tool_to_prompt_message_tool(dataset_tool)
  218. # save prompt tool
  219. prompt_messages_tools.append(prompt_tool)
  220. # save tool entity
  221. tool_instances[dataset_tool.entity.identity.name] = dataset_tool
  222. return tool_instances, prompt_messages_tools
  223. def update_prompt_message_tool(self, tool: Tool, prompt_tool: PromptMessageTool) -> PromptMessageTool:
  224. """
  225. update prompt message tool
  226. """
  227. # try to get tool runtime parameters
  228. tool_runtime_parameters = tool.get_runtime_parameters() or []
  229. for parameter in tool_runtime_parameters:
  230. if parameter.form != ToolParameter.ToolParameterForm.LLM:
  231. continue
  232. parameter_type = ToolParameterConverter.get_parameter_type(parameter.type)
  233. enum = []
  234. if parameter.type == ToolParameter.ToolParameterType.SELECT:
  235. enum = [option.value for option in parameter.options]
  236. prompt_tool.parameters["properties"][parameter.name] = {
  237. "type": parameter_type,
  238. "description": parameter.llm_description or "",
  239. }
  240. if len(enum) > 0:
  241. prompt_tool.parameters["properties"][parameter.name]["enum"] = enum
  242. if parameter.required:
  243. if parameter.name not in prompt_tool.parameters["required"]:
  244. prompt_tool.parameters["required"].append(parameter.name)
  245. return prompt_tool
  246. def create_agent_thought(
  247. self, message_id: str, message: str, tool_name: str, tool_input: str, messages_ids: list[str]
  248. ) -> MessageAgentThought:
  249. """
  250. Create agent thought
  251. """
  252. thought = MessageAgentThought(
  253. message_id=message_id,
  254. message_chain_id=None,
  255. thought="",
  256. tool=tool_name,
  257. tool_labels_str="{}",
  258. tool_meta_str="{}",
  259. tool_input=tool_input,
  260. message=message,
  261. message_token=0,
  262. message_unit_price=0,
  263. message_price_unit=0,
  264. message_files=json.dumps(messages_ids) if messages_ids else "",
  265. answer="",
  266. observation="",
  267. answer_token=0,
  268. answer_unit_price=0,
  269. answer_price_unit=0,
  270. tokens=0,
  271. total_price=0,
  272. position=self.agent_thought_count + 1,
  273. currency="USD",
  274. latency=0,
  275. created_by_role="account",
  276. created_by=self.user_id,
  277. )
  278. db.session.add(thought)
  279. db.session.commit()
  280. db.session.refresh(thought)
  281. db.session.close()
  282. self.agent_thought_count += 1
  283. return thought
  284. def save_agent_thought(
  285. self,
  286. agent_thought: MessageAgentThought,
  287. tool_name: str | None,
  288. tool_input: Union[str, dict, None],
  289. thought: str | None,
  290. observation: Union[str, dict, None],
  291. tool_invoke_meta: Union[str, dict, None],
  292. answer: str | None,
  293. messages_ids: list[str],
  294. llm_usage: LLMUsage | None = None,
  295. ):
  296. """
  297. Save agent thought
  298. """
  299. updated_agent_thought = (
  300. db.session.query(MessageAgentThought).filter(MessageAgentThought.id == agent_thought.id).first()
  301. )
  302. if not updated_agent_thought:
  303. raise ValueError("agent thought not found")
  304. if thought is not None:
  305. updated_agent_thought.thought = thought
  306. if tool_name is not None:
  307. updated_agent_thought.tool = tool_name
  308. if tool_input is not None:
  309. if isinstance(tool_input, dict):
  310. try:
  311. tool_input = json.dumps(tool_input, ensure_ascii=False)
  312. except Exception as e:
  313. tool_input = json.dumps(tool_input)
  314. updated_agent_thought.tool_input = tool_input
  315. if observation is not None:
  316. if isinstance(observation, dict):
  317. try:
  318. observation = json.dumps(observation, ensure_ascii=False)
  319. except Exception as e:
  320. observation = json.dumps(observation)
  321. updated_agent_thought.observation = observation
  322. if answer is not None:
  323. updated_agent_thought.answer = answer
  324. if messages_ids is not None and len(messages_ids) > 0:
  325. updated_agent_thought.message_files = json.dumps(messages_ids)
  326. if llm_usage:
  327. updated_agent_thought.message_token = llm_usage.prompt_tokens
  328. updated_agent_thought.message_price_unit = llm_usage.prompt_price_unit
  329. updated_agent_thought.message_unit_price = llm_usage.prompt_unit_price
  330. updated_agent_thought.answer_token = llm_usage.completion_tokens
  331. updated_agent_thought.answer_price_unit = llm_usage.completion_price_unit
  332. updated_agent_thought.answer_unit_price = llm_usage.completion_unit_price
  333. updated_agent_thought.tokens = llm_usage.total_tokens
  334. updated_agent_thought.total_price = llm_usage.total_price
  335. # check if tool labels is not empty
  336. labels = updated_agent_thought.tool_labels or {}
  337. tools = updated_agent_thought.tool.split(";") if updated_agent_thought.tool else []
  338. for tool in tools:
  339. if not tool:
  340. continue
  341. if tool not in labels:
  342. tool_label = ToolManager.get_tool_label(tool)
  343. if tool_label:
  344. labels[tool] = tool_label.to_dict()
  345. else:
  346. labels[tool] = {"en_US": tool, "zh_Hans": tool}
  347. updated_agent_thought.tool_labels_str = json.dumps(labels)
  348. if tool_invoke_meta is not None:
  349. if isinstance(tool_invoke_meta, dict):
  350. try:
  351. tool_invoke_meta = json.dumps(tool_invoke_meta, ensure_ascii=False)
  352. except Exception as e:
  353. tool_invoke_meta = json.dumps(tool_invoke_meta)
  354. updated_agent_thought.tool_meta_str = tool_invoke_meta
  355. db.session.commit()
  356. db.session.close()
  357. def organize_agent_history(self, prompt_messages: list[PromptMessage]) -> list[PromptMessage]:
  358. """
  359. Organize agent history
  360. """
  361. result = []
  362. # check if there is a system message in the beginning of the conversation
  363. for prompt_message in prompt_messages:
  364. if isinstance(prompt_message, SystemPromptMessage):
  365. result.append(prompt_message)
  366. messages: list[Message] = (
  367. db.session.query(Message)
  368. .filter(
  369. Message.conversation_id == self.message.conversation_id,
  370. )
  371. .order_by(Message.created_at.asc())
  372. .all()
  373. )
  374. for message in messages:
  375. if message.id == self.message.id:
  376. continue
  377. result.append(self.organize_agent_user_prompt(message))
  378. agent_thoughts: list[MessageAgentThought] = message.agent_thoughts
  379. if agent_thoughts:
  380. for agent_thought in agent_thoughts:
  381. tools = agent_thought.tool
  382. if tools:
  383. tools = tools.split(";")
  384. tool_calls: list[AssistantPromptMessage.ToolCall] = []
  385. tool_call_response: list[ToolPromptMessage] = []
  386. try:
  387. tool_inputs = json.loads(agent_thought.tool_input)
  388. except Exception as e:
  389. tool_inputs = {tool: {} for tool in tools}
  390. try:
  391. tool_responses = json.loads(agent_thought.observation)
  392. except Exception as e:
  393. tool_responses = dict.fromkeys(tools, agent_thought.observation)
  394. for tool in tools:
  395. # generate a uuid for tool call
  396. tool_call_id = str(uuid.uuid4())
  397. tool_calls.append(
  398. AssistantPromptMessage.ToolCall(
  399. id=tool_call_id,
  400. type="function",
  401. function=AssistantPromptMessage.ToolCall.ToolCallFunction(
  402. name=tool,
  403. arguments=json.dumps(tool_inputs.get(tool, {})),
  404. ),
  405. )
  406. )
  407. tool_call_response.append(
  408. ToolPromptMessage(
  409. content=tool_responses.get(tool, agent_thought.observation),
  410. name=tool,
  411. tool_call_id=tool_call_id,
  412. )
  413. )
  414. result.extend(
  415. [
  416. AssistantPromptMessage(
  417. content=agent_thought.thought,
  418. tool_calls=tool_calls,
  419. ),
  420. *tool_call_response,
  421. ]
  422. )
  423. if not tools:
  424. result.append(AssistantPromptMessage(content=agent_thought.thought))
  425. else:
  426. if message.answer:
  427. result.append(AssistantPromptMessage(content=message.answer))
  428. db.session.close()
  429. return result
  430. def organize_agent_user_prompt(self, message: Message) -> UserPromptMessage:
  431. message_file_parser = MessageFileParser(
  432. tenant_id=self.tenant_id,
  433. app_id=self.app_config.app_id,
  434. )
  435. files = message.message_files
  436. if files:
  437. assert message.app_model_config
  438. file_extra_config = FileUploadConfigManager.convert(message.app_model_config.to_dict())
  439. if file_extra_config:
  440. file_objs = message_file_parser.transform_message_files(files, file_extra_config)
  441. else:
  442. file_objs = []
  443. if not file_objs:
  444. return UserPromptMessage(content=message.query)
  445. else:
  446. prompt_message_contents: list[PromptMessageContent] = [TextPromptMessageContent(data=message.query)]
  447. for file_obj in file_objs:
  448. prompt_message_contents.append(file_obj.prompt_message_content)
  449. return UserPromptMessage(content=prompt_message_contents)
  450. else:
  451. return UserPromptMessage(content=message.query)