Browse Source

feat: support agent node

Yeuoly 4 months ago
parent
commit
ae72514cb4

+ 9 - 1
api/core/agent/entities.py

@@ -3,7 +3,7 @@ from typing import Any, Optional, Union
 
 from pydantic import BaseModel
 
-from core.tools.entities.tool_entities import ToolProviderType
+from core.tools.entities.tool_entities import ToolInvokeMessage, ToolProviderType
 
 
 class AgentToolEntity(BaseModel):
@@ -83,3 +83,11 @@ class AgentEntity(BaseModel):
     prompt: Optional[AgentPromptEntity] = None
     tools: Optional[list[AgentToolEntity]] = None
     max_iteration: int = 5
+
+
+class AgentInvokeMessage(ToolInvokeMessage):
+    """
+    Agent Invoke Message.
+    """
+
+    pass

+ 41 - 0
api/core/agent/plugin_entities.py

@@ -0,0 +1,41 @@
+from typing import Optional
+from pydantic import BaseModel, ConfigDict, Field, ValidationInfo, field_validator
+
+from core.tools.entities.common_entities import I18nObject
+from core.tools.entities.tool_entities import ToolIdentity, ToolParameter, ToolProviderIdentity
+
+
+class AgentProviderIdentity(ToolProviderIdentity):
+    pass
+
+
+class AgentParameter(ToolParameter):
+    pass
+
+
+class AgentProviderEntity(BaseModel):
+    identity: AgentProviderIdentity
+    plugin_id: Optional[str] = Field(None, description="The id of the plugin")
+
+
+class AgentIdentity(ToolIdentity):
+    pass
+
+
+class AgentStrategyEntity(BaseModel):
+    identity: AgentIdentity
+    parameters: list[AgentParameter] = Field(default_factory=list)
+    description: I18nObject = Field(..., description="The description of the agent strategy")
+    output_schema: Optional[dict] = None
+
+    # pydantic configs
+    model_config = ConfigDict(protected_namespaces=())
+
+    @field_validator("parameters", mode="before")
+    @classmethod
+    def set_parameters(cls, v, validation_info: ValidationInfo) -> list[AgentParameter]:
+        return v or []
+
+
+class AgentProviderEntityWithPlugin(AgentProviderEntity):
+    strategies: list[AgentStrategyEntity] = Field(default_factory=list)

+ 41 - 0
api/core/agent/strategy/base.py

@@ -0,0 +1,41 @@
+from abc import ABC, abstractmethod
+from typing import Any, Generator, Optional, Sequence
+
+from core.agent.entities import AgentInvokeMessage
+from core.agent.plugin_entities import AgentParameter
+
+
+class BaseAgentStrategy(ABC):
+    """
+    Agent Strategy
+    """
+
+    def invoke(
+        self,
+        params: dict[str, Any],
+        user_id: str,
+        conversation_id: Optional[str] = None,
+        app_id: Optional[str] = None,
+        message_id: Optional[str] = None,
+    ) -> Generator[AgentInvokeMessage, None, None]:
+        """
+        Invoke the agent strategy.
+        """
+        yield from self._invoke(params, user_id, conversation_id, app_id, message_id)
+
+    def get_parameters(self) -> Sequence[AgentParameter]:
+        """
+        Get the parameters for the agent strategy.
+        """
+        return []
+
+    @abstractmethod
+    def _invoke(
+        self,
+        params: dict[str, Any],
+        user_id: str,
+        conversation_id: Optional[str] = None,
+        app_id: Optional[str] = None,
+        message_id: Optional[str] = None,
+    ) -> Generator[AgentInvokeMessage, None, None]:
+        pass

+ 52 - 0
api/core/agent/strategy/plugin.py

@@ -0,0 +1,52 @@
+from typing import Any, Generator, Optional, Sequence
+
+from core.agent.entities import AgentInvokeMessage
+from core.agent.plugin_entities import AgentParameter, AgentStrategyEntity
+from core.agent.strategy.base import BaseAgentStrategy
+from core.plugin.manager.agent import PluginAgentManager
+from core.tools.plugin_tool.tool import PluginTool
+
+
+class PluginAgentStrategy(BaseAgentStrategy):
+    """
+    Agent Strategy
+    """
+
+    tenant_id: str
+    plugin_unique_identifier: str
+    declaration: AgentStrategyEntity
+
+    def __init__(self, tenant_id: str, plugin_unique_identifier: str, declaration: AgentStrategyEntity):
+        self.tenant_id = tenant_id
+        self.plugin_unique_identifier = plugin_unique_identifier
+        self.declaration = declaration
+
+    def get_parameters(self) -> Sequence[AgentParameter]:
+        return self.declaration.parameters
+
+    def _invoke(
+        self,
+        params: dict[str, Any],
+        user_id: str,
+        conversation_id: Optional[str] = None,
+        app_id: Optional[str] = None,
+        message_id: Optional[str] = None,
+    ) -> Generator[AgentInvokeMessage, None, None]:
+        """
+        Invoke the agent strategy.
+        """
+        manager = PluginAgentManager()
+
+        # convert agent parameters with File type to PluginFileEntity
+        params = PluginTool._transform_image_parameters(params)
+
+        yield from manager.invoke(
+            tenant_id=self.tenant_id,
+            user_id=user_id,
+            agent_provider=self.declaration.identity.provider,
+            agent_strategy=self.declaration.identity.name,
+            agent_params=params,
+            conversation_id=conversation_id,
+            app_id=app_id,
+            message_id=message_id,
+        )

+ 8 - 0
api/core/plugin/entities/plugin_daemon.py

@@ -5,6 +5,7 @@ from typing import Generic, Optional, TypeVar
 
 from pydantic import BaseModel, ConfigDict, Field
 
+from core.agent.plugin_entities import AgentProviderEntityWithPlugin
 from core.model_runtime.entities.model_entities import AIModelEntity
 from core.model_runtime.entities.provider_entities import ProviderEntity
 from core.plugin.entities.base import BasePluginEntity
@@ -46,6 +47,13 @@ class PluginToolProviderEntity(BaseModel):
     declaration: ToolProviderEntityWithPlugin
 
 
+class PluginAgentProviderEntity(BaseModel):
+    provider: str
+    plugin_unique_identifier: str
+    plugin_id: str
+    declaration: AgentProviderEntityWithPlugin
+
+
 class PluginBasicBooleanResponse(BaseModel):
     """
     Basic boolean response from plugin daemon.

+ 110 - 0
api/core/plugin/manager/agent.py

@@ -0,0 +1,110 @@
+from collections.abc import Generator
+from typing import Any, Optional
+
+
+from core.agent.entities import AgentInvokeMessage
+from core.plugin.entities.plugin import GenericProviderID
+from core.plugin.entities.plugin_daemon import (
+    PluginAgentProviderEntity,
+)
+from core.plugin.manager.base import BasePluginManager
+
+
+class PluginAgentManager(BasePluginManager):
+    def fetch_agent_providers(self, tenant_id: str) -> list[PluginAgentProviderEntity]:
+        """
+        Fetch agent providers for the given tenant.
+        """
+
+        def transformer(json_response: dict[str, Any]) -> dict:
+            for provider in json_response.get("data", []):
+                declaration = provider.get("declaration", {}) or {}
+                provider_name = declaration.get("identity", {}).get("name")
+                for tool in declaration.get("tools", []):
+                    tool["identity"]["provider"] = provider_name
+
+            return json_response
+
+        response = self._request_with_plugin_daemon_response(
+            "GET",
+            f"plugin/{tenant_id}/management/agents",
+            list[PluginAgentProviderEntity],
+            params={"page": 1, "page_size": 256},
+            transformer=transformer,
+        )
+
+        for provider in response:
+            provider.declaration.identity.name = f"{provider.plugin_id}/{provider.declaration.identity.name}"
+
+            # override the provider name for each tool to plugin_id/provider_name
+            for strategy in provider.declaration.strategies:
+                strategy.identity.provider = provider.declaration.identity.name
+
+        return response
+
+    def fetch_agent_provider(self, tenant_id: str, provider: str) -> PluginAgentProviderEntity:
+        """
+        Fetch tool provider for the given tenant and plugin.
+        """
+        agent_provider_id = GenericProviderID(provider)
+
+        def transformer(json_response: dict[str, Any]) -> dict:
+            for strategy in json_response.get("data", {}).get("declaration", {}).get("strategies", []):
+                strategy["identity"]["provider"] = agent_provider_id.provider_name
+
+            return json_response
+
+        response = self._request_with_plugin_daemon_response(
+            "GET",
+            f"plugin/{tenant_id}/management/agent",
+            PluginAgentProviderEntity,
+            params={"provider": agent_provider_id.provider_name, "plugin_id": agent_provider_id.plugin_id},
+            transformer=transformer,
+        )
+
+        response.declaration.identity.name = f"{response.plugin_id}/{response.declaration.identity.name}"
+
+        # override the provider name for each tool to plugin_id/provider_name
+        for strategy in response.declaration.strategies:
+            strategy.identity.provider = response.declaration.identity.name
+
+        return response
+
+    def invoke(
+        self,
+        tenant_id: str,
+        user_id: str,
+        agent_provider: str,
+        agent_strategy: str,
+        agent_params: dict[str, Any],
+        conversation_id: Optional[str] = None,
+        app_id: Optional[str] = None,
+        message_id: Optional[str] = None,
+    ) -> Generator[AgentInvokeMessage, None, None]:
+        """
+        Invoke the agent with the given tenant, user, plugin, provider, name and parameters.
+        """
+
+        agent_provider_id = GenericProviderID(agent_provider)
+
+        response = self._request_with_plugin_daemon_response_stream(
+            "POST",
+            f"plugin/{tenant_id}/dispatch/agent/invoke",
+            AgentInvokeMessage,
+            data={
+                "user_id": user_id,
+                "conversation_id": conversation_id,
+                "app_id": app_id,
+                "message_id": message_id,
+                "data": {
+                    "provider": agent_provider_id.provider_name,
+                    "strategy": agent_strategy,
+                    "agent_params": agent_params,
+                },
+            },
+            headers={
+                "X-Plugin-ID": agent_provider_id.plugin_id,
+                "Content-Type": "application/json",
+            },
+        )
+        return response

+ 0 - 0
api/core/tools/entities/agent_entities.py


+ 21 - 16
api/core/tools/plugin_tool/tool.py

@@ -27,23 +27,14 @@ class PluginTool(Tool):
     def tool_provider_type(self) -> ToolProviderType:
         return ToolProviderType.PLUGIN
 
-    def _invoke(
-        self,
-        user_id: str,
-        tool_parameters: dict[str, Any],
-        conversation_id: Optional[str] = None,
-        app_id: Optional[str] = None,
-        message_id: Optional[str] = None,
-    ) -> Generator[ToolInvokeMessage, None, None]:
-        manager = PluginToolManager()
-
-        # convert tool parameters with File type to PluginFileEntity
-        for parameter_name, parameter in tool_parameters.items():
+    @classmethod
+    def _transform_image_parameters(cls, parameters: dict[str, Any]) -> dict[str, Any]:
+        for parameter_name, parameter in parameters.items():
             if isinstance(parameter, File):
                 url = parameter.generate_url()
                 if url is None:
                     raise ValueError(f"File {parameter.id} does not have a valid URL")
-                tool_parameters[parameter_name] = PluginFileEntity(
+                parameters[parameter_name] = PluginFileEntity(
                     url=url,
                     mime_type=parameter.mime_type,
                     type=parameter.type,
@@ -52,13 +43,13 @@ class PluginTool(Tool):
                     size=parameter.size,
                 ).model_dump()
             elif isinstance(parameter, list) and all(isinstance(p, File) for p in parameter):
-                tool_parameters[parameter_name] = []
+                parameters[parameter_name] = []
                 for p in parameter:
                     assert isinstance(p, File)
                     url = p.generate_url()
                     if url is None:
                         raise ValueError(f"File {p.id} does not have a valid URL")
-                    tool_parameters[parameter_name].append(
+                    parameters[parameter_name].append(
                         PluginFileEntity(
                             url=url,
                             mime_type=p.mime_type,
@@ -68,8 +59,22 @@ class PluginTool(Tool):
                             size=p.size,
                         ).model_dump()
                     )
+        return parameters
+
+    def _invoke(
+        self,
+        user_id: str,
+        tool_parameters: dict[str, Any],
+        conversation_id: Optional[str] = None,
+        app_id: Optional[str] = None,
+        message_id: Optional[str] = None,
+    ) -> Generator[ToolInvokeMessage, None, None]:
+        manager = PluginToolManager()
+
+        # convert tool parameters with File type to PluginFileEntity
+        tool_parameters = self._transform_image_parameters(tool_parameters)
 
-        return manager.invoke(
+        yield from manager.invoke(
             tenant_id=self.tenant_id,
             user_id=user_id,
             tool_provider=self.entity.identity.provider,

+ 3 - 0
api/core/workflow/nodes/agent/__init__.py

@@ -0,0 +1,3 @@
+from .agent_node import AgentNode
+
+__all__ = ["AgentNode"]

+ 85 - 0
api/core/workflow/nodes/agent/agent_node.py

@@ -0,0 +1,85 @@
+from collections.abc import Generator
+from typing import cast
+from core.plugin.manager.exc import PluginDaemonClientSideError
+from core.workflow.entities.node_entities import NodeRunResult
+from core.workflow.nodes.agent.entities import AgentNodeData
+from core.workflow.nodes.enums import NodeType
+from core.workflow.nodes.event.event import RunCompletedEvent
+from core.workflow.nodes.tool.tool_node import ToolNode
+from factories.agent_factory import get_plugin_agent_strategy
+from models.workflow import WorkflowNodeExecutionStatus
+
+
+class AgentNode(ToolNode):
+    """
+    Agent Node
+    """
+
+    _node_data_cls = AgentNodeData
+    _node_type = NodeType.AGENT
+
+    def _run(self) -> Generator:
+        """
+        Run the agent node
+        """
+        node_data = cast(AgentNodeData, self.node_data)
+
+        try:
+            strategy = get_plugin_agent_strategy(
+                tenant_id=self.tenant_id,
+                plugin_unique_identifier=node_data.plugin_unique_identifier,
+                agent_strategy_provider_name=node_data.agent_strategy_provider_name,
+                agent_strategy_name=node_data.agent_strategy_name,
+            )
+        except Exception as e:
+            yield RunCompletedEvent(
+                run_result=NodeRunResult(
+                    status=WorkflowNodeExecutionStatus.FAILED,
+                    inputs={},
+                    error=f"Failed to get agent strategy: {str(e)}",
+                )
+            )
+            return
+
+        agent_parameters = strategy.get_parameters()
+
+        # get parameters
+        parameters = self._generate_parameters(
+            tool_parameters=agent_parameters,
+            variable_pool=self.graph_runtime_state.variable_pool,
+            node_data=self.node_data,
+        )
+        parameters_for_log = self._generate_parameters(
+            tool_parameters=agent_parameters,
+            variable_pool=self.graph_runtime_state.variable_pool,
+            node_data=self.node_data,
+            for_log=True,
+        )
+
+        try:
+            message_stream = strategy.invoke(
+                params=parameters,
+                user_id=self.user_id,
+                app_id=self.app_id,
+                # TODO: conversation id and message id
+            )
+        except Exception as e:
+            yield RunCompletedEvent(
+                run_result=NodeRunResult(
+                    status=WorkflowNodeExecutionStatus.FAILED,
+                    inputs=parameters_for_log,
+                    error=f"Failed to invoke agent: {str(e)}",
+                )
+            )
+
+        try:
+            # convert tool messages
+            yield from self._transform_message(message_stream, {}, parameters_for_log)
+        except PluginDaemonClientSideError as e:
+            yield RunCompletedEvent(
+                run_result=NodeRunResult(
+                    status=WorkflowNodeExecutionStatus.FAILED,
+                    inputs=parameters_for_log,
+                    error=f"Failed to transform agent message: {str(e)}",
+                )
+            )

+ 51 - 0
api/core/workflow/nodes/agent/entities.py

@@ -0,0 +1,51 @@
+from typing import Any, Literal, Union
+from pydantic import BaseModel, ValidationInfo, field_validator
+
+from core.workflow.nodes.base.entities import BaseNodeData
+
+
+class AgentEntity(BaseModel):
+    agent_strategy_provider_name: str  # redundancy
+    agent_strategy_name: str
+    agent_strategy_label: str  # redundancy
+    agent_configurations: dict[str, Any]
+    plugin_unique_identifier: str
+
+    @field_validator("agent_configurations", mode="before")
+    @classmethod
+    def validate_agent_configurations(cls, value, values: ValidationInfo):
+        if not isinstance(value, dict):
+            raise ValueError("agent_configurations must be a dictionary")
+
+        for key in values.data.get("agent_configurations", {}):
+            value = values.data.get("agent_configurations", {}).get(key)
+            if not isinstance(value, str | int | float | bool):
+                raise ValueError(f"{key} must be a string")
+
+        return value
+
+
+class AgentNodeData(BaseNodeData, AgentEntity):
+    class AgentInput(BaseModel):
+        # TODO: check this type
+        value: Union[Any, list[str]]
+        type: Literal["mixed", "variable", "constant"]
+
+        @field_validator("type", mode="before")
+        @classmethod
+        def check_type(cls, value, validation_info: ValidationInfo):
+            typ = value
+            value = validation_info.data.get("value")
+            if typ == "mixed" and not isinstance(value, str):
+                raise ValueError("value must be a string")
+            elif typ == "variable":
+                if not isinstance(value, list):
+                    raise ValueError("value must be a list")
+                for val in value:
+                    if not isinstance(val, str):
+                        raise ValueError("value must be a list of strings")
+            elif typ == "constant" and not isinstance(value, str | int | float | bool):
+                raise ValueError("value must be a string, int, float, or bool")
+            return typ
+
+    agent_parameters: dict[str, AgentInput]

+ 1 - 0
api/core/workflow/nodes/enums.py

@@ -22,3 +22,4 @@ class NodeType(StrEnum):
     VARIABLE_ASSIGNER = "assigner"
     DOCUMENT_EXTRACTOR = "document-extractor"
     LIST_OPERATOR = "list-operator"
+    AGENT = "agent"

+ 5 - 0
api/core/workflow/nodes/node_mapping.py

@@ -1,5 +1,6 @@
 from collections.abc import Mapping
 
+from core.workflow.nodes.agent.agent_node import AgentNode
 from core.workflow.nodes.answer import AnswerNode
 from core.workflow.nodes.base import BaseNode
 from core.workflow.nodes.code import CodeNode
@@ -101,4 +102,8 @@ NODE_TYPE_CLASSES_MAPPING: Mapping[NodeType, Mapping[str, type[BaseNode]]] = {
         LATEST_VERSION: ListOperatorNode,
         "1": ListOperatorNode,
     },
+    NodeType.AGENT: {
+        LATEST_VERSION: AgentNode,
+        "1": AgentNode,
+    },
 }

+ 2 - 2
api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py

@@ -234,8 +234,8 @@ class ParameterExtractorNode(LLMNode):
         if not isinstance(invoke_result, LLMResult):
             raise InvalidInvokeResultError(f"Invalid invoke result: {invoke_result}")
 
-        text = invoke_result.message.content
-        if not isinstance(text, str | None):
+        text = invoke_result.message.content or ""
+        if not isinstance(text, str):
             raise InvalidTextContentTypeError(f"Invalid text content type: {type(text)}. Expected str.")
 
         usage = invoke_result.usage

+ 15 - 0
api/factories/agent_factory.py

@@ -0,0 +1,15 @@
+from core.agent.strategy.plugin import PluginAgentStrategy
+from core.plugin.manager.agent import PluginAgentManager
+
+
+def get_plugin_agent_strategy(
+    tenant_id: str, plugin_unique_identifier: str, agent_strategy_provider_name: str, agent_strategy_name: str
+) -> PluginAgentStrategy:
+    # TODO: use contexts to cache the agent provider
+    manager = PluginAgentManager()
+    agent_provider = manager.fetch_agent_provider(tenant_id, agent_strategy_provider_name)
+    for agent_strategy in agent_provider.declaration.strategies:
+        if agent_strategy.identity.name == agent_strategy_name:
+            return PluginAgentStrategy(tenant_id, plugin_unique_identifier, agent_strategy)
+
+    raise ValueError(f"Agent strategy {agent_strategy_name} not found")