fc_agent_runner.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463
  1. import json
  2. import logging
  3. from collections.abc import Generator
  4. from copy import deepcopy
  5. from typing import Any, Optional, Union
  6. from core.agent.base_agent_runner import BaseAgentRunner
  7. from core.app.apps.base_app_queue_manager import PublishFrom
  8. from core.app.entities.queue_entities import QueueAgentThoughtEvent, QueueMessageEndEvent, QueueMessageFileEvent
  9. from core.file import file_manager
  10. from core.model_runtime.entities import (
  11. AssistantPromptMessage,
  12. LLMResult,
  13. LLMResultChunk,
  14. LLMResultChunkDelta,
  15. LLMUsage,
  16. PromptMessage,
  17. PromptMessageContent,
  18. PromptMessageContentType,
  19. SystemPromptMessage,
  20. TextPromptMessageContent,
  21. ToolPromptMessage,
  22. UserPromptMessage,
  23. )
  24. from core.model_runtime.entities.message_entities import ImagePromptMessageContent
  25. from core.prompt.agent_history_prompt_transform import AgentHistoryPromptTransform
  26. from core.tools.entities.tool_entities import ToolInvokeMeta
  27. from core.tools.tool_engine import ToolEngine
  28. from models.model import Message
  29. logger = logging.getLogger(__name__)
  30. class FunctionCallAgentRunner(BaseAgentRunner):
  31. def run(self, message: Message, query: str, **kwargs: Any) -> Generator[LLMResultChunk, None, None]:
  32. """
  33. Run FunctionCall agent application
  34. """
  35. self.query = query
  36. app_generate_entity = self.application_generate_entity
  37. app_config = self.app_config
  38. # convert tools into ModelRuntime Tool format
  39. tool_instances, prompt_messages_tools = self._init_prompt_tools()
  40. assert app_config.agent
  41. iteration_step = 1
  42. max_iteration_steps = min(app_config.agent.max_iteration, 5) + 1
  43. # continue to run until there is not any tool call
  44. function_call_state = True
  45. llm_usage: dict[str, Optional[LLMUsage]] = {"usage": None}
  46. final_answer = ""
  47. # get tracing instance
  48. trace_manager = app_generate_entity.trace_manager
  49. def increase_usage(final_llm_usage_dict: dict[str, Optional[LLMUsage]], usage: LLMUsage):
  50. if not final_llm_usage_dict["usage"]:
  51. final_llm_usage_dict["usage"] = usage
  52. else:
  53. llm_usage = final_llm_usage_dict["usage"]
  54. llm_usage.prompt_tokens += usage.prompt_tokens
  55. llm_usage.completion_tokens += usage.completion_tokens
  56. llm_usage.prompt_price += usage.prompt_price
  57. llm_usage.completion_price += usage.completion_price
  58. llm_usage.total_price += usage.total_price
  59. model_instance = self.model_instance
  60. while function_call_state and iteration_step <= max_iteration_steps:
  61. function_call_state = False
  62. if iteration_step == max_iteration_steps:
  63. # the last iteration, remove all tools
  64. prompt_messages_tools = []
  65. message_file_ids = []
  66. agent_thought = self.create_agent_thought(
  67. message_id=message.id, message="", tool_name="", tool_input="", messages_ids=message_file_ids
  68. )
  69. # recalc llm max tokens
  70. prompt_messages = self._organize_prompt_messages()
  71. self.recalc_llm_max_tokens(self.model_config, prompt_messages)
  72. # invoke model
  73. chunks: Union[Generator[LLMResultChunk, None, None], LLMResult] = model_instance.invoke_llm(
  74. prompt_messages=prompt_messages,
  75. model_parameters=app_generate_entity.model_conf.parameters,
  76. tools=prompt_messages_tools,
  77. stop=app_generate_entity.model_conf.stop,
  78. stream=self.stream_tool_call,
  79. user=self.user_id,
  80. callbacks=[],
  81. )
  82. tool_calls: list[tuple[str, str, dict[str, Any]]] = []
  83. # save full response
  84. response = ""
  85. # save tool call names and inputs
  86. tool_call_names = ""
  87. tool_call_inputs = ""
  88. current_llm_usage = None
  89. if isinstance(chunks, Generator):
  90. is_first_chunk = True
  91. for chunk in chunks:
  92. if is_first_chunk:
  93. self.queue_manager.publish(
  94. QueueAgentThoughtEvent(agent_thought_id=agent_thought.id), PublishFrom.APPLICATION_MANAGER
  95. )
  96. is_first_chunk = False
  97. # check if there is any tool call
  98. if self.check_tool_calls(chunk):
  99. function_call_state = True
  100. tool_calls.extend(self.extract_tool_calls(chunk))
  101. tool_call_names = ";".join([tool_call[1] for tool_call in tool_calls])
  102. try:
  103. tool_call_inputs = json.dumps(
  104. {tool_call[1]: tool_call[2] for tool_call in tool_calls}, ensure_ascii=False
  105. )
  106. except json.JSONDecodeError as e:
  107. # ensure ascii to avoid encoding error
  108. tool_call_inputs = json.dumps({tool_call[1]: tool_call[2] for tool_call in tool_calls})
  109. if chunk.delta.message and chunk.delta.message.content:
  110. if isinstance(chunk.delta.message.content, list):
  111. for content in chunk.delta.message.content:
  112. response += content.data
  113. else:
  114. response += chunk.delta.message.content
  115. if chunk.delta.usage:
  116. increase_usage(llm_usage, chunk.delta.usage)
  117. current_llm_usage = chunk.delta.usage
  118. yield chunk
  119. else:
  120. result = chunks
  121. # check if there is any tool call
  122. if self.check_blocking_tool_calls(result):
  123. function_call_state = True
  124. tool_calls.extend(self.extract_blocking_tool_calls(result))
  125. tool_call_names = ";".join([tool_call[1] for tool_call in tool_calls])
  126. try:
  127. tool_call_inputs = json.dumps(
  128. {tool_call[1]: tool_call[2] for tool_call in tool_calls}, ensure_ascii=False
  129. )
  130. except json.JSONDecodeError as e:
  131. # ensure ascii to avoid encoding error
  132. tool_call_inputs = json.dumps({tool_call[1]: tool_call[2] for tool_call in tool_calls})
  133. if result.usage:
  134. increase_usage(llm_usage, result.usage)
  135. current_llm_usage = result.usage
  136. if result.message and result.message.content:
  137. if isinstance(result.message.content, list):
  138. for content in result.message.content:
  139. response += content.data
  140. else:
  141. response += result.message.content
  142. if not result.message.content:
  143. result.message.content = ""
  144. self.queue_manager.publish(
  145. QueueAgentThoughtEvent(agent_thought_id=agent_thought.id), PublishFrom.APPLICATION_MANAGER
  146. )
  147. yield LLMResultChunk(
  148. model=model_instance.model,
  149. prompt_messages=result.prompt_messages,
  150. system_fingerprint=result.system_fingerprint,
  151. delta=LLMResultChunkDelta(
  152. index=0,
  153. message=result.message,
  154. usage=result.usage,
  155. ),
  156. )
  157. assistant_message = AssistantPromptMessage(content="", tool_calls=[])
  158. if tool_calls:
  159. assistant_message.tool_calls = [
  160. AssistantPromptMessage.ToolCall(
  161. id=tool_call[0],
  162. type="function",
  163. function=AssistantPromptMessage.ToolCall.ToolCallFunction(
  164. name=tool_call[1], arguments=json.dumps(tool_call[2], ensure_ascii=False)
  165. ),
  166. )
  167. for tool_call in tool_calls
  168. ]
  169. else:
  170. assistant_message.content = response
  171. self._current_thoughts.append(assistant_message)
  172. # save thought
  173. self.save_agent_thought(
  174. agent_thought=agent_thought,
  175. tool_name=tool_call_names,
  176. tool_input=tool_call_inputs,
  177. thought=response,
  178. tool_invoke_meta=None,
  179. observation=None,
  180. answer=response,
  181. messages_ids=[],
  182. llm_usage=current_llm_usage,
  183. )
  184. self.queue_manager.publish(
  185. QueueAgentThoughtEvent(agent_thought_id=agent_thought.id), PublishFrom.APPLICATION_MANAGER
  186. )
  187. final_answer += response + "\n"
  188. # call tools
  189. tool_responses = []
  190. for tool_call_id, tool_call_name, tool_call_args in tool_calls:
  191. tool_instance = tool_instances.get(tool_call_name)
  192. if not tool_instance:
  193. tool_response = {
  194. "tool_call_id": tool_call_id,
  195. "tool_call_name": tool_call_name,
  196. "tool_response": f"there is not a tool named {tool_call_name}",
  197. "meta": ToolInvokeMeta.error_instance(f"there is not a tool named {tool_call_name}").to_dict(),
  198. }
  199. else:
  200. # invoke tool
  201. tool_invoke_response, message_files, tool_invoke_meta = ToolEngine.agent_invoke(
  202. tool=tool_instance,
  203. tool_parameters=tool_call_args,
  204. user_id=self.user_id,
  205. tenant_id=self.tenant_id,
  206. message=self.message,
  207. invoke_from=self.application_generate_entity.invoke_from,
  208. agent_tool_callback=self.agent_callback,
  209. trace_manager=trace_manager,
  210. app_id=self.application_generate_entity.app_config.app_id,
  211. message_id=self.message.id,
  212. conversation_id=self.conversation.id,
  213. )
  214. # publish files
  215. for message_file_id, save_as in message_files:
  216. # publish message file
  217. self.queue_manager.publish(
  218. QueueMessageFileEvent(message_file_id=message_file_id.id), PublishFrom.APPLICATION_MANAGER
  219. )
  220. # add message file ids
  221. message_file_ids.append(message_file_id.id)
  222. tool_response = {
  223. "tool_call_id": tool_call_id,
  224. "tool_call_name": tool_call_name,
  225. "tool_response": tool_invoke_response,
  226. "meta": tool_invoke_meta.to_dict(),
  227. }
  228. tool_responses.append(tool_response)
  229. if tool_response["tool_response"] is not None:
  230. self._current_thoughts.append(
  231. ToolPromptMessage(
  232. content=tool_response["tool_response"],
  233. tool_call_id=tool_call_id,
  234. name=tool_call_name,
  235. )
  236. )
  237. if len(tool_responses) > 0:
  238. # save agent thought
  239. self.save_agent_thought(
  240. agent_thought=agent_thought,
  241. tool_name=None,
  242. tool_input=None,
  243. thought=None,
  244. tool_invoke_meta={
  245. tool_response["tool_call_name"]: tool_response["meta"] for tool_response in tool_responses
  246. },
  247. observation={
  248. tool_response["tool_call_name"]: tool_response["tool_response"]
  249. for tool_response in tool_responses
  250. },
  251. answer=None,
  252. messages_ids=message_file_ids,
  253. )
  254. self.queue_manager.publish(
  255. QueueAgentThoughtEvent(agent_thought_id=agent_thought.id), PublishFrom.APPLICATION_MANAGER
  256. )
  257. # update prompt tool
  258. for prompt_tool in prompt_messages_tools:
  259. self.update_prompt_message_tool(tool_instances[prompt_tool.name], prompt_tool)
  260. iteration_step += 1
  261. # publish end event
  262. self.queue_manager.publish(
  263. QueueMessageEndEvent(
  264. llm_result=LLMResult(
  265. model=model_instance.model,
  266. prompt_messages=prompt_messages,
  267. message=AssistantPromptMessage(content=final_answer),
  268. usage=llm_usage["usage"] or LLMUsage.empty_usage(),
  269. system_fingerprint="",
  270. )
  271. ),
  272. PublishFrom.APPLICATION_MANAGER,
  273. )
  274. def check_tool_calls(self, llm_result_chunk: LLMResultChunk) -> bool:
  275. """
  276. Check if there is any tool call in llm result chunk
  277. """
  278. if llm_result_chunk.delta.message.tool_calls:
  279. return True
  280. return False
  281. def check_blocking_tool_calls(self, llm_result: LLMResult) -> bool:
  282. """
  283. Check if there is any blocking tool call in llm result
  284. """
  285. if llm_result.message.tool_calls:
  286. return True
  287. return False
  288. def extract_tool_calls(self, llm_result_chunk: LLMResultChunk) -> list[tuple[str, str, dict[str, Any]]]:
  289. """
  290. Extract tool calls from llm result chunk
  291. Returns:
  292. List[Tuple[str, str, Dict[str, Any]]]: [(tool_call_id, tool_call_name, tool_call_args)]
  293. """
  294. tool_calls = []
  295. for prompt_message in llm_result_chunk.delta.message.tool_calls:
  296. args = {}
  297. if prompt_message.function.arguments != "":
  298. args = json.loads(prompt_message.function.arguments)
  299. tool_calls.append(
  300. (
  301. prompt_message.id,
  302. prompt_message.function.name,
  303. args,
  304. )
  305. )
  306. return tool_calls
  307. def extract_blocking_tool_calls(self, llm_result: LLMResult) -> list[tuple[str, str, dict[str, Any]]]:
  308. """
  309. Extract blocking tool calls from llm result
  310. Returns:
  311. List[Tuple[str, str, Dict[str, Any]]]: [(tool_call_id, tool_call_name, tool_call_args)]
  312. """
  313. tool_calls = []
  314. for prompt_message in llm_result.message.tool_calls:
  315. args = {}
  316. if prompt_message.function.arguments != "":
  317. args = json.loads(prompt_message.function.arguments)
  318. tool_calls.append(
  319. (
  320. prompt_message.id,
  321. prompt_message.function.name,
  322. args,
  323. )
  324. )
  325. return tool_calls
  326. def _init_system_message(self, prompt_template: str, prompt_messages: list[PromptMessage]) -> list[PromptMessage]:
  327. """
  328. Initialize system message
  329. """
  330. if not prompt_messages and prompt_template:
  331. return [
  332. SystemPromptMessage(content=prompt_template),
  333. ]
  334. if prompt_messages and not isinstance(prompt_messages[0], SystemPromptMessage) and prompt_template:
  335. prompt_messages.insert(0, SystemPromptMessage(content=prompt_template))
  336. return prompt_messages
  337. def _organize_user_query(self, query, prompt_messages: list[PromptMessage]) -> list[PromptMessage]:
  338. """
  339. Organize user query
  340. """
  341. if self.files:
  342. prompt_message_contents: list[PromptMessageContent] = []
  343. prompt_message_contents.append(TextPromptMessageContent(data=query))
  344. # get image detail config
  345. image_detail_config = (
  346. self.application_generate_entity.file_upload_config.image_config.detail
  347. if (
  348. self.application_generate_entity.file_upload_config
  349. and self.application_generate_entity.file_upload_config.image_config
  350. )
  351. else None
  352. )
  353. image_detail_config = image_detail_config or ImagePromptMessageContent.DETAIL.LOW
  354. for file in self.files:
  355. prompt_message_contents.append(
  356. file_manager.to_prompt_message_content(
  357. file,
  358. image_detail_config=image_detail_config,
  359. )
  360. )
  361. prompt_messages.append(UserPromptMessage(content=prompt_message_contents))
  362. else:
  363. prompt_messages.append(UserPromptMessage(content=query))
  364. return prompt_messages
  365. def _clear_user_prompt_image_messages(self, prompt_messages: list[PromptMessage]) -> list[PromptMessage]:
  366. """
  367. As for now, gpt supports both fc and vision at the first iteration.
  368. We need to remove the image messages from the prompt messages at the first iteration.
  369. """
  370. prompt_messages = deepcopy(prompt_messages)
  371. for prompt_message in prompt_messages:
  372. if isinstance(prompt_message, UserPromptMessage):
  373. if isinstance(prompt_message.content, list):
  374. prompt_message.content = "\n".join(
  375. [
  376. content.data
  377. if content.type == PromptMessageContentType.TEXT
  378. else "[image]"
  379. if content.type == PromptMessageContentType.IMAGE
  380. else "[file]"
  381. for content in prompt_message.content
  382. ]
  383. )
  384. return prompt_messages
  385. def _organize_prompt_messages(self):
  386. prompt_template = self.app_config.prompt_template.simple_prompt_template or ""
  387. self.history_prompt_messages = self._init_system_message(prompt_template, self.history_prompt_messages)
  388. query_prompt_messages = self._organize_user_query(self.query, [])
  389. self.history_prompt_messages = AgentHistoryPromptTransform(
  390. model_config=self.model_config,
  391. prompt_messages=[*query_prompt_messages, *self._current_thoughts],
  392. history_messages=self.history_prompt_messages,
  393. memory=self.memory,
  394. ).get_prompt()
  395. prompt_messages = [*self.history_prompt_messages, *query_prompt_messages, *self._current_thoughts]
  396. if len(self._current_thoughts) != 0:
  397. # clear messages after the first iteration
  398. prompt_messages = self._clear_user_prompt_image_messages(prompt_messages)
  399. return prompt_messages