Преглед на файлове

Refactor agent history organization and initialization of agent scrat… (#2495)

Yeuoly преди 1 година
родител
ревизия
ae3ad59b16
променени са 2 файла, в които са добавени 98 реда и са изтрити 14 реда
  1. 63 14
      api/core/features/assistant_base_runner.py
  2. 35 0
      api/core/features/assistant_cot_runner.py

+ 63 - 14
api/core/features/assistant_base_runner.py

@@ -1,5 +1,6 @@
 import json
 import logging
+import uuid
 from datetime import datetime
 from mimetypes import guess_extension
 from typing import Optional, Union, cast
@@ -20,7 +21,14 @@ from core.file.message_file_parser import FileTransferMethod
 from core.memory.token_buffer_memory import TokenBufferMemory
 from core.model_manager import ModelInstance
 from core.model_runtime.entities.llm_entities import LLMUsage
-from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool
+from core.model_runtime.entities.message_entities import (
+    AssistantPromptMessage,
+    PromptMessage,
+    PromptMessageTool,
+    SystemPromptMessage,
+    ToolPromptMessage,
+    UserPromptMessage,
+)
 from core.model_runtime.entities.model_entities import ModelFeature
 from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
 from core.model_runtime.utils.encoders import jsonable_encoder
@@ -77,7 +85,9 @@ class BaseAssistantApplicationRunner(AppRunner):
         self.message = message
         self.user_id = user_id
         self.memory = memory
-        self.history_prompt_messages = prompt_messages
+        self.history_prompt_messages = self.organize_agent_history(
+            prompt_messages=prompt_messages or []
+        )
         self.variables_pool = variables_pool
         self.db_variables_pool = db_variables
         self.model_instance = model_instance
@@ -504,17 +514,6 @@ class BaseAssistantApplicationRunner(AppRunner):
         agent_thought.tool_labels_str = json.dumps(labels)
 
         db.session.commit()
-
-    def get_history_prompt_messages(self) -> list[PromptMessage]:
-        """
-        Get history prompt messages
-        """
-        if self.history_prompt_messages is None:
-            self.history_prompt_messages = db.session.query(PromptMessage).filter(
-                PromptMessage.message_id == self.message.id,
-            ).order_by(PromptMessage.position.asc()).all()
-
-        return self.history_prompt_messages
     
     def transform_tool_invoke_messages(self, messages: list[ToolInvokeMessage]) -> list[ToolInvokeMessage]:
         """
@@ -589,4 +588,54 @@ class BaseAssistantApplicationRunner(AppRunner):
         """
         db_variables.updated_at = datetime.utcnow()
         db_variables.variables_str = json.dumps(jsonable_encoder(tool_variables.pool))
-        db.session.commit()
+        db.session.commit()
+
+    def organize_agent_history(self, prompt_messages: list[PromptMessage]) -> list[PromptMessage]:
+        """
+        Organize agent history
+        """
+        result = []
+        # check if there is a system message in the beginning of the conversation
+        if prompt_messages and isinstance(prompt_messages[0], SystemPromptMessage):
+            result.append(prompt_messages[0])
+
+        messages: list[Message] = db.session.query(Message).filter(
+            Message.conversation_id == self.message.conversation_id,
+        ).order_by(Message.created_at.asc()).all()
+
+        for message in messages:
+            result.append(UserPromptMessage(content=message.query))
+            agent_thoughts: list[MessageAgentThought] = message.agent_thoughts
+            for agent_thought in agent_thoughts:
+                tools = agent_thought.tool
+                if tools:
+                    tools = tools.split(';')
+                    tool_calls: list[AssistantPromptMessage.ToolCall] = []
+                    tool_call_response: list[ToolPromptMessage] = []
+                    tool_inputs = json.loads(agent_thought.tool_input)
+                    for tool in tools:
+                        # generate a uuid for tool call
+                        tool_call_id = str(uuid.uuid4())
+                        tool_calls.append(AssistantPromptMessage.ToolCall(
+                            id=tool_call_id,
+                            type='function',
+                            function=AssistantPromptMessage.ToolCall.ToolCallFunction(
+                                name=tool,
+                                arguments=json.dumps(tool_inputs.get(tool, {})),
+                            )
+                        ))
+                        tool_call_response.append(ToolPromptMessage(
+                            content=agent_thought.observation,
+                            name=tool,
+                            tool_call_id=tool_call_id,
+                        ))
+
+                    result.extend([
+                        AssistantPromptMessage(
+                            content=agent_thought.thought,
+                            tool_calls=tool_calls,
+                        ),
+                        *tool_call_response
+                    ])
+
+        return result

+ 35 - 0
api/core/features/assistant_cot_runner.py

@@ -12,6 +12,7 @@ from core.model_runtime.entities.message_entities import (
     PromptMessage,
     PromptMessageTool,
     SystemPromptMessage,
+    ToolPromptMessage,
     UserPromptMessage,
 )
 from core.model_runtime.utils.encoders import jsonable_encoder
@@ -39,6 +40,7 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner):
         self._repack_app_orchestration_config(app_orchestration_config)
 
         agent_scratchpad: list[AgentScratchpadUnit] = []
+        self._init_agent_scratchpad(agent_scratchpad, self.history_prompt_messages)
 
         # check model mode
         if self.app_orchestration_config.model_config.mode == "completion":
@@ -327,6 +329,39 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner):
                 continue
 
         return instruction
+    
+    def _init_agent_scratchpad(self, 
+                               agent_scratchpad: list[AgentScratchpadUnit],
+                               messages: list[PromptMessage]
+                               ) -> list[AgentScratchpadUnit]:
+        """
+        init agent scratchpad
+        """
+        current_scratchpad: AgentScratchpadUnit = None
+        for message in messages:
+            if isinstance(message, AssistantPromptMessage):
+                current_scratchpad = AgentScratchpadUnit(
+                    agent_response=message.content,
+                    thought=message.content,
+                    action_str='',
+                    action=None,
+                    observation=None
+                )
+                if message.tool_calls:
+                    try:
+                        current_scratchpad.action = AgentScratchpadUnit.Action(
+                            action_name=message.tool_calls[0].function.name,
+                            action_input=json.loads(message.tool_calls[0].function.arguments)
+                        )
+                    except:
+                        pass
+                    
+                agent_scratchpad.append(current_scratchpad)
+            elif isinstance(message, ToolPromptMessage):
+                if current_scratchpad:
+                    current_scratchpad.observation = message.content
+
+        return agent_scratchpad
 
     def _extract_response_scratchpad(self, content: str) -> AgentScratchpadUnit:
         """