base_agent_runner.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512
  1. import json
  2. import logging
  3. import uuid
  4. from collections.abc import Mapping, Sequence
  5. from typing import Optional, Union, cast
  6. from core.agent.entities import AgentEntity, AgentToolEntity
  7. from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
  8. from core.app.apps.agent_chat.app_config_manager import AgentChatAppConfig
  9. from core.app.apps.base_app_queue_manager import AppQueueManager
  10. from core.app.apps.base_app_runner import AppRunner
  11. from core.app.entities.app_invoke_entities import (
  12. AgentChatAppGenerateEntity,
  13. ModelConfigWithCredentialsEntity,
  14. )
  15. from core.callback_handler.agent_tool_callback_handler import DifyAgentCallbackHandler
  16. from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
  17. from core.file import file_manager
  18. from core.memory.token_buffer_memory import TokenBufferMemory
  19. from core.model_manager import ModelInstance
  20. from core.model_runtime.entities import (
  21. AssistantPromptMessage,
  22. LLMUsage,
  23. PromptMessage,
  24. PromptMessageContent,
  25. PromptMessageTool,
  26. SystemPromptMessage,
  27. TextPromptMessageContent,
  28. ToolPromptMessage,
  29. UserPromptMessage,
  30. )
  31. from core.model_runtime.entities.model_entities import ModelFeature
  32. from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
  33. from core.prompt.utils.extract_thread_messages import extract_thread_messages
  34. from core.tools.__base.tool import Tool
  35. from core.tools.entities.tool_entities import (
  36. ToolParameter,
  37. )
  38. from core.tools.tool_manager import ToolManager
  39. from core.tools.utils.dataset_retriever_tool import DatasetRetrieverTool
  40. from extensions.ext_database import db
  41. from factories import file_factory
  42. from models.model import Conversation, Message, MessageAgentThought, MessageFile
  43. logger = logging.getLogger(__name__)
  44. class BaseAgentRunner(AppRunner):
  45. def __init__(
  46. self,
  47. tenant_id: str,
  48. application_generate_entity: AgentChatAppGenerateEntity,
  49. conversation: Conversation,
  50. app_config: AgentChatAppConfig,
  51. model_config: ModelConfigWithCredentialsEntity,
  52. config: AgentEntity,
  53. queue_manager: AppQueueManager,
  54. message: Message,
  55. user_id: str,
  56. model_instance: ModelInstance,
  57. memory: Optional[TokenBufferMemory] = None,
  58. prompt_messages: Optional[list[PromptMessage]] = None,
  59. ) -> None:
  60. self.tenant_id = tenant_id
  61. self.application_generate_entity = application_generate_entity
  62. self.conversation = conversation
  63. self.app_config = app_config
  64. self.model_config = model_config
  65. self.config = config
  66. self.queue_manager = queue_manager
  67. self.message = message
  68. self.user_id = user_id
  69. self.memory = memory
  70. self.history_prompt_messages = self.organize_agent_history(prompt_messages=prompt_messages or [])
  71. self.model_instance = model_instance
  72. # init callback
  73. self.agent_callback = DifyAgentCallbackHandler()
  74. # init dataset tools
  75. hit_callback = DatasetIndexToolCallbackHandler(
  76. queue_manager=queue_manager,
  77. app_id=self.app_config.app_id,
  78. message_id=message.id,
  79. user_id=user_id,
  80. invoke_from=self.application_generate_entity.invoke_from,
  81. )
  82. self.dataset_tools = DatasetRetrieverTool.get_dataset_tools(
  83. tenant_id=tenant_id,
  84. dataset_ids=app_config.dataset.dataset_ids if app_config.dataset else [],
  85. retrieve_config=app_config.dataset.retrieve_config if app_config.dataset else None,
  86. return_resource=app_config.additional_features.show_retrieve_source,
  87. invoke_from=application_generate_entity.invoke_from,
  88. hit_callback=hit_callback,
  89. )
  90. # get how many agent thoughts have been created
  91. self.agent_thought_count = (
  92. db.session.query(MessageAgentThought)
  93. .filter(
  94. MessageAgentThought.message_id == self.message.id,
  95. )
  96. .count()
  97. )
  98. db.session.close()
  99. # check if model supports stream tool call
  100. llm_model = cast(LargeLanguageModel, model_instance.model_type_instance)
  101. model_schema = llm_model.get_model_schema(model_instance.model, model_instance.credentials)
  102. if model_schema and ModelFeature.STREAM_TOOL_CALL in (model_schema.features or []):
  103. self.stream_tool_call = True
  104. else:
  105. self.stream_tool_call = False
  106. # check if model supports vision
  107. if model_schema and ModelFeature.VISION in (model_schema.features or []):
  108. self.files = application_generate_entity.files
  109. else:
  110. self.files = []
  111. self.query = None
  112. self._current_thoughts: list[PromptMessage] = []
  113. def _repack_app_generate_entity(
  114. self, app_generate_entity: AgentChatAppGenerateEntity
  115. ) -> AgentChatAppGenerateEntity:
  116. """
  117. Repack app generate entity
  118. """
  119. if app_generate_entity.app_config.prompt_template.simple_prompt_template is None:
  120. app_generate_entity.app_config.prompt_template.simple_prompt_template = ""
  121. return app_generate_entity
  122. def _convert_tool_to_prompt_message_tool(self, tool: AgentToolEntity) -> tuple[PromptMessageTool, Tool]:
  123. """
  124. convert tool to prompt message tool
  125. """
  126. tool_entity = ToolManager.get_agent_tool_runtime(
  127. tenant_id=self.tenant_id,
  128. app_id=self.app_config.app_id,
  129. agent_tool=tool,
  130. invoke_from=self.application_generate_entity.invoke_from,
  131. )
  132. assert tool_entity.entity.description
  133. message_tool = PromptMessageTool(
  134. name=tool.tool_name,
  135. description=tool_entity.entity.description.llm,
  136. parameters={
  137. "type": "object",
  138. "properties": {},
  139. "required": [],
  140. },
  141. )
  142. parameters = tool_entity.get_merged_runtime_parameters()
  143. for parameter in parameters:
  144. if parameter.form != ToolParameter.ToolParameterForm.LLM:
  145. continue
  146. parameter_type = parameter.type.as_normal_type()
  147. if parameter.type in {
  148. ToolParameter.ToolParameterType.SYSTEM_FILES,
  149. ToolParameter.ToolParameterType.FILE,
  150. ToolParameter.ToolParameterType.FILES,
  151. }:
  152. continue
  153. enum = []
  154. if parameter.type == ToolParameter.ToolParameterType.SELECT:
  155. enum = [option.value for option in parameter.options]
  156. message_tool.parameters["properties"][parameter.name] = {
  157. "type": parameter_type,
  158. "description": parameter.llm_description or "",
  159. }
  160. if len(enum) > 0:
  161. message_tool.parameters["properties"][parameter.name]["enum"] = enum
  162. if parameter.required:
  163. message_tool.parameters["required"].append(parameter.name)
  164. return message_tool, tool_entity
  165. def _convert_dataset_retriever_tool_to_prompt_message_tool(self, tool: DatasetRetrieverTool) -> PromptMessageTool:
  166. """
  167. convert dataset retriever tool to prompt message tool
  168. """
  169. assert tool.entity.description
  170. prompt_tool = PromptMessageTool(
  171. name=tool.entity.identity.name,
  172. description=tool.entity.description.llm,
  173. parameters={
  174. "type": "object",
  175. "properties": {},
  176. "required": [],
  177. },
  178. )
  179. for parameter in tool.get_runtime_parameters():
  180. parameter_type = "string"
  181. prompt_tool.parameters["properties"][parameter.name] = {
  182. "type": parameter_type,
  183. "description": parameter.llm_description or "",
  184. }
  185. if parameter.required:
  186. if parameter.name not in prompt_tool.parameters["required"]:
  187. prompt_tool.parameters["required"].append(parameter.name)
  188. return prompt_tool
  189. def _init_prompt_tools(self) -> tuple[Mapping[str, Tool], Sequence[PromptMessageTool]]:
  190. """
  191. Init tools
  192. """
  193. tool_instances = {}
  194. prompt_messages_tools = []
  195. for tool in self.app_config.agent.tools or [] if self.app_config.agent else []:
  196. try:
  197. prompt_tool, tool_entity = self._convert_tool_to_prompt_message_tool(tool)
  198. except Exception:
  199. # api tool may be deleted
  200. continue
  201. # save tool entity
  202. tool_instances[tool.tool_name] = tool_entity
  203. # save prompt tool
  204. prompt_messages_tools.append(prompt_tool)
  205. # convert dataset tools into ModelRuntime Tool format
  206. for dataset_tool in self.dataset_tools:
  207. prompt_tool = self._convert_dataset_retriever_tool_to_prompt_message_tool(dataset_tool)
  208. # save prompt tool
  209. prompt_messages_tools.append(prompt_tool)
  210. # save tool entity
  211. tool_instances[dataset_tool.entity.identity.name] = dataset_tool
  212. return tool_instances, prompt_messages_tools
  213. def update_prompt_message_tool(self, tool: Tool, prompt_tool: PromptMessageTool) -> PromptMessageTool:
  214. """
  215. update prompt message tool
  216. """
  217. # try to get tool runtime parameters
  218. tool_runtime_parameters = tool.get_runtime_parameters() or []
  219. for parameter in tool_runtime_parameters:
  220. if parameter.form != ToolParameter.ToolParameterForm.LLM:
  221. continue
  222. parameter_type = parameter.type.as_normal_type()
  223. if parameter.type in {
  224. ToolParameter.ToolParameterType.SYSTEM_FILES,
  225. ToolParameter.ToolParameterType.FILE,
  226. ToolParameter.ToolParameterType.FILES,
  227. }:
  228. continue
  229. enum = []
  230. if parameter.type == ToolParameter.ToolParameterType.SELECT:
  231. enum = [option.value for option in parameter.options]
  232. prompt_tool.parameters["properties"][parameter.name] = {
  233. "type": parameter_type,
  234. "description": parameter.llm_description or "",
  235. }
  236. if len(enum) > 0:
  237. prompt_tool.parameters["properties"][parameter.name]["enum"] = enum
  238. if parameter.required:
  239. if parameter.name not in prompt_tool.parameters["required"]:
  240. prompt_tool.parameters["required"].append(parameter.name)
  241. return prompt_tool
  242. def create_agent_thought(
  243. self, message_id: str, message: str, tool_name: str, tool_input: str, messages_ids: list[str]
  244. ) -> MessageAgentThought:
  245. """
  246. Create agent thought
  247. """
  248. thought = MessageAgentThought(
  249. message_id=message_id,
  250. message_chain_id=None,
  251. thought="",
  252. tool=tool_name,
  253. tool_labels_str="{}",
  254. tool_meta_str="{}",
  255. tool_input=tool_input,
  256. message=message,
  257. message_token=0,
  258. message_unit_price=0,
  259. message_price_unit=0,
  260. message_files=json.dumps(messages_ids) if messages_ids else "",
  261. answer="",
  262. observation="",
  263. answer_token=0,
  264. answer_unit_price=0,
  265. answer_price_unit=0,
  266. tokens=0,
  267. total_price=0,
  268. position=self.agent_thought_count + 1,
  269. currency="USD",
  270. latency=0,
  271. created_by_role="account",
  272. created_by=self.user_id,
  273. )
  274. db.session.add(thought)
  275. db.session.commit()
  276. db.session.refresh(thought)
  277. db.session.close()
  278. self.agent_thought_count += 1
  279. return thought
  280. def save_agent_thought(
  281. self,
  282. agent_thought: MessageAgentThought,
  283. tool_name: str | None,
  284. tool_input: Union[str, dict, None],
  285. thought: str | None,
  286. observation: Union[str, dict, None],
  287. tool_invoke_meta: Union[str, dict, None],
  288. answer: str | None,
  289. messages_ids: list[str],
  290. llm_usage: LLMUsage | None = None,
  291. ):
  292. """
  293. Save agent thought
  294. """
  295. updated_agent_thought = (
  296. db.session.query(MessageAgentThought).filter(MessageAgentThought.id == agent_thought.id).first()
  297. )
  298. if not updated_agent_thought:
  299. raise ValueError("agent thought not found")
  300. if thought is not None:
  301. updated_agent_thought.thought = thought
  302. if tool_name is not None:
  303. updated_agent_thought.tool = tool_name
  304. if tool_input is not None:
  305. if isinstance(tool_input, dict):
  306. try:
  307. tool_input = json.dumps(tool_input, ensure_ascii=False)
  308. except Exception as e:
  309. tool_input = json.dumps(tool_input)
  310. updated_agent_thought.tool_input = tool_input
  311. if observation is not None:
  312. if isinstance(observation, dict):
  313. try:
  314. observation = json.dumps(observation, ensure_ascii=False)
  315. except Exception as e:
  316. observation = json.dumps(observation)
  317. updated_agent_thought.observation = observation
  318. if answer is not None:
  319. updated_agent_thought.answer = answer
  320. if messages_ids is not None and len(messages_ids) > 0:
  321. updated_agent_thought.message_files = json.dumps(messages_ids)
  322. if llm_usage:
  323. updated_agent_thought.message_token = llm_usage.prompt_tokens
  324. updated_agent_thought.message_price_unit = llm_usage.prompt_price_unit
  325. updated_agent_thought.message_unit_price = llm_usage.prompt_unit_price
  326. updated_agent_thought.answer_token = llm_usage.completion_tokens
  327. updated_agent_thought.answer_price_unit = llm_usage.completion_price_unit
  328. updated_agent_thought.answer_unit_price = llm_usage.completion_unit_price
  329. updated_agent_thought.tokens = llm_usage.total_tokens
  330. updated_agent_thought.total_price = llm_usage.total_price
  331. # check if tool labels is not empty
  332. labels = updated_agent_thought.tool_labels or {}
  333. tools = updated_agent_thought.tool.split(";") if updated_agent_thought.tool else []
  334. for tool in tools:
  335. if not tool:
  336. continue
  337. if tool not in labels:
  338. tool_label = ToolManager.get_tool_label(tool)
  339. if tool_label:
  340. labels[tool] = tool_label.to_dict()
  341. else:
  342. labels[tool] = {"en_US": tool, "zh_Hans": tool}
  343. updated_agent_thought.tool_labels_str = json.dumps(labels)
  344. if tool_invoke_meta is not None:
  345. if isinstance(tool_invoke_meta, dict):
  346. try:
  347. tool_invoke_meta = json.dumps(tool_invoke_meta, ensure_ascii=False)
  348. except Exception as e:
  349. tool_invoke_meta = json.dumps(tool_invoke_meta)
  350. updated_agent_thought.tool_meta_str = tool_invoke_meta
  351. db.session.commit()
  352. db.session.close()
  353. def organize_agent_history(self, prompt_messages: list[PromptMessage]) -> list[PromptMessage]:
  354. """
  355. Organize agent history
  356. """
  357. result = []
  358. # check if there is a system message in the beginning of the conversation
  359. for prompt_message in prompt_messages:
  360. if isinstance(prompt_message, SystemPromptMessage):
  361. result.append(prompt_message)
  362. messages: list[Message] = (
  363. db.session.query(Message)
  364. .filter(
  365. Message.conversation_id == self.message.conversation_id,
  366. )
  367. .order_by(Message.created_at.desc())
  368. .all()
  369. )
  370. messages = list(reversed(extract_thread_messages(messages)))
  371. for message in messages:
  372. if message.id == self.message.id:
  373. continue
  374. result.append(self.organize_agent_user_prompt(message))
  375. agent_thoughts: list[MessageAgentThought] = message.agent_thoughts
  376. if agent_thoughts:
  377. for agent_thought in agent_thoughts:
  378. tools = agent_thought.tool
  379. if tools:
  380. tools = tools.split(";")
  381. tool_calls: list[AssistantPromptMessage.ToolCall] = []
  382. tool_call_response: list[ToolPromptMessage] = []
  383. try:
  384. tool_inputs = json.loads(agent_thought.tool_input)
  385. except Exception as e:
  386. tool_inputs = {tool: {} for tool in tools}
  387. try:
  388. tool_responses = json.loads(agent_thought.observation)
  389. except Exception as e:
  390. tool_responses = dict.fromkeys(tools, agent_thought.observation)
  391. for tool in tools:
  392. # generate a uuid for tool call
  393. tool_call_id = str(uuid.uuid4())
  394. tool_calls.append(
  395. AssistantPromptMessage.ToolCall(
  396. id=tool_call_id,
  397. type="function",
  398. function=AssistantPromptMessage.ToolCall.ToolCallFunction(
  399. name=tool,
  400. arguments=json.dumps(tool_inputs.get(tool, {})),
  401. ),
  402. )
  403. )
  404. tool_call_response.append(
  405. ToolPromptMessage(
  406. content=tool_responses.get(tool, agent_thought.observation),
  407. name=tool,
  408. tool_call_id=tool_call_id,
  409. )
  410. )
  411. result.extend(
  412. [
  413. AssistantPromptMessage(
  414. content=agent_thought.thought,
  415. tool_calls=tool_calls,
  416. ),
  417. *tool_call_response,
  418. ]
  419. )
  420. if not tools:
  421. result.append(AssistantPromptMessage(content=agent_thought.thought))
  422. else:
  423. if message.answer:
  424. result.append(AssistantPromptMessage(content=message.answer))
  425. db.session.close()
  426. return result
  427. def organize_agent_user_prompt(self, message: Message) -> UserPromptMessage:
  428. files = db.session.query(MessageFile).filter(MessageFile.message_id == message.id).all()
  429. if files:
  430. assert message.app_model_config
  431. file_extra_config = FileUploadConfigManager.convert(message.app_model_config.to_dict())
  432. if file_extra_config:
  433. file_objs = file_factory.build_from_message_files(
  434. message_files=files, tenant_id=self.tenant_id, config=file_extra_config
  435. )
  436. else:
  437. file_objs = []
  438. if not file_objs:
  439. return UserPromptMessage(content=message.query)
  440. else:
  441. prompt_message_contents: list[PromptMessageContent] = [TextPromptMessageContent(data=message.query)]
  442. for file_obj in file_objs:
  443. prompt_message_contents.append(file_manager.to_prompt_message_content(file_obj))
  444. return UserPromptMessage(content=prompt_message_contents)
  445. else:
  446. return UserPromptMessage(content=message.query)