|
@@ -1,5 +1,6 @@
|
|
|
import json
|
|
|
import logging
|
|
|
+import uuid
|
|
|
from datetime import datetime
|
|
|
from mimetypes import guess_extension
|
|
|
from typing import Optional, Union, cast
|
|
@@ -20,7 +21,14 @@ from core.file.message_file_parser import FileTransferMethod
|
|
|
from core.memory.token_buffer_memory import TokenBufferMemory
|
|
|
from core.model_manager import ModelInstance
|
|
|
from core.model_runtime.entities.llm_entities import LLMUsage
|
|
|
-from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool
|
|
|
+from core.model_runtime.entities.message_entities import (
|
|
|
+ AssistantPromptMessage,
|
|
|
+ PromptMessage,
|
|
|
+ PromptMessageTool,
|
|
|
+ SystemPromptMessage,
|
|
|
+ ToolPromptMessage,
|
|
|
+ UserPromptMessage,
|
|
|
+)
|
|
|
from core.model_runtime.entities.model_entities import ModelFeature
|
|
|
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
|
|
from core.model_runtime.utils.encoders import jsonable_encoder
|
|
@@ -77,7 +85,9 @@ class BaseAssistantApplicationRunner(AppRunner):
|
|
|
self.message = message
|
|
|
self.user_id = user_id
|
|
|
self.memory = memory
|
|
|
- self.history_prompt_messages = prompt_messages
|
|
|
+ self.history_prompt_messages = self.organize_agent_history(
|
|
|
+ prompt_messages=prompt_messages or []
|
|
|
+ )
|
|
|
self.variables_pool = variables_pool
|
|
|
self.db_variables_pool = db_variables
|
|
|
self.model_instance = model_instance
|
|
@@ -504,17 +514,6 @@ class BaseAssistantApplicationRunner(AppRunner):
|
|
|
agent_thought.tool_labels_str = json.dumps(labels)
|
|
|
|
|
|
db.session.commit()
|
|
|
-
|
|
|
- def get_history_prompt_messages(self) -> list[PromptMessage]:
|
|
|
- """
|
|
|
- Get history prompt messages
|
|
|
- """
|
|
|
- if self.history_prompt_messages is None:
|
|
|
- self.history_prompt_messages = db.session.query(PromptMessage).filter(
|
|
|
- PromptMessage.message_id == self.message.id,
|
|
|
- ).order_by(PromptMessage.position.asc()).all()
|
|
|
-
|
|
|
- return self.history_prompt_messages
|
|
|
|
|
|
def transform_tool_invoke_messages(self, messages: list[ToolInvokeMessage]) -> list[ToolInvokeMessage]:
|
|
|
"""
|
|
@@ -589,4 +588,54 @@ class BaseAssistantApplicationRunner(AppRunner):
|
|
|
"""
|
|
|
db_variables.updated_at = datetime.utcnow()
|
|
|
db_variables.variables_str = json.dumps(jsonable_encoder(tool_variables.pool))
|
|
|
- db.session.commit()
|
|
|
+ db.session.commit()
|
|
|
+
|
|
|
+ def organize_agent_history(self, prompt_messages: list[PromptMessage]) -> list[PromptMessage]:
|
|
|
+ """
|
|
|
+ Organize agent history
|
|
|
+ """
|
|
|
+ result = []
|
|
|
+ # check if there is a system message in the beginning of the conversation
|
|
|
+ if prompt_messages and isinstance(prompt_messages[0], SystemPromptMessage):
|
|
|
+ result.append(prompt_messages[0])
|
|
|
+
|
|
|
+ messages: list[Message] = db.session.query(Message).filter(
|
|
|
+ Message.conversation_id == self.message.conversation_id,
|
|
|
+ ).order_by(Message.created_at.asc()).all()
|
|
|
+
|
|
|
+ for message in messages:
|
|
|
+ result.append(UserPromptMessage(content=message.query))
|
|
|
+ agent_thoughts: list[MessageAgentThought] = message.agent_thoughts
|
|
|
+ for agent_thought in agent_thoughts:
|
|
|
+ tools = agent_thought.tool
|
|
|
+ if tools:
|
|
|
+ tools = tools.split(';')
|
|
|
+ tool_calls: list[AssistantPromptMessage.ToolCall] = []
|
|
|
+ tool_call_response: list[ToolPromptMessage] = []
|
|
|
+ tool_inputs = json.loads(agent_thought.tool_input)
|
|
|
+ for tool in tools:
|
|
|
+ # generate a uuid for tool call
|
|
|
+ tool_call_id = str(uuid.uuid4())
|
|
|
+ tool_calls.append(AssistantPromptMessage.ToolCall(
|
|
|
+ id=tool_call_id,
|
|
|
+ type='function',
|
|
|
+ function=AssistantPromptMessage.ToolCall.ToolCallFunction(
|
|
|
+ name=tool,
|
|
|
+ arguments=json.dumps(tool_inputs.get(tool, {})),
|
|
|
+ )
|
|
|
+ ))
|
|
|
+ tool_call_response.append(ToolPromptMessage(
|
|
|
+ content=agent_thought.observation,
|
|
|
+ name=tool,
|
|
|
+ tool_call_id=tool_call_id,
|
|
|
+ ))
|
|
|
+
|
|
|
+ result.extend([
|
|
|
+ AssistantPromptMessage(
|
|
|
+ content=agent_thought.thought,
|
|
|
+ tool_calls=tool_calls,
|
|
|
+ ),
|
|
|
+ *tool_call_response
|
|
|
+ ])
|
|
|
+
|
|
|
+ return result
|