cot_chat_agent_runner.py 3.6 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798
  1. import json
  2. from core.agent.cot_agent_runner import CotAgentRunner
  3. from core.model_runtime.entities.message_entities import (
  4. AssistantPromptMessage,
  5. PromptMessage,
  6. PromptMessageContent,
  7. SystemPromptMessage,
  8. TextPromptMessageContent,
  9. UserPromptMessage,
  10. )
  11. from core.model_runtime.utils.encoders import jsonable_encoder
  12. class CotChatAgentRunner(CotAgentRunner):
  13. def _organize_system_prompt(self) -> SystemPromptMessage:
  14. """
  15. Organize system prompt
  16. """
  17. assert self.app_config.agent
  18. assert self.app_config.agent.prompt
  19. prompt_entity = self.app_config.agent.prompt
  20. first_prompt = prompt_entity.first_prompt
  21. system_prompt = (
  22. first_prompt.replace("{{instruction}}", self._instruction)
  23. .replace("{{tools}}", json.dumps(jsonable_encoder(self._prompt_messages_tools)))
  24. .replace("{{tool_names}}", ", ".join([tool.name for tool in self._prompt_messages_tools]))
  25. )
  26. return SystemPromptMessage(content=system_prompt)
  27. def _organize_user_query(self, query, prompt_messages: list[PromptMessage]) -> list[PromptMessage]:
  28. """
  29. Organize user query
  30. """
  31. if self.files:
  32. prompt_message_contents: list[PromptMessageContent] = [TextPromptMessageContent(data=query)]
  33. for file_obj in self.files:
  34. prompt_message_contents.append(file_obj.prompt_message_content)
  35. prompt_messages.append(UserPromptMessage(content=prompt_message_contents))
  36. else:
  37. prompt_messages.append(UserPromptMessage(content=query))
  38. return prompt_messages
  39. def _organize_prompt_messages(self) -> list[PromptMessage]:
  40. """
  41. Organize
  42. """
  43. # organize system prompt
  44. system_message = self._organize_system_prompt()
  45. # organize current assistant messages
  46. agent_scratchpad = self._agent_scratchpad
  47. if not agent_scratchpad:
  48. assistant_messages = []
  49. else:
  50. assistant_message = AssistantPromptMessage(content="")
  51. for unit in agent_scratchpad:
  52. if unit.is_final():
  53. assert isinstance(assistant_message.content, str)
  54. assistant_message.content += f"Final Answer: {unit.agent_response}"
  55. else:
  56. assert isinstance(assistant_message.content, str)
  57. assistant_message.content += f"Thought: {unit.thought}\n\n"
  58. if unit.action_str:
  59. assistant_message.content += f"Action: {unit.action_str}\n\n"
  60. if unit.observation:
  61. assistant_message.content += f"Observation: {unit.observation}\n\n"
  62. assistant_messages = [assistant_message]
  63. # query messages
  64. query_messages = self._organize_user_query(self._query, [])
  65. if assistant_messages:
  66. # organize historic prompt messages
  67. historic_messages = self._organize_historic_prompt_messages(
  68. [system_message, *query_messages, *assistant_messages, UserPromptMessage(content="continue")]
  69. )
  70. messages = [
  71. system_message,
  72. *historic_messages,
  73. *query_messages,
  74. *assistant_messages,
  75. UserPromptMessage(content="continue"),
  76. ]
  77. else:
  78. # organize historic prompt messages
  79. historic_messages = self._organize_historic_prompt_messages([system_message, *query_messages])
  80. messages = [system_message, *historic_messages, *query_messages]
  81. # join all messages
  82. return messages