fc_agent_runner.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414
  1. import json
  2. import logging
  3. from collections.abc import Generator
  4. from typing import Any, Union
  5. from core.agent.base_agent_runner import BaseAgentRunner
  6. from core.app.apps.base_app_queue_manager import PublishFrom
  7. from core.app.entities.queue_entities import QueueAgentThoughtEvent, QueueMessageEndEvent, QueueMessageFileEvent
  8. from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
  9. from core.model_runtime.entities.message_entities import (
  10. AssistantPromptMessage,
  11. PromptMessage,
  12. PromptMessageTool,
  13. SystemPromptMessage,
  14. ToolPromptMessage,
  15. UserPromptMessage,
  16. )
  17. from core.tools.entities.tool_entities import ToolInvokeMeta
  18. from core.tools.tool_engine import ToolEngine
  19. from models.model import Conversation, Message, MessageAgentThought
  20. logger = logging.getLogger(__name__)
  21. class FunctionCallAgentRunner(BaseAgentRunner):
  22. def run(self, conversation: Conversation,
  23. message: Message,
  24. query: str,
  25. ) -> Generator[LLMResultChunk, None, None]:
  26. """
  27. Run FunctionCall agent application
  28. """
  29. app_generate_entity = self.application_generate_entity
  30. app_config = self.app_config
  31. prompt_template = app_config.prompt_template.simple_prompt_template or ''
  32. prompt_messages = self.history_prompt_messages
  33. prompt_messages = self.organize_prompt_messages(
  34. prompt_template=prompt_template,
  35. query=query,
  36. prompt_messages=prompt_messages
  37. )
  38. # convert tools into ModelRuntime Tool format
  39. prompt_messages_tools: list[PromptMessageTool] = []
  40. tool_instances = {}
  41. for tool in app_config.agent.tools if app_config.agent else []:
  42. try:
  43. prompt_tool, tool_entity = self._convert_tool_to_prompt_message_tool(tool)
  44. except Exception:
  45. # api tool may be deleted
  46. continue
  47. # save tool entity
  48. tool_instances[tool.tool_name] = tool_entity
  49. # save prompt tool
  50. prompt_messages_tools.append(prompt_tool)
  51. # convert dataset tools into ModelRuntime Tool format
  52. for dataset_tool in self.dataset_tools:
  53. prompt_tool = self._convert_dataset_retriever_tool_to_prompt_message_tool(dataset_tool)
  54. # save prompt tool
  55. prompt_messages_tools.append(prompt_tool)
  56. # save tool entity
  57. tool_instances[dataset_tool.identity.name] = dataset_tool
  58. iteration_step = 1
  59. max_iteration_steps = min(app_config.agent.max_iteration, 5) + 1
  60. # continue to run until there is not any tool call
  61. function_call_state = True
  62. agent_thoughts: list[MessageAgentThought] = []
  63. llm_usage = {
  64. 'usage': None
  65. }
  66. final_answer = ''
  67. def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage):
  68. if not final_llm_usage_dict['usage']:
  69. final_llm_usage_dict['usage'] = usage
  70. else:
  71. llm_usage = final_llm_usage_dict['usage']
  72. llm_usage.prompt_tokens += usage.prompt_tokens
  73. llm_usage.completion_tokens += usage.completion_tokens
  74. llm_usage.prompt_price += usage.prompt_price
  75. llm_usage.completion_price += usage.completion_price
  76. model_instance = self.model_instance
  77. while function_call_state and iteration_step <= max_iteration_steps:
  78. function_call_state = False
  79. if iteration_step == max_iteration_steps:
  80. # the last iteration, remove all tools
  81. prompt_messages_tools = []
  82. message_file_ids = []
  83. agent_thought = self.create_agent_thought(
  84. message_id=message.id,
  85. message='',
  86. tool_name='',
  87. tool_input='',
  88. messages_ids=message_file_ids
  89. )
  90. # recalc llm max tokens
  91. self.recalc_llm_max_tokens(self.model_config, prompt_messages)
  92. # invoke model
  93. chunks: Union[Generator[LLMResultChunk, None, None], LLMResult] = model_instance.invoke_llm(
  94. prompt_messages=prompt_messages,
  95. model_parameters=app_generate_entity.model_config.parameters,
  96. tools=prompt_messages_tools,
  97. stop=app_generate_entity.model_config.stop,
  98. stream=self.stream_tool_call,
  99. user=self.user_id,
  100. callbacks=[],
  101. )
  102. tool_calls: list[tuple[str, str, dict[str, Any]]] = []
  103. # save full response
  104. response = ''
  105. # save tool call names and inputs
  106. tool_call_names = ''
  107. tool_call_inputs = ''
  108. current_llm_usage = None
  109. if self.stream_tool_call:
  110. is_first_chunk = True
  111. for chunk in chunks:
  112. if is_first_chunk:
  113. self.queue_manager.publish(QueueAgentThoughtEvent(
  114. agent_thought_id=agent_thought.id
  115. ), PublishFrom.APPLICATION_MANAGER)
  116. is_first_chunk = False
  117. # check if there is any tool call
  118. if self.check_tool_calls(chunk):
  119. function_call_state = True
  120. tool_calls.extend(self.extract_tool_calls(chunk))
  121. tool_call_names = ';'.join([tool_call[1] for tool_call in tool_calls])
  122. try:
  123. tool_call_inputs = json.dumps({
  124. tool_call[1]: tool_call[2] for tool_call in tool_calls
  125. }, ensure_ascii=False)
  126. except json.JSONDecodeError as e:
  127. # ensure ascii to avoid encoding error
  128. tool_call_inputs = json.dumps({
  129. tool_call[1]: tool_call[2] for tool_call in tool_calls
  130. })
  131. if chunk.delta.message and chunk.delta.message.content:
  132. if isinstance(chunk.delta.message.content, list):
  133. for content in chunk.delta.message.content:
  134. response += content.data
  135. else:
  136. response += chunk.delta.message.content
  137. if chunk.delta.usage:
  138. increase_usage(llm_usage, chunk.delta.usage)
  139. current_llm_usage = chunk.delta.usage
  140. yield chunk
  141. else:
  142. result: LLMResult = chunks
  143. # check if there is any tool call
  144. if self.check_blocking_tool_calls(result):
  145. function_call_state = True
  146. tool_calls.extend(self.extract_blocking_tool_calls(result))
  147. tool_call_names = ';'.join([tool_call[1] for tool_call in tool_calls])
  148. try:
  149. tool_call_inputs = json.dumps({
  150. tool_call[1]: tool_call[2] for tool_call in tool_calls
  151. }, ensure_ascii=False)
  152. except json.JSONDecodeError as e:
  153. # ensure ascii to avoid encoding error
  154. tool_call_inputs = json.dumps({
  155. tool_call[1]: tool_call[2] for tool_call in tool_calls
  156. })
  157. if result.usage:
  158. increase_usage(llm_usage, result.usage)
  159. current_llm_usage = result.usage
  160. if result.message and result.message.content:
  161. if isinstance(result.message.content, list):
  162. for content in result.message.content:
  163. response += content.data
  164. else:
  165. response += result.message.content
  166. if not result.message.content:
  167. result.message.content = ''
  168. self.queue_manager.publish(QueueAgentThoughtEvent(
  169. agent_thought_id=agent_thought.id
  170. ), PublishFrom.APPLICATION_MANAGER)
  171. yield LLMResultChunk(
  172. model=model_instance.model,
  173. prompt_messages=result.prompt_messages,
  174. system_fingerprint=result.system_fingerprint,
  175. delta=LLMResultChunkDelta(
  176. index=0,
  177. message=result.message,
  178. usage=result.usage,
  179. )
  180. )
  181. if tool_calls:
  182. prompt_messages.append(AssistantPromptMessage(
  183. content='',
  184. name='',
  185. tool_calls=[AssistantPromptMessage.ToolCall(
  186. id=tool_call[0],
  187. type='function',
  188. function=AssistantPromptMessage.ToolCall.ToolCallFunction(
  189. name=tool_call[1],
  190. arguments=json.dumps(tool_call[2], ensure_ascii=False)
  191. )
  192. ) for tool_call in tool_calls]
  193. ))
  194. # save thought
  195. self.save_agent_thought(
  196. agent_thought=agent_thought,
  197. tool_name=tool_call_names,
  198. tool_input=tool_call_inputs,
  199. thought=response,
  200. tool_invoke_meta=None,
  201. observation=None,
  202. answer=response,
  203. messages_ids=[],
  204. llm_usage=current_llm_usage
  205. )
  206. self.queue_manager.publish(QueueAgentThoughtEvent(
  207. agent_thought_id=agent_thought.id
  208. ), PublishFrom.APPLICATION_MANAGER)
  209. final_answer += response + '\n'
  210. # update prompt messages
  211. if response.strip():
  212. prompt_messages.append(AssistantPromptMessage(
  213. content=response,
  214. ))
  215. # call tools
  216. tool_responses = []
  217. for tool_call_id, tool_call_name, tool_call_args in tool_calls:
  218. tool_instance = tool_instances.get(tool_call_name)
  219. if not tool_instance:
  220. tool_response = {
  221. "tool_call_id": tool_call_id,
  222. "tool_call_name": tool_call_name,
  223. "tool_response": f"there is not a tool named {tool_call_name}",
  224. "meta": ToolInvokeMeta.error_instance(f"there is not a tool named {tool_call_name}").to_dict()
  225. }
  226. else:
  227. # invoke tool
  228. tool_invoke_response, message_files, tool_invoke_meta = ToolEngine.agent_invoke(
  229. tool=tool_instance,
  230. tool_parameters=tool_call_args,
  231. user_id=self.user_id,
  232. tenant_id=self.tenant_id,
  233. message=self.message,
  234. invoke_from=self.application_generate_entity.invoke_from,
  235. agent_tool_callback=self.agent_callback,
  236. )
  237. # publish files
  238. for message_file, save_as in message_files:
  239. if save_as:
  240. self.variables_pool.set_file(tool_name=tool_call_name, value=message_file.id, name=save_as)
  241. # publish message file
  242. self.queue_manager.publish(QueueMessageFileEvent(
  243. message_file_id=message_file.id
  244. ), PublishFrom.APPLICATION_MANAGER)
  245. # add message file ids
  246. message_file_ids.append(message_file.id)
  247. tool_response = {
  248. "tool_call_id": tool_call_id,
  249. "tool_call_name": tool_call_name,
  250. "tool_response": tool_invoke_response,
  251. "meta": tool_invoke_meta.to_dict()
  252. }
  253. tool_responses.append(tool_response)
  254. prompt_messages = self.organize_prompt_messages(
  255. prompt_template=prompt_template,
  256. query=None,
  257. tool_call_id=tool_call_id,
  258. tool_call_name=tool_call_name,
  259. tool_response=tool_response['tool_response'],
  260. prompt_messages=prompt_messages,
  261. )
  262. if len(tool_responses) > 0:
  263. # save agent thought
  264. self.save_agent_thought(
  265. agent_thought=agent_thought,
  266. tool_name=None,
  267. tool_input=None,
  268. thought=None,
  269. tool_invoke_meta={
  270. tool_response['tool_call_name']: tool_response['meta']
  271. for tool_response in tool_responses
  272. },
  273. observation={
  274. tool_response['tool_call_name']: tool_response['tool_response']
  275. for tool_response in tool_responses
  276. },
  277. answer=None,
  278. messages_ids=message_file_ids
  279. )
  280. self.queue_manager.publish(QueueAgentThoughtEvent(
  281. agent_thought_id=agent_thought.id
  282. ), PublishFrom.APPLICATION_MANAGER)
  283. # update prompt tool
  284. for prompt_tool in prompt_messages_tools:
  285. self.update_prompt_message_tool(tool_instances[prompt_tool.name], prompt_tool)
  286. iteration_step += 1
  287. self.update_db_variables(self.variables_pool, self.db_variables_pool)
  288. # publish end event
  289. self.queue_manager.publish(QueueMessageEndEvent(llm_result=LLMResult(
  290. model=model_instance.model,
  291. prompt_messages=prompt_messages,
  292. message=AssistantPromptMessage(
  293. content=final_answer
  294. ),
  295. usage=llm_usage['usage'] if llm_usage['usage'] else LLMUsage.empty_usage(),
  296. system_fingerprint=''
  297. )), PublishFrom.APPLICATION_MANAGER)
  298. def check_tool_calls(self, llm_result_chunk: LLMResultChunk) -> bool:
  299. """
  300. Check if there is any tool call in llm result chunk
  301. """
  302. if llm_result_chunk.delta.message.tool_calls:
  303. return True
  304. return False
  305. def check_blocking_tool_calls(self, llm_result: LLMResult) -> bool:
  306. """
  307. Check if there is any blocking tool call in llm result
  308. """
  309. if llm_result.message.tool_calls:
  310. return True
  311. return False
  312. def extract_tool_calls(self, llm_result_chunk: LLMResultChunk) -> Union[None, list[tuple[str, str, dict[str, Any]]]]:
  313. """
  314. Extract tool calls from llm result chunk
  315. Returns:
  316. List[Tuple[str, str, Dict[str, Any]]]: [(tool_call_id, tool_call_name, tool_call_args)]
  317. """
  318. tool_calls = []
  319. for prompt_message in llm_result_chunk.delta.message.tool_calls:
  320. tool_calls.append((
  321. prompt_message.id,
  322. prompt_message.function.name,
  323. json.loads(prompt_message.function.arguments),
  324. ))
  325. return tool_calls
  326. def extract_blocking_tool_calls(self, llm_result: LLMResult) -> Union[None, list[tuple[str, str, dict[str, Any]]]]:
  327. """
  328. Extract blocking tool calls from llm result
  329. Returns:
  330. List[Tuple[str, str, Dict[str, Any]]]: [(tool_call_id, tool_call_name, tool_call_args)]
  331. """
  332. tool_calls = []
  333. for prompt_message in llm_result.message.tool_calls:
  334. tool_calls.append((
  335. prompt_message.id,
  336. prompt_message.function.name,
  337. json.loads(prompt_message.function.arguments),
  338. ))
  339. return tool_calls
  340. def organize_prompt_messages(self, prompt_template: str,
  341. query: str = None,
  342. tool_call_id: str = None, tool_call_name: str = None, tool_response: str = None,
  343. prompt_messages: list[PromptMessage] = None
  344. ) -> list[PromptMessage]:
  345. """
  346. Organize prompt messages
  347. """
  348. if not prompt_messages:
  349. prompt_messages = [
  350. SystemPromptMessage(content=prompt_template),
  351. UserPromptMessage(content=query),
  352. ]
  353. else:
  354. if tool_response:
  355. prompt_messages = prompt_messages.copy()
  356. prompt_messages.append(
  357. ToolPromptMessage(
  358. content=tool_response,
  359. tool_call_id=tool_call_id,
  360. name=tool_call_name,
  361. )
  362. )
  363. return prompt_messages