base_agent_runner.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474
  1. import json
  2. import logging
  3. import uuid
  4. from datetime import datetime
  5. from typing import Optional, Union, cast
  6. from core.agent.entities import AgentEntity, AgentToolEntity
  7. from core.app.apps.agent_chat.app_config_manager import AgentChatAppConfig
  8. from core.app.apps.base_app_queue_manager import AppQueueManager
  9. from core.app.apps.base_app_runner import AppRunner
  10. from core.app.entities.app_invoke_entities import (
  11. AgentChatAppGenerateEntity,
  12. ModelConfigWithCredentialsEntity,
  13. )
  14. from core.callback_handler.agent_tool_callback_handler import DifyAgentCallbackHandler
  15. from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
  16. from core.memory.token_buffer_memory import TokenBufferMemory
  17. from core.model_manager import ModelInstance
  18. from core.model_runtime.entities.llm_entities import LLMUsage
  19. from core.model_runtime.entities.message_entities import (
  20. AssistantPromptMessage,
  21. PromptMessage,
  22. PromptMessageTool,
  23. SystemPromptMessage,
  24. ToolPromptMessage,
  25. UserPromptMessage,
  26. )
  27. from core.model_runtime.entities.model_entities import ModelFeature
  28. from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
  29. from core.model_runtime.utils.encoders import jsonable_encoder
  30. from core.tools.entities.tool_entities import (
  31. ToolInvokeMessage,
  32. ToolParameter,
  33. ToolRuntimeVariablePool,
  34. )
  35. from core.tools.tool.dataset_retriever_tool import DatasetRetrieverTool
  36. from core.tools.tool.tool import Tool
  37. from core.tools.tool_manager import ToolManager
  38. from extensions.ext_database import db
  39. from models.model import Message, MessageAgentThought
  40. from models.tools import ToolConversationVariables
  41. logger = logging.getLogger(__name__)
  42. class BaseAgentRunner(AppRunner):
  43. def __init__(self, tenant_id: str,
  44. application_generate_entity: AgentChatAppGenerateEntity,
  45. app_config: AgentChatAppConfig,
  46. model_config: ModelConfigWithCredentialsEntity,
  47. config: AgentEntity,
  48. queue_manager: AppQueueManager,
  49. message: Message,
  50. user_id: str,
  51. memory: Optional[TokenBufferMemory] = None,
  52. prompt_messages: Optional[list[PromptMessage]] = None,
  53. variables_pool: Optional[ToolRuntimeVariablePool] = None,
  54. db_variables: Optional[ToolConversationVariables] = None,
  55. model_instance: ModelInstance = None
  56. ) -> None:
  57. """
  58. Agent runner
  59. :param tenant_id: tenant id
  60. :param app_config: app generate entity
  61. :param model_config: model config
  62. :param config: dataset config
  63. :param queue_manager: queue manager
  64. :param message: message
  65. :param user_id: user id
  66. :param agent_llm_callback: agent llm callback
  67. :param callback: callback
  68. :param memory: memory
  69. """
  70. self.tenant_id = tenant_id
  71. self.application_generate_entity = application_generate_entity
  72. self.app_config = app_config
  73. self.model_config = model_config
  74. self.config = config
  75. self.queue_manager = queue_manager
  76. self.message = message
  77. self.user_id = user_id
  78. self.memory = memory
  79. self.history_prompt_messages = self.organize_agent_history(
  80. prompt_messages=prompt_messages or []
  81. )
  82. self.variables_pool = variables_pool
  83. self.db_variables_pool = db_variables
  84. self.model_instance = model_instance
  85. # init callback
  86. self.agent_callback = DifyAgentCallbackHandler()
  87. # init dataset tools
  88. hit_callback = DatasetIndexToolCallbackHandler(
  89. queue_manager=queue_manager,
  90. app_id=self.app_config.app_id,
  91. message_id=message.id,
  92. user_id=user_id,
  93. invoke_from=self.application_generate_entity.invoke_from,
  94. )
  95. self.dataset_tools = DatasetRetrieverTool.get_dataset_tools(
  96. tenant_id=tenant_id,
  97. dataset_ids=app_config.dataset.dataset_ids if app_config.dataset else [],
  98. retrieve_config=app_config.dataset.retrieve_config if app_config.dataset else None,
  99. return_resource=app_config.additional_features.show_retrieve_source,
  100. invoke_from=application_generate_entity.invoke_from,
  101. hit_callback=hit_callback
  102. )
  103. # get how many agent thoughts have been created
  104. self.agent_thought_count = db.session.query(MessageAgentThought).filter(
  105. MessageAgentThought.message_id == self.message.id,
  106. ).count()
  107. db.session.close()
  108. # check if model supports stream tool call
  109. llm_model = cast(LargeLanguageModel, model_instance.model_type_instance)
  110. model_schema = llm_model.get_model_schema(model_instance.model, model_instance.credentials)
  111. if model_schema and ModelFeature.STREAM_TOOL_CALL in (model_schema.features or []):
  112. self.stream_tool_call = True
  113. else:
  114. self.stream_tool_call = False
  115. def _repack_app_generate_entity(self, app_generate_entity: AgentChatAppGenerateEntity) \
  116. -> AgentChatAppGenerateEntity:
  117. """
  118. Repack app generate entity
  119. """
  120. if app_generate_entity.app_config.prompt_template.simple_prompt_template is None:
  121. app_generate_entity.app_config.prompt_template.simple_prompt_template = ''
  122. return app_generate_entity
  123. def _convert_tool_response_to_str(self, tool_response: list[ToolInvokeMessage]) -> str:
  124. """
  125. Handle tool response
  126. """
  127. result = ''
  128. for response in tool_response:
  129. if response.type == ToolInvokeMessage.MessageType.TEXT:
  130. result += response.message
  131. elif response.type == ToolInvokeMessage.MessageType.LINK:
  132. result += f"result link: {response.message}. please tell user to check it."
  133. elif response.type == ToolInvokeMessage.MessageType.IMAGE_LINK or \
  134. response.type == ToolInvokeMessage.MessageType.IMAGE:
  135. result += "image has been created and sent to user already, you do not need to create it, just tell the user to check it now."
  136. else:
  137. result += f"tool response: {response.message}."
  138. return result
  139. def _convert_tool_to_prompt_message_tool(self, tool: AgentToolEntity) -> tuple[PromptMessageTool, Tool]:
  140. """
  141. convert tool to prompt message tool
  142. """
  143. tool_entity = ToolManager.get_agent_tool_runtime(
  144. tenant_id=self.tenant_id,
  145. agent_tool=tool,
  146. )
  147. tool_entity.load_variables(self.variables_pool)
  148. message_tool = PromptMessageTool(
  149. name=tool.tool_name,
  150. description=tool_entity.description.llm,
  151. parameters={
  152. "type": "object",
  153. "properties": {},
  154. "required": [],
  155. }
  156. )
  157. parameters = tool_entity.get_all_runtime_parameters()
  158. for parameter in parameters:
  159. if parameter.form != ToolParameter.ToolParameterForm.LLM:
  160. continue
  161. parameter_type = 'string'
  162. enum = []
  163. if parameter.type == ToolParameter.ToolParameterType.STRING:
  164. parameter_type = 'string'
  165. elif parameter.type == ToolParameter.ToolParameterType.BOOLEAN:
  166. parameter_type = 'boolean'
  167. elif parameter.type == ToolParameter.ToolParameterType.NUMBER:
  168. parameter_type = 'number'
  169. elif parameter.type == ToolParameter.ToolParameterType.SELECT:
  170. for option in parameter.options:
  171. enum.append(option.value)
  172. parameter_type = 'string'
  173. else:
  174. raise ValueError(f"parameter type {parameter.type} is not supported")
  175. message_tool.parameters['properties'][parameter.name] = {
  176. "type": parameter_type,
  177. "description": parameter.llm_description or '',
  178. }
  179. if len(enum) > 0:
  180. message_tool.parameters['properties'][parameter.name]['enum'] = enum
  181. if parameter.required:
  182. message_tool.parameters['required'].append(parameter.name)
  183. return message_tool, tool_entity
  184. def _convert_dataset_retriever_tool_to_prompt_message_tool(self, tool: DatasetRetrieverTool) -> PromptMessageTool:
  185. """
  186. convert dataset retriever tool to prompt message tool
  187. """
  188. prompt_tool = PromptMessageTool(
  189. name=tool.identity.name,
  190. description=tool.description.llm,
  191. parameters={
  192. "type": "object",
  193. "properties": {},
  194. "required": [],
  195. }
  196. )
  197. for parameter in tool.get_runtime_parameters():
  198. parameter_type = 'string'
  199. prompt_tool.parameters['properties'][parameter.name] = {
  200. "type": parameter_type,
  201. "description": parameter.llm_description or '',
  202. }
  203. if parameter.required:
  204. if parameter.name not in prompt_tool.parameters['required']:
  205. prompt_tool.parameters['required'].append(parameter.name)
  206. return prompt_tool
  207. def update_prompt_message_tool(self, tool: Tool, prompt_tool: PromptMessageTool) -> PromptMessageTool:
  208. """
  209. update prompt message tool
  210. """
  211. # try to get tool runtime parameters
  212. tool_runtime_parameters = tool.get_runtime_parameters() or []
  213. for parameter in tool_runtime_parameters:
  214. if parameter.form != ToolParameter.ToolParameterForm.LLM:
  215. continue
  216. parameter_type = 'string'
  217. enum = []
  218. if parameter.type == ToolParameter.ToolParameterType.STRING:
  219. parameter_type = 'string'
  220. elif parameter.type == ToolParameter.ToolParameterType.BOOLEAN:
  221. parameter_type = 'boolean'
  222. elif parameter.type == ToolParameter.ToolParameterType.NUMBER:
  223. parameter_type = 'number'
  224. elif parameter.type == ToolParameter.ToolParameterType.SELECT:
  225. for option in parameter.options:
  226. enum.append(option.value)
  227. parameter_type = 'string'
  228. else:
  229. raise ValueError(f"parameter type {parameter.type} is not supported")
  230. prompt_tool.parameters['properties'][parameter.name] = {
  231. "type": parameter_type,
  232. "description": parameter.llm_description or '',
  233. }
  234. if len(enum) > 0:
  235. prompt_tool.parameters['properties'][parameter.name]['enum'] = enum
  236. if parameter.required:
  237. if parameter.name not in prompt_tool.parameters['required']:
  238. prompt_tool.parameters['required'].append(parameter.name)
  239. return prompt_tool
  240. def create_agent_thought(self, message_id: str, message: str,
  241. tool_name: str, tool_input: str, messages_ids: list[str]
  242. ) -> MessageAgentThought:
  243. """
  244. Create agent thought
  245. """
  246. thought = MessageAgentThought(
  247. message_id=message_id,
  248. message_chain_id=None,
  249. thought='',
  250. tool=tool_name,
  251. tool_labels_str='{}',
  252. tool_meta_str='{}',
  253. tool_input=tool_input,
  254. message=message,
  255. message_token=0,
  256. message_unit_price=0,
  257. message_price_unit=0,
  258. message_files=json.dumps(messages_ids) if messages_ids else '',
  259. answer='',
  260. observation='',
  261. answer_token=0,
  262. answer_unit_price=0,
  263. answer_price_unit=0,
  264. tokens=0,
  265. total_price=0,
  266. position=self.agent_thought_count + 1,
  267. currency='USD',
  268. latency=0,
  269. created_by_role='account',
  270. created_by=self.user_id,
  271. )
  272. db.session.add(thought)
  273. db.session.commit()
  274. db.session.refresh(thought)
  275. db.session.close()
  276. self.agent_thought_count += 1
  277. return thought
  278. def save_agent_thought(self,
  279. agent_thought: MessageAgentThought,
  280. tool_name: str,
  281. tool_input: Union[str, dict],
  282. thought: str,
  283. observation: Union[str, str],
  284. tool_invoke_meta: Union[str, dict],
  285. answer: str,
  286. messages_ids: list[str],
  287. llm_usage: LLMUsage = None) -> MessageAgentThought:
  288. """
  289. Save agent thought
  290. """
  291. agent_thought = db.session.query(MessageAgentThought).filter(
  292. MessageAgentThought.id == agent_thought.id
  293. ).first()
  294. if thought is not None:
  295. agent_thought.thought = thought
  296. if tool_name is not None:
  297. agent_thought.tool = tool_name
  298. if tool_input is not None:
  299. if isinstance(tool_input, dict):
  300. try:
  301. tool_input = json.dumps(tool_input, ensure_ascii=False)
  302. except Exception as e:
  303. tool_input = json.dumps(tool_input)
  304. agent_thought.tool_input = tool_input
  305. if observation is not None:
  306. if isinstance(observation, dict):
  307. try:
  308. observation = json.dumps(observation, ensure_ascii=False)
  309. except Exception as e:
  310. observation = json.dumps(observation)
  311. agent_thought.observation = observation
  312. if answer is not None:
  313. agent_thought.answer = answer
  314. if messages_ids is not None and len(messages_ids) > 0:
  315. agent_thought.message_files = json.dumps(messages_ids)
  316. if llm_usage:
  317. agent_thought.message_token = llm_usage.prompt_tokens
  318. agent_thought.message_price_unit = llm_usage.prompt_price_unit
  319. agent_thought.message_unit_price = llm_usage.prompt_unit_price
  320. agent_thought.answer_token = llm_usage.completion_tokens
  321. agent_thought.answer_price_unit = llm_usage.completion_price_unit
  322. agent_thought.answer_unit_price = llm_usage.completion_unit_price
  323. agent_thought.tokens = llm_usage.total_tokens
  324. agent_thought.total_price = llm_usage.total_price
  325. # check if tool labels is not empty
  326. labels = agent_thought.tool_labels or {}
  327. tools = agent_thought.tool.split(';') if agent_thought.tool else []
  328. for tool in tools:
  329. if not tool:
  330. continue
  331. if tool not in labels:
  332. tool_label = ToolManager.get_tool_label(tool)
  333. if tool_label:
  334. labels[tool] = tool_label.to_dict()
  335. else:
  336. labels[tool] = {'en_US': tool, 'zh_Hans': tool}
  337. agent_thought.tool_labels_str = json.dumps(labels)
  338. if tool_invoke_meta is not None:
  339. if isinstance(tool_invoke_meta, dict):
  340. try:
  341. tool_invoke_meta = json.dumps(tool_invoke_meta, ensure_ascii=False)
  342. except Exception as e:
  343. tool_invoke_meta = json.dumps(tool_invoke_meta)
  344. agent_thought.tool_meta_str = tool_invoke_meta
  345. db.session.commit()
  346. db.session.close()
  347. def update_db_variables(self, tool_variables: ToolRuntimeVariablePool, db_variables: ToolConversationVariables):
  348. """
  349. convert tool variables to db variables
  350. """
  351. db_variables = db.session.query(ToolConversationVariables).filter(
  352. ToolConversationVariables.conversation_id == self.message.conversation_id,
  353. ).first()
  354. db_variables.updated_at = datetime.utcnow()
  355. db_variables.variables_str = json.dumps(jsonable_encoder(tool_variables.pool))
  356. db.session.commit()
  357. db.session.close()
  358. def organize_agent_history(self, prompt_messages: list[PromptMessage]) -> list[PromptMessage]:
  359. """
  360. Organize agent history
  361. """
  362. result = []
  363. # check if there is a system message in the beginning of the conversation
  364. if prompt_messages and isinstance(prompt_messages[0], SystemPromptMessage):
  365. result.append(prompt_messages[0])
  366. messages: list[Message] = db.session.query(Message).filter(
  367. Message.conversation_id == self.message.conversation_id,
  368. ).order_by(Message.created_at.asc()).all()
  369. for message in messages:
  370. result.append(UserPromptMessage(content=message.query))
  371. agent_thoughts: list[MessageAgentThought] = message.agent_thoughts
  372. if agent_thoughts:
  373. for agent_thought in agent_thoughts:
  374. tools = agent_thought.tool
  375. if tools:
  376. tools = tools.split(';')
  377. tool_calls: list[AssistantPromptMessage.ToolCall] = []
  378. tool_call_response: list[ToolPromptMessage] = []
  379. try:
  380. tool_inputs = json.loads(agent_thought.tool_input)
  381. except Exception as e:
  382. tool_inputs = { tool: {} for tool in tools }
  383. try:
  384. tool_responses = json.loads(agent_thought.observation)
  385. except Exception as e:
  386. tool_responses = { tool: agent_thought.observation for tool in tools }
  387. for tool in tools:
  388. # generate a uuid for tool call
  389. tool_call_id = str(uuid.uuid4())
  390. tool_calls.append(AssistantPromptMessage.ToolCall(
  391. id=tool_call_id,
  392. type='function',
  393. function=AssistantPromptMessage.ToolCall.ToolCallFunction(
  394. name=tool,
  395. arguments=json.dumps(tool_inputs.get(tool, {})),
  396. )
  397. ))
  398. tool_call_response.append(ToolPromptMessage(
  399. content=tool_responses.get(tool, agent_thought.observation),
  400. name=tool,
  401. tool_call_id=tool_call_id,
  402. ))
  403. result.extend([
  404. AssistantPromptMessage(
  405. content=agent_thought.thought,
  406. tool_calls=tool_calls,
  407. ),
  408. *tool_call_response
  409. ])
  410. if not tools:
  411. result.append(AssistantPromptMessage(content=agent_thought.thought))
  412. else:
  413. if message.answer:
  414. result.append(AssistantPromptMessage(content=message.answer))
  415. db.session.close()
  416. return result