cot_chat_agent_runner.py 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120
  1. import json
  2. from core.agent.cot_agent_runner import CotAgentRunner
  3. from core.file import file_manager
  4. from core.model_runtime.entities import (
  5. AssistantPromptMessage,
  6. PromptMessage,
  7. PromptMessageContent,
  8. SystemPromptMessage,
  9. TextPromptMessageContent,
  10. UserPromptMessage,
  11. )
  12. from core.model_runtime.entities.message_entities import ImagePromptMessageContent
  13. from core.model_runtime.utils.encoders import jsonable_encoder
  14. class CotChatAgentRunner(CotAgentRunner):
  15. def _organize_system_prompt(self) -> SystemPromptMessage:
  16. """
  17. Organize system prompt
  18. """
  19. assert self.app_config.agent
  20. assert self.app_config.agent.prompt
  21. prompt_entity = self.app_config.agent.prompt
  22. if not prompt_entity:
  23. raise ValueError("Agent prompt configuration is not set")
  24. first_prompt = prompt_entity.first_prompt
  25. system_prompt = (
  26. first_prompt.replace("{{instruction}}", self._instruction)
  27. .replace("{{tools}}", json.dumps(jsonable_encoder(self._prompt_messages_tools)))
  28. .replace("{{tool_names}}", ", ".join([tool.name for tool in self._prompt_messages_tools]))
  29. )
  30. return SystemPromptMessage(content=system_prompt)
  31. def _organize_user_query(self, query, prompt_messages: list[PromptMessage]) -> list[PromptMessage]:
  32. """
  33. Organize user query
  34. """
  35. if self.files:
  36. prompt_message_contents: list[PromptMessageContent] = []
  37. prompt_message_contents.append(TextPromptMessageContent(data=query))
  38. # get image detail config
  39. image_detail_config = (
  40. self.application_generate_entity.file_upload_config.image_config.detail
  41. if (
  42. self.application_generate_entity.file_upload_config
  43. and self.application_generate_entity.file_upload_config.image_config
  44. )
  45. else None
  46. )
  47. image_detail_config = image_detail_config or ImagePromptMessageContent.DETAIL.LOW
  48. for file in self.files:
  49. prompt_message_contents.append(
  50. file_manager.to_prompt_message_content(
  51. file,
  52. image_detail_config=image_detail_config,
  53. )
  54. )
  55. prompt_messages.append(UserPromptMessage(content=prompt_message_contents))
  56. else:
  57. prompt_messages.append(UserPromptMessage(content=query))
  58. return prompt_messages
  59. def _organize_prompt_messages(self) -> list[PromptMessage]:
  60. """
  61. Organize
  62. """
  63. # organize system prompt
  64. system_message = self._organize_system_prompt()
  65. # organize current assistant messages
  66. agent_scratchpad = self._agent_scratchpad
  67. if not agent_scratchpad:
  68. assistant_messages = []
  69. else:
  70. assistant_message = AssistantPromptMessage(content="")
  71. assistant_message.content = "" # FIXME: type check tell mypy that assistant_message.content is str
  72. for unit in agent_scratchpad:
  73. if unit.is_final():
  74. assert isinstance(assistant_message.content, str)
  75. assistant_message.content += f"Final Answer: {unit.agent_response}"
  76. else:
  77. assert isinstance(assistant_message.content, str)
  78. assistant_message.content += f"Thought: {unit.thought}\n\n"
  79. if unit.action_str:
  80. assistant_message.content += f"Action: {unit.action_str}\n\n"
  81. if unit.observation:
  82. assistant_message.content += f"Observation: {unit.observation}\n\n"
  83. assistant_messages = [assistant_message]
  84. # query messages
  85. query_messages = self._organize_user_query(self._query, [])
  86. if assistant_messages:
  87. # organize historic prompt messages
  88. historic_messages = self._organize_historic_prompt_messages(
  89. [system_message, *query_messages, *assistant_messages, UserPromptMessage(content="continue")]
  90. )
  91. messages = [
  92. system_message,
  93. *historic_messages,
  94. *query_messages,
  95. *assistant_messages,
  96. UserPromptMessage(content="continue"),
  97. ]
  98. else:
  99. # organize historic prompt messages
  100. historic_messages = self._organize_historic_prompt_messages([system_message, *query_messages])
  101. messages = [system_message, *historic_messages, *query_messages]
  102. # join all messages
  103. return messages