123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532 |
- import json
- import logging
- import uuid
- from collections.abc import Mapping, Sequence
- from datetime import datetime, timezone
- from typing import Optional, Union, cast
- from core.agent.entities import AgentEntity, AgentToolEntity
- from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
- from core.app.apps.agent_chat.app_config_manager import AgentChatAppConfig
- from core.app.apps.base_app_queue_manager import AppQueueManager
- from core.app.apps.base_app_runner import AppRunner
- from core.app.entities.app_invoke_entities import (
- AgentChatAppGenerateEntity,
- ModelConfigWithCredentialsEntity,
- )
- from core.callback_handler.agent_tool_callback_handler import DifyAgentCallbackHandler
- from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
- from core.file import file_manager
- from core.memory.token_buffer_memory import TokenBufferMemory
- from core.model_manager import ModelInstance
- from core.model_runtime.entities import (
- AssistantPromptMessage,
- LLMUsage,
- PromptMessage,
- PromptMessageContent,
- PromptMessageTool,
- SystemPromptMessage,
- TextPromptMessageContent,
- ToolPromptMessage,
- UserPromptMessage,
- )
- from core.model_runtime.entities.model_entities import ModelFeature
- from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
- from core.model_runtime.utils.encoders import jsonable_encoder
- from core.prompt.utils.extract_thread_messages import extract_thread_messages
- from core.tools.entities.tool_entities import (
- ToolParameter,
- ToolRuntimeVariablePool,
- )
- from core.tools.tool.dataset_retriever_tool import DatasetRetrieverTool
- from core.tools.tool.tool import Tool
- from core.tools.tool_manager import ToolManager
- from extensions.ext_database import db
- from factories import file_factory
- from models.model import Conversation, Message, MessageAgentThought, MessageFile
- from models.tools import ToolConversationVariables
- logger = logging.getLogger(__name__)
- class BaseAgentRunner(AppRunner):
- def __init__(
- self,
- tenant_id: str,
- application_generate_entity: AgentChatAppGenerateEntity,
- conversation: Conversation,
- app_config: AgentChatAppConfig,
- model_config: ModelConfigWithCredentialsEntity,
- config: AgentEntity,
- queue_manager: AppQueueManager,
- message: Message,
- user_id: str,
- memory: Optional[TokenBufferMemory] = None,
- prompt_messages: Optional[list[PromptMessage]] = None,
- variables_pool: Optional[ToolRuntimeVariablePool] = None,
- db_variables: Optional[ToolConversationVariables] = None,
- model_instance: ModelInstance = None,
- ) -> None:
- self.tenant_id = tenant_id
- self.application_generate_entity = application_generate_entity
- self.conversation = conversation
- self.app_config = app_config
- self.model_config = model_config
- self.config = config
- self.queue_manager = queue_manager
- self.message = message
- self.user_id = user_id
- self.memory = memory
- self.history_prompt_messages = self.organize_agent_history(prompt_messages=prompt_messages or [])
- self.variables_pool = variables_pool
- self.db_variables_pool = db_variables
- self.model_instance = model_instance
- # init callback
- self.agent_callback = DifyAgentCallbackHandler()
- # init dataset tools
- hit_callback = DatasetIndexToolCallbackHandler(
- queue_manager=queue_manager,
- app_id=self.app_config.app_id,
- message_id=message.id,
- user_id=user_id,
- invoke_from=self.application_generate_entity.invoke_from,
- )
- self.dataset_tools = DatasetRetrieverTool.get_dataset_tools(
- tenant_id=tenant_id,
- dataset_ids=app_config.dataset.dataset_ids if app_config.dataset else [],
- retrieve_config=app_config.dataset.retrieve_config if app_config.dataset else None,
- return_resource=app_config.additional_features.show_retrieve_source,
- invoke_from=application_generate_entity.invoke_from,
- hit_callback=hit_callback,
- )
- # get how many agent thoughts have been created
- self.agent_thought_count = (
- db.session.query(MessageAgentThought)
- .filter(
- MessageAgentThought.message_id == self.message.id,
- )
- .count()
- )
- db.session.close()
- # check if model supports stream tool call
- llm_model = cast(LargeLanguageModel, model_instance.model_type_instance)
- model_schema = llm_model.get_model_schema(model_instance.model, model_instance.credentials)
- if model_schema and ModelFeature.STREAM_TOOL_CALL in (model_schema.features or []):
- self.stream_tool_call = True
- else:
- self.stream_tool_call = False
- # check if model supports vision
- if model_schema and ModelFeature.VISION in (model_schema.features or []):
- self.files = application_generate_entity.files
- else:
- self.files = []
- self.query = None
- self._current_thoughts: list[PromptMessage] = []
- def _repack_app_generate_entity(
- self, app_generate_entity: AgentChatAppGenerateEntity
- ) -> AgentChatAppGenerateEntity:
- """
- Repack app generate entity
- """
- if app_generate_entity.app_config.prompt_template.simple_prompt_template is None:
- app_generate_entity.app_config.prompt_template.simple_prompt_template = ""
- return app_generate_entity
- def _convert_tool_to_prompt_message_tool(self, tool: AgentToolEntity) -> tuple[PromptMessageTool, Tool]:
- """
- convert tool to prompt message tool
- """
- tool_entity = ToolManager.get_agent_tool_runtime(
- tenant_id=self.tenant_id,
- app_id=self.app_config.app_id,
- agent_tool=tool,
- invoke_from=self.application_generate_entity.invoke_from,
- )
- tool_entity.load_variables(self.variables_pool)
- message_tool = PromptMessageTool(
- name=tool.tool_name,
- description=tool_entity.description.llm,
- parameters={
- "type": "object",
- "properties": {},
- "required": [],
- },
- )
- parameters = tool_entity.get_all_runtime_parameters()
- for parameter in parameters:
- if parameter.form != ToolParameter.ToolParameterForm.LLM:
- continue
- parameter_type = parameter.type.as_normal_type()
- if parameter.type in {
- ToolParameter.ToolParameterType.SYSTEM_FILES,
- ToolParameter.ToolParameterType.FILE,
- ToolParameter.ToolParameterType.FILES,
- }:
- continue
- enum = []
- if parameter.type == ToolParameter.ToolParameterType.SELECT:
- enum = [option.value for option in parameter.options]
- message_tool.parameters["properties"][parameter.name] = {
- "type": parameter_type,
- "description": parameter.llm_description or "",
- }
- if len(enum) > 0:
- message_tool.parameters["properties"][parameter.name]["enum"] = enum
- if parameter.required:
- message_tool.parameters["required"].append(parameter.name)
- return message_tool, tool_entity
- def _convert_dataset_retriever_tool_to_prompt_message_tool(self, tool: DatasetRetrieverTool) -> PromptMessageTool:
- """
- convert dataset retriever tool to prompt message tool
- """
- prompt_tool = PromptMessageTool(
- name=tool.identity.name,
- description=tool.description.llm,
- parameters={
- "type": "object",
- "properties": {},
- "required": [],
- },
- )
- for parameter in tool.get_runtime_parameters():
- parameter_type = "string"
- prompt_tool.parameters["properties"][parameter.name] = {
- "type": parameter_type,
- "description": parameter.llm_description or "",
- }
- if parameter.required:
- if parameter.name not in prompt_tool.parameters["required"]:
- prompt_tool.parameters["required"].append(parameter.name)
- return prompt_tool
- def _init_prompt_tools(self) -> tuple[Mapping[str, Tool], Sequence[PromptMessageTool]]:
- """
- Init tools
- """
- tool_instances = {}
- prompt_messages_tools = []
- for tool in self.app_config.agent.tools if self.app_config.agent else []:
- try:
- prompt_tool, tool_entity = self._convert_tool_to_prompt_message_tool(tool)
- except Exception:
- # api tool may be deleted
- continue
- # save tool entity
- tool_instances[tool.tool_name] = tool_entity
- # save prompt tool
- prompt_messages_tools.append(prompt_tool)
- # convert dataset tools into ModelRuntime Tool format
- for dataset_tool in self.dataset_tools:
- prompt_tool = self._convert_dataset_retriever_tool_to_prompt_message_tool(dataset_tool)
- # save prompt tool
- prompt_messages_tools.append(prompt_tool)
- # save tool entity
- tool_instances[dataset_tool.identity.name] = dataset_tool
- return tool_instances, prompt_messages_tools
- def update_prompt_message_tool(self, tool: Tool, prompt_tool: PromptMessageTool) -> PromptMessageTool:
- """
- update prompt message tool
- """
- # try to get tool runtime parameters
- tool_runtime_parameters = tool.get_runtime_parameters() or []
- for parameter in tool_runtime_parameters:
- if parameter.form != ToolParameter.ToolParameterForm.LLM:
- continue
- parameter_type = parameter.type.as_normal_type()
- if parameter.type in {
- ToolParameter.ToolParameterType.SYSTEM_FILES,
- ToolParameter.ToolParameterType.FILE,
- ToolParameter.ToolParameterType.FILES,
- }:
- continue
- enum = []
- if parameter.type == ToolParameter.ToolParameterType.SELECT:
- enum = [option.value for option in parameter.options]
- prompt_tool.parameters["properties"][parameter.name] = {
- "type": parameter_type,
- "description": parameter.llm_description or "",
- }
- if len(enum) > 0:
- prompt_tool.parameters["properties"][parameter.name]["enum"] = enum
- if parameter.required:
- if parameter.name not in prompt_tool.parameters["required"]:
- prompt_tool.parameters["required"].append(parameter.name)
- return prompt_tool
- def create_agent_thought(
- self, message_id: str, message: str, tool_name: str, tool_input: str, messages_ids: list[str]
- ) -> MessageAgentThought:
- """
- Create agent thought
- """
- thought = MessageAgentThought(
- message_id=message_id,
- message_chain_id=None,
- thought="",
- tool=tool_name,
- tool_labels_str="{}",
- tool_meta_str="{}",
- tool_input=tool_input,
- message=message,
- message_token=0,
- message_unit_price=0,
- message_price_unit=0,
- message_files=json.dumps(messages_ids) if messages_ids else "",
- answer="",
- observation="",
- answer_token=0,
- answer_unit_price=0,
- answer_price_unit=0,
- tokens=0,
- total_price=0,
- position=self.agent_thought_count + 1,
- currency="USD",
- latency=0,
- created_by_role="account",
- created_by=self.user_id,
- )
- db.session.add(thought)
- db.session.commit()
- db.session.refresh(thought)
- db.session.close()
- self.agent_thought_count += 1
- return thought
- def save_agent_thought(
- self,
- agent_thought: MessageAgentThought,
- tool_name: str,
- tool_input: Union[str, dict],
- thought: str,
- observation: Union[str, dict],
- tool_invoke_meta: Union[str, dict],
- answer: str,
- messages_ids: list[str],
- llm_usage: LLMUsage = None,
- ) -> MessageAgentThought:
- """
- Save agent thought
- """
- agent_thought = db.session.query(MessageAgentThought).filter(MessageAgentThought.id == agent_thought.id).first()
- if thought is not None:
- agent_thought.thought = thought
- if tool_name is not None:
- agent_thought.tool = tool_name
- if tool_input is not None:
- if isinstance(tool_input, dict):
- try:
- tool_input = json.dumps(tool_input, ensure_ascii=False)
- except Exception as e:
- tool_input = json.dumps(tool_input)
- agent_thought.tool_input = tool_input
- if observation is not None:
- if isinstance(observation, dict):
- try:
- observation = json.dumps(observation, ensure_ascii=False)
- except Exception as e:
- observation = json.dumps(observation)
- agent_thought.observation = observation
- if answer is not None:
- agent_thought.answer = answer
- if messages_ids is not None and len(messages_ids) > 0:
- agent_thought.message_files = json.dumps(messages_ids)
- if llm_usage:
- agent_thought.message_token = llm_usage.prompt_tokens
- agent_thought.message_price_unit = llm_usage.prompt_price_unit
- agent_thought.message_unit_price = llm_usage.prompt_unit_price
- agent_thought.answer_token = llm_usage.completion_tokens
- agent_thought.answer_price_unit = llm_usage.completion_price_unit
- agent_thought.answer_unit_price = llm_usage.completion_unit_price
- agent_thought.tokens = llm_usage.total_tokens
- agent_thought.total_price = llm_usage.total_price
- # check if tool labels is not empty
- labels = agent_thought.tool_labels or {}
- tools = agent_thought.tool.split(";") if agent_thought.tool else []
- for tool in tools:
- if not tool:
- continue
- if tool not in labels:
- tool_label = ToolManager.get_tool_label(tool)
- if tool_label:
- labels[tool] = tool_label.to_dict()
- else:
- labels[tool] = {"en_US": tool, "zh_Hans": tool}
- agent_thought.tool_labels_str = json.dumps(labels)
- if tool_invoke_meta is not None:
- if isinstance(tool_invoke_meta, dict):
- try:
- tool_invoke_meta = json.dumps(tool_invoke_meta, ensure_ascii=False)
- except Exception as e:
- tool_invoke_meta = json.dumps(tool_invoke_meta)
- agent_thought.tool_meta_str = tool_invoke_meta
- db.session.commit()
- db.session.close()
- def update_db_variables(self, tool_variables: ToolRuntimeVariablePool, db_variables: ToolConversationVariables):
- """
- convert tool variables to db variables
- """
- db_variables = (
- db.session.query(ToolConversationVariables)
- .filter(
- ToolConversationVariables.conversation_id == self.message.conversation_id,
- )
- .first()
- )
- db_variables.updated_at = datetime.now(timezone.utc).replace(tzinfo=None)
- db_variables.variables_str = json.dumps(jsonable_encoder(tool_variables.pool))
- db.session.commit()
- db.session.close()
- def organize_agent_history(self, prompt_messages: list[PromptMessage]) -> list[PromptMessage]:
- """
- Organize agent history
- """
- result = []
- # check if there is a system message in the beginning of the conversation
- for prompt_message in prompt_messages:
- if isinstance(prompt_message, SystemPromptMessage):
- result.append(prompt_message)
- messages: list[Message] = (
- db.session.query(Message)
- .filter(
- Message.conversation_id == self.message.conversation_id,
- )
- .order_by(Message.created_at.desc())
- .all()
- )
- messages = list(reversed(extract_thread_messages(messages)))
- for message in messages:
- if message.id == self.message.id:
- continue
- result.append(self.organize_agent_user_prompt(message))
- agent_thoughts: list[MessageAgentThought] = message.agent_thoughts
- if agent_thoughts:
- for agent_thought in agent_thoughts:
- tools = agent_thought.tool
- if tools:
- tools = tools.split(";")
- tool_calls: list[AssistantPromptMessage.ToolCall] = []
- tool_call_response: list[ToolPromptMessage] = []
- try:
- tool_inputs = json.loads(agent_thought.tool_input)
- except Exception as e:
- tool_inputs = {tool: {} for tool in tools}
- try:
- tool_responses = json.loads(agent_thought.observation)
- except Exception as e:
- tool_responses = dict.fromkeys(tools, agent_thought.observation)
- for tool in tools:
- # generate a uuid for tool call
- tool_call_id = str(uuid.uuid4())
- tool_calls.append(
- AssistantPromptMessage.ToolCall(
- id=tool_call_id,
- type="function",
- function=AssistantPromptMessage.ToolCall.ToolCallFunction(
- name=tool,
- arguments=json.dumps(tool_inputs.get(tool, {})),
- ),
- )
- )
- tool_call_response.append(
- ToolPromptMessage(
- content=tool_responses.get(tool, agent_thought.observation),
- name=tool,
- tool_call_id=tool_call_id,
- )
- )
- result.extend(
- [
- AssistantPromptMessage(
- content=agent_thought.thought,
- tool_calls=tool_calls,
- ),
- *tool_call_response,
- ]
- )
- if not tools:
- result.append(AssistantPromptMessage(content=agent_thought.thought))
- else:
- if message.answer:
- result.append(AssistantPromptMessage(content=message.answer))
- db.session.close()
- return result
- def organize_agent_user_prompt(self, message: Message) -> UserPromptMessage:
- files = db.session.query(MessageFile).filter(MessageFile.message_id == message.id).all()
- if files:
- file_extra_config = FileUploadConfigManager.convert(message.app_model_config.to_dict())
- if file_extra_config:
- file_objs = file_factory.build_from_message_files(
- message_files=files, tenant_id=self.tenant_id, config=file_extra_config
- )
- else:
- file_objs = []
- if not file_objs:
- return UserPromptMessage(content=message.query)
- else:
- prompt_message_contents: list[PromptMessageContent] = []
- prompt_message_contents.append(TextPromptMessageContent(data=message.query))
- for file_obj in file_objs:
- prompt_message_contents.append(file_manager.to_prompt_message_content(file_obj))
- return UserPromptMessage(content=prompt_message_contents)
- else:
- return UserPromptMessage(content=message.query)
|