cot_chat_agent_runner.py 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899
  1. import json
  2. from core.agent.cot_agent_runner import CotAgentRunner
  3. from core.file import file_manager
  4. from core.model_runtime.entities import (
  5. AssistantPromptMessage,
  6. PromptMessage,
  7. PromptMessageContent,
  8. SystemPromptMessage,
  9. TextPromptMessageContent,
  10. UserPromptMessage,
  11. )
  12. from core.model_runtime.utils.encoders import jsonable_encoder
  13. class CotChatAgentRunner(CotAgentRunner):
  14. def _organize_system_prompt(self) -> SystemPromptMessage:
  15. """
  16. Organize system prompt
  17. """
  18. assert self.app_config.agent
  19. assert self.app_config.agent.prompt
  20. prompt_entity = self.app_config.agent.prompt
  21. first_prompt = prompt_entity.first_prompt
  22. system_prompt = (
  23. first_prompt.replace("{{instruction}}", self._instruction)
  24. .replace("{{tools}}", json.dumps(jsonable_encoder(self._prompt_messages_tools)))
  25. .replace("{{tool_names}}", ", ".join([tool.name for tool in self._prompt_messages_tools]))
  26. )
  27. return SystemPromptMessage(content=system_prompt)
  28. def _organize_user_query(self, query, prompt_messages: list[PromptMessage]) -> list[PromptMessage]:
  29. """
  30. Organize user query
  31. """
  32. if self.files:
  33. prompt_message_contents: list[PromptMessageContent] = [TextPromptMessageContent(data=query)]
  34. for file_obj in self.files:
  35. prompt_message_contents.append(file_manager.to_prompt_message_content(file_obj))
  36. prompt_messages.append(UserPromptMessage(content=prompt_message_contents))
  37. else:
  38. prompt_messages.append(UserPromptMessage(content=query))
  39. return prompt_messages
  40. def _organize_prompt_messages(self) -> list[PromptMessage]:
  41. """
  42. Organize
  43. """
  44. # organize system prompt
  45. system_message = self._organize_system_prompt()
  46. # organize current assistant messages
  47. agent_scratchpad = self._agent_scratchpad
  48. if not agent_scratchpad:
  49. assistant_messages = []
  50. else:
  51. assistant_message = AssistantPromptMessage(content="")
  52. for unit in agent_scratchpad:
  53. if unit.is_final():
  54. assert isinstance(assistant_message.content, str)
  55. assistant_message.content += f"Final Answer: {unit.agent_response}"
  56. else:
  57. assert isinstance(assistant_message.content, str)
  58. assistant_message.content += f"Thought: {unit.thought}\n\n"
  59. if unit.action_str:
  60. assistant_message.content += f"Action: {unit.action_str}\n\n"
  61. if unit.observation:
  62. assistant_message.content += f"Observation: {unit.observation}\n\n"
  63. assistant_messages = [assistant_message]
  64. # query messages
  65. query_messages = self._organize_user_query(self._query, [])
  66. if assistant_messages:
  67. # organize historic prompt messages
  68. historic_messages = self._organize_historic_prompt_messages(
  69. [system_message, *query_messages, *assistant_messages, UserPromptMessage(content="continue")]
  70. )
  71. messages = [
  72. system_message,
  73. *historic_messages,
  74. *query_messages,
  75. *assistant_messages,
  76. UserPromptMessage(content="continue"),
  77. ]
  78. else:
  79. # organize historic prompt messages
  80. historic_messages = self._organize_historic_prompt_messages([system_message, *query_messages])
  81. messages = [system_message, *historic_messages, *query_messages]
  82. # join all messages
  83. return messages