base_agent_runner.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532
  1. import json
  2. import logging
  3. import uuid
  4. from collections.abc import Mapping, Sequence
  5. from datetime import datetime, timezone
  6. from typing import Optional, Union, cast
  7. from core.agent.entities import AgentEntity, AgentToolEntity
  8. from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
  9. from core.app.apps.agent_chat.app_config_manager import AgentChatAppConfig
  10. from core.app.apps.base_app_queue_manager import AppQueueManager
  11. from core.app.apps.base_app_runner import AppRunner
  12. from core.app.entities.app_invoke_entities import (
  13. AgentChatAppGenerateEntity,
  14. ModelConfigWithCredentialsEntity,
  15. )
  16. from core.callback_handler.agent_tool_callback_handler import DifyAgentCallbackHandler
  17. from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
  18. from core.file import file_manager
  19. from core.memory.token_buffer_memory import TokenBufferMemory
  20. from core.model_manager import ModelInstance
  21. from core.model_runtime.entities import (
  22. AssistantPromptMessage,
  23. LLMUsage,
  24. PromptMessage,
  25. PromptMessageContent,
  26. PromptMessageTool,
  27. SystemPromptMessage,
  28. TextPromptMessageContent,
  29. ToolPromptMessage,
  30. UserPromptMessage,
  31. )
  32. from core.model_runtime.entities.model_entities import ModelFeature
  33. from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
  34. from core.model_runtime.utils.encoders import jsonable_encoder
  35. from core.prompt.utils.extract_thread_messages import extract_thread_messages
  36. from core.tools.entities.tool_entities import (
  37. ToolParameter,
  38. ToolRuntimeVariablePool,
  39. )
  40. from core.tools.tool.dataset_retriever_tool import DatasetRetrieverTool
  41. from core.tools.tool.tool import Tool
  42. from core.tools.tool_manager import ToolManager
  43. from extensions.ext_database import db
  44. from factories import file_factory
  45. from models.model import Conversation, Message, MessageAgentThought, MessageFile
  46. from models.tools import ToolConversationVariables
  47. logger = logging.getLogger(__name__)
  48. class BaseAgentRunner(AppRunner):
  49. def __init__(
  50. self,
  51. tenant_id: str,
  52. application_generate_entity: AgentChatAppGenerateEntity,
  53. conversation: Conversation,
  54. app_config: AgentChatAppConfig,
  55. model_config: ModelConfigWithCredentialsEntity,
  56. config: AgentEntity,
  57. queue_manager: AppQueueManager,
  58. message: Message,
  59. user_id: str,
  60. memory: Optional[TokenBufferMemory] = None,
  61. prompt_messages: Optional[list[PromptMessage]] = None,
  62. variables_pool: Optional[ToolRuntimeVariablePool] = None,
  63. db_variables: Optional[ToolConversationVariables] = None,
  64. model_instance: ModelInstance = None,
  65. ) -> None:
  66. self.tenant_id = tenant_id
  67. self.application_generate_entity = application_generate_entity
  68. self.conversation = conversation
  69. self.app_config = app_config
  70. self.model_config = model_config
  71. self.config = config
  72. self.queue_manager = queue_manager
  73. self.message = message
  74. self.user_id = user_id
  75. self.memory = memory
  76. self.history_prompt_messages = self.organize_agent_history(prompt_messages=prompt_messages or [])
  77. self.variables_pool = variables_pool
  78. self.db_variables_pool = db_variables
  79. self.model_instance = model_instance
  80. # init callback
  81. self.agent_callback = DifyAgentCallbackHandler()
  82. # init dataset tools
  83. hit_callback = DatasetIndexToolCallbackHandler(
  84. queue_manager=queue_manager,
  85. app_id=self.app_config.app_id,
  86. message_id=message.id,
  87. user_id=user_id,
  88. invoke_from=self.application_generate_entity.invoke_from,
  89. )
  90. self.dataset_tools = DatasetRetrieverTool.get_dataset_tools(
  91. tenant_id=tenant_id,
  92. dataset_ids=app_config.dataset.dataset_ids if app_config.dataset else [],
  93. retrieve_config=app_config.dataset.retrieve_config if app_config.dataset else None,
  94. return_resource=app_config.additional_features.show_retrieve_source,
  95. invoke_from=application_generate_entity.invoke_from,
  96. hit_callback=hit_callback,
  97. )
  98. # get how many agent thoughts have been created
  99. self.agent_thought_count = (
  100. db.session.query(MessageAgentThought)
  101. .filter(
  102. MessageAgentThought.message_id == self.message.id,
  103. )
  104. .count()
  105. )
  106. db.session.close()
  107. # check if model supports stream tool call
  108. llm_model = cast(LargeLanguageModel, model_instance.model_type_instance)
  109. model_schema = llm_model.get_model_schema(model_instance.model, model_instance.credentials)
  110. if model_schema and ModelFeature.STREAM_TOOL_CALL in (model_schema.features or []):
  111. self.stream_tool_call = True
  112. else:
  113. self.stream_tool_call = False
  114. # check if model supports vision
  115. if model_schema and ModelFeature.VISION in (model_schema.features or []):
  116. self.files = application_generate_entity.files
  117. else:
  118. self.files = []
  119. self.query = None
  120. self._current_thoughts: list[PromptMessage] = []
  121. def _repack_app_generate_entity(
  122. self, app_generate_entity: AgentChatAppGenerateEntity
  123. ) -> AgentChatAppGenerateEntity:
  124. """
  125. Repack app generate entity
  126. """
  127. if app_generate_entity.app_config.prompt_template.simple_prompt_template is None:
  128. app_generate_entity.app_config.prompt_template.simple_prompt_template = ""
  129. return app_generate_entity
  130. def _convert_tool_to_prompt_message_tool(self, tool: AgentToolEntity) -> tuple[PromptMessageTool, Tool]:
  131. """
  132. convert tool to prompt message tool
  133. """
  134. tool_entity = ToolManager.get_agent_tool_runtime(
  135. tenant_id=self.tenant_id,
  136. app_id=self.app_config.app_id,
  137. agent_tool=tool,
  138. invoke_from=self.application_generate_entity.invoke_from,
  139. )
  140. tool_entity.load_variables(self.variables_pool)
  141. message_tool = PromptMessageTool(
  142. name=tool.tool_name,
  143. description=tool_entity.description.llm,
  144. parameters={
  145. "type": "object",
  146. "properties": {},
  147. "required": [],
  148. },
  149. )
  150. parameters = tool_entity.get_all_runtime_parameters()
  151. for parameter in parameters:
  152. if parameter.form != ToolParameter.ToolParameterForm.LLM:
  153. continue
  154. parameter_type = parameter.type.as_normal_type()
  155. if parameter.type in {
  156. ToolParameter.ToolParameterType.SYSTEM_FILES,
  157. ToolParameter.ToolParameterType.FILE,
  158. ToolParameter.ToolParameterType.FILES,
  159. }:
  160. continue
  161. enum = []
  162. if parameter.type == ToolParameter.ToolParameterType.SELECT:
  163. enum = [option.value for option in parameter.options]
  164. message_tool.parameters["properties"][parameter.name] = {
  165. "type": parameter_type,
  166. "description": parameter.llm_description or "",
  167. }
  168. if len(enum) > 0:
  169. message_tool.parameters["properties"][parameter.name]["enum"] = enum
  170. if parameter.required:
  171. message_tool.parameters["required"].append(parameter.name)
  172. return message_tool, tool_entity
  173. def _convert_dataset_retriever_tool_to_prompt_message_tool(self, tool: DatasetRetrieverTool) -> PromptMessageTool:
  174. """
  175. convert dataset retriever tool to prompt message tool
  176. """
  177. prompt_tool = PromptMessageTool(
  178. name=tool.identity.name,
  179. description=tool.description.llm,
  180. parameters={
  181. "type": "object",
  182. "properties": {},
  183. "required": [],
  184. },
  185. )
  186. for parameter in tool.get_runtime_parameters():
  187. parameter_type = "string"
  188. prompt_tool.parameters["properties"][parameter.name] = {
  189. "type": parameter_type,
  190. "description": parameter.llm_description or "",
  191. }
  192. if parameter.required:
  193. if parameter.name not in prompt_tool.parameters["required"]:
  194. prompt_tool.parameters["required"].append(parameter.name)
  195. return prompt_tool
  196. def _init_prompt_tools(self) -> tuple[Mapping[str, Tool], Sequence[PromptMessageTool]]:
  197. """
  198. Init tools
  199. """
  200. tool_instances = {}
  201. prompt_messages_tools = []
  202. for tool in self.app_config.agent.tools if self.app_config.agent else []:
  203. try:
  204. prompt_tool, tool_entity = self._convert_tool_to_prompt_message_tool(tool)
  205. except Exception:
  206. # api tool may be deleted
  207. continue
  208. # save tool entity
  209. tool_instances[tool.tool_name] = tool_entity
  210. # save prompt tool
  211. prompt_messages_tools.append(prompt_tool)
  212. # convert dataset tools into ModelRuntime Tool format
  213. for dataset_tool in self.dataset_tools:
  214. prompt_tool = self._convert_dataset_retriever_tool_to_prompt_message_tool(dataset_tool)
  215. # save prompt tool
  216. prompt_messages_tools.append(prompt_tool)
  217. # save tool entity
  218. tool_instances[dataset_tool.identity.name] = dataset_tool
  219. return tool_instances, prompt_messages_tools
  220. def update_prompt_message_tool(self, tool: Tool, prompt_tool: PromptMessageTool) -> PromptMessageTool:
  221. """
  222. update prompt message tool
  223. """
  224. # try to get tool runtime parameters
  225. tool_runtime_parameters = tool.get_runtime_parameters() or []
  226. for parameter in tool_runtime_parameters:
  227. if parameter.form != ToolParameter.ToolParameterForm.LLM:
  228. continue
  229. parameter_type = parameter.type.as_normal_type()
  230. if parameter.type in {
  231. ToolParameter.ToolParameterType.SYSTEM_FILES,
  232. ToolParameter.ToolParameterType.FILE,
  233. ToolParameter.ToolParameterType.FILES,
  234. }:
  235. continue
  236. enum = []
  237. if parameter.type == ToolParameter.ToolParameterType.SELECT:
  238. enum = [option.value for option in parameter.options]
  239. prompt_tool.parameters["properties"][parameter.name] = {
  240. "type": parameter_type,
  241. "description": parameter.llm_description or "",
  242. }
  243. if len(enum) > 0:
  244. prompt_tool.parameters["properties"][parameter.name]["enum"] = enum
  245. if parameter.required:
  246. if parameter.name not in prompt_tool.parameters["required"]:
  247. prompt_tool.parameters["required"].append(parameter.name)
  248. return prompt_tool
  249. def create_agent_thought(
  250. self, message_id: str, message: str, tool_name: str, tool_input: str, messages_ids: list[str]
  251. ) -> MessageAgentThought:
  252. """
  253. Create agent thought
  254. """
  255. thought = MessageAgentThought(
  256. message_id=message_id,
  257. message_chain_id=None,
  258. thought="",
  259. tool=tool_name,
  260. tool_labels_str="{}",
  261. tool_meta_str="{}",
  262. tool_input=tool_input,
  263. message=message,
  264. message_token=0,
  265. message_unit_price=0,
  266. message_price_unit=0,
  267. message_files=json.dumps(messages_ids) if messages_ids else "",
  268. answer="",
  269. observation="",
  270. answer_token=0,
  271. answer_unit_price=0,
  272. answer_price_unit=0,
  273. tokens=0,
  274. total_price=0,
  275. position=self.agent_thought_count + 1,
  276. currency="USD",
  277. latency=0,
  278. created_by_role="account",
  279. created_by=self.user_id,
  280. )
  281. db.session.add(thought)
  282. db.session.commit()
  283. db.session.refresh(thought)
  284. db.session.close()
  285. self.agent_thought_count += 1
  286. return thought
  287. def save_agent_thought(
  288. self,
  289. agent_thought: MessageAgentThought,
  290. tool_name: str,
  291. tool_input: Union[str, dict],
  292. thought: str,
  293. observation: Union[str, dict],
  294. tool_invoke_meta: Union[str, dict],
  295. answer: str,
  296. messages_ids: list[str],
  297. llm_usage: LLMUsage = None,
  298. ) -> MessageAgentThought:
  299. """
  300. Save agent thought
  301. """
  302. agent_thought = db.session.query(MessageAgentThought).filter(MessageAgentThought.id == agent_thought.id).first()
  303. if thought is not None:
  304. agent_thought.thought = thought
  305. if tool_name is not None:
  306. agent_thought.tool = tool_name
  307. if tool_input is not None:
  308. if isinstance(tool_input, dict):
  309. try:
  310. tool_input = json.dumps(tool_input, ensure_ascii=False)
  311. except Exception as e:
  312. tool_input = json.dumps(tool_input)
  313. agent_thought.tool_input = tool_input
  314. if observation is not None:
  315. if isinstance(observation, dict):
  316. try:
  317. observation = json.dumps(observation, ensure_ascii=False)
  318. except Exception as e:
  319. observation = json.dumps(observation)
  320. agent_thought.observation = observation
  321. if answer is not None:
  322. agent_thought.answer = answer
  323. if messages_ids is not None and len(messages_ids) > 0:
  324. agent_thought.message_files = json.dumps(messages_ids)
  325. if llm_usage:
  326. agent_thought.message_token = llm_usage.prompt_tokens
  327. agent_thought.message_price_unit = llm_usage.prompt_price_unit
  328. agent_thought.message_unit_price = llm_usage.prompt_unit_price
  329. agent_thought.answer_token = llm_usage.completion_tokens
  330. agent_thought.answer_price_unit = llm_usage.completion_price_unit
  331. agent_thought.answer_unit_price = llm_usage.completion_unit_price
  332. agent_thought.tokens = llm_usage.total_tokens
  333. agent_thought.total_price = llm_usage.total_price
  334. # check if tool labels is not empty
  335. labels = agent_thought.tool_labels or {}
  336. tools = agent_thought.tool.split(";") if agent_thought.tool else []
  337. for tool in tools:
  338. if not tool:
  339. continue
  340. if tool not in labels:
  341. tool_label = ToolManager.get_tool_label(tool)
  342. if tool_label:
  343. labels[tool] = tool_label.to_dict()
  344. else:
  345. labels[tool] = {"en_US": tool, "zh_Hans": tool}
  346. agent_thought.tool_labels_str = json.dumps(labels)
  347. if tool_invoke_meta is not None:
  348. if isinstance(tool_invoke_meta, dict):
  349. try:
  350. tool_invoke_meta = json.dumps(tool_invoke_meta, ensure_ascii=False)
  351. except Exception as e:
  352. tool_invoke_meta = json.dumps(tool_invoke_meta)
  353. agent_thought.tool_meta_str = tool_invoke_meta
  354. db.session.commit()
  355. db.session.close()
  356. def update_db_variables(self, tool_variables: ToolRuntimeVariablePool, db_variables: ToolConversationVariables):
  357. """
  358. convert tool variables to db variables
  359. """
  360. db_variables = (
  361. db.session.query(ToolConversationVariables)
  362. .filter(
  363. ToolConversationVariables.conversation_id == self.message.conversation_id,
  364. )
  365. .first()
  366. )
  367. db_variables.updated_at = datetime.now(timezone.utc).replace(tzinfo=None)
  368. db_variables.variables_str = json.dumps(jsonable_encoder(tool_variables.pool))
  369. db.session.commit()
  370. db.session.close()
  371. def organize_agent_history(self, prompt_messages: list[PromptMessage]) -> list[PromptMessage]:
  372. """
  373. Organize agent history
  374. """
  375. result = []
  376. # check if there is a system message in the beginning of the conversation
  377. for prompt_message in prompt_messages:
  378. if isinstance(prompt_message, SystemPromptMessage):
  379. result.append(prompt_message)
  380. messages: list[Message] = (
  381. db.session.query(Message)
  382. .filter(
  383. Message.conversation_id == self.message.conversation_id,
  384. )
  385. .order_by(Message.created_at.desc())
  386. .all()
  387. )
  388. messages = list(reversed(extract_thread_messages(messages)))
  389. for message in messages:
  390. if message.id == self.message.id:
  391. continue
  392. result.append(self.organize_agent_user_prompt(message))
  393. agent_thoughts: list[MessageAgentThought] = message.agent_thoughts
  394. if agent_thoughts:
  395. for agent_thought in agent_thoughts:
  396. tools = agent_thought.tool
  397. if tools:
  398. tools = tools.split(";")
  399. tool_calls: list[AssistantPromptMessage.ToolCall] = []
  400. tool_call_response: list[ToolPromptMessage] = []
  401. try:
  402. tool_inputs = json.loads(agent_thought.tool_input)
  403. except Exception as e:
  404. tool_inputs = {tool: {} for tool in tools}
  405. try:
  406. tool_responses = json.loads(agent_thought.observation)
  407. except Exception as e:
  408. tool_responses = dict.fromkeys(tools, agent_thought.observation)
  409. for tool in tools:
  410. # generate a uuid for tool call
  411. tool_call_id = str(uuid.uuid4())
  412. tool_calls.append(
  413. AssistantPromptMessage.ToolCall(
  414. id=tool_call_id,
  415. type="function",
  416. function=AssistantPromptMessage.ToolCall.ToolCallFunction(
  417. name=tool,
  418. arguments=json.dumps(tool_inputs.get(tool, {})),
  419. ),
  420. )
  421. )
  422. tool_call_response.append(
  423. ToolPromptMessage(
  424. content=tool_responses.get(tool, agent_thought.observation),
  425. name=tool,
  426. tool_call_id=tool_call_id,
  427. )
  428. )
  429. result.extend(
  430. [
  431. AssistantPromptMessage(
  432. content=agent_thought.thought,
  433. tool_calls=tool_calls,
  434. ),
  435. *tool_call_response,
  436. ]
  437. )
  438. if not tools:
  439. result.append(AssistantPromptMessage(content=agent_thought.thought))
  440. else:
  441. if message.answer:
  442. result.append(AssistantPromptMessage(content=message.answer))
  443. db.session.close()
  444. return result
  445. def organize_agent_user_prompt(self, message: Message) -> UserPromptMessage:
  446. files = db.session.query(MessageFile).filter(MessageFile.message_id == message.id).all()
  447. if files:
  448. file_extra_config = FileUploadConfigManager.convert(message.app_model_config.to_dict())
  449. if file_extra_config:
  450. file_objs = file_factory.build_from_message_files(
  451. message_files=files, tenant_id=self.tenant_id, config=file_extra_config
  452. )
  453. else:
  454. file_objs = []
  455. if not file_objs:
  456. return UserPromptMessage(content=message.query)
  457. else:
  458. prompt_message_contents: list[PromptMessageContent] = []
  459. prompt_message_contents.append(TextPromptMessageContent(data=message.query))
  460. for file_obj in file_objs:
  461. prompt_message_contents.append(file_manager.to_prompt_message_content(file_obj))
  462. return UserPromptMessage(content=prompt_message_contents)
  463. else:
  464. return UserPromptMessage(content=message.query)