Переглянути джерело

refactor: rename agent to agent strategy

Yeuoly 4 місяців тому
батько
коміт
3c628d0c26

+ 9 - 9
api/core/agent/plugin_entities.py

@@ -6,26 +6,26 @@ from core.tools.entities.common_entities import I18nObject
 from core.tools.entities.tool_entities import ToolIdentity, ToolParameter, ToolProviderIdentity
 
 
-class AgentProviderIdentity(ToolProviderIdentity):
+class AgentStrategyProviderIdentity(ToolProviderIdentity):
     pass
 
 
-class AgentParameter(ToolParameter):
+class AgentStrategyParameter(ToolParameter):
     pass
 
 
-class AgentProviderEntity(BaseModel):
-    identity: AgentProviderIdentity
+class AgentStrategyProviderEntity(BaseModel):
+    identity: AgentStrategyProviderIdentity
     plugin_id: Optional[str] = Field(None, description="The id of the plugin")
 
 
-class AgentIdentity(ToolIdentity):
+class AgentStrategyIdentity(ToolIdentity):
     pass
 
 
 class AgentStrategyEntity(BaseModel):
-    identity: AgentIdentity
-    parameters: list[AgentParameter] = Field(default_factory=list)
+    identity: AgentStrategyIdentity
+    parameters: list[AgentStrategyParameter] = Field(default_factory=list)
     description: I18nObject = Field(..., description="The description of the agent strategy")
     output_schema: Optional[dict] = None
 
@@ -34,9 +34,9 @@ class AgentStrategyEntity(BaseModel):
 
     @field_validator("parameters", mode="before")
     @classmethod
-    def set_parameters(cls, v, validation_info: ValidationInfo) -> list[AgentParameter]:
+    def set_parameters(cls, v, validation_info: ValidationInfo) -> list[AgentStrategyParameter]:
         return v or []
 
 
-class AgentProviderEntityWithPlugin(AgentProviderEntity):
+class AgentProviderEntityWithPlugin(AgentStrategyProviderEntity):
     strategies: list[AgentStrategyEntity] = Field(default_factory=list)

+ 2 - 2
api/core/agent/strategy/base.py

@@ -2,7 +2,7 @@ from abc import ABC, abstractmethod
 from typing import Any, Generator, Optional, Sequence
 
 from core.agent.entities import AgentInvokeMessage
-from core.agent.plugin_entities import AgentParameter
+from core.agent.plugin_entities import AgentStrategyParameter
 
 
 class BaseAgentStrategy(ABC):
@@ -23,7 +23,7 @@ class BaseAgentStrategy(ABC):
         """
         yield from self._invoke(params, user_id, conversation_id, app_id, message_id)
 
-    def get_parameters(self) -> Sequence[AgentParameter]:
+    def get_parameters(self) -> Sequence[AgentStrategyParameter]:
         """
         Get the parameters for the agent strategy.
         """

+ 2 - 2
api/core/agent/strategy/plugin.py

@@ -1,7 +1,7 @@
 from typing import Any, Generator, Optional, Sequence
 
 from core.agent.entities import AgentInvokeMessage
-from core.agent.plugin_entities import AgentParameter, AgentStrategyEntity
+from core.agent.plugin_entities import AgentStrategyParameter, AgentStrategyEntity
 from core.agent.strategy.base import BaseAgentStrategy
 from core.plugin.manager.agent import PluginAgentManager
 from core.tools.plugin_tool.tool import PluginTool
@@ -21,7 +21,7 @@ class PluginAgentStrategy(BaseAgentStrategy):
         self.plugin_unique_identifier = plugin_unique_identifier
         self.declaration = declaration
 
-    def get_parameters(self) -> Sequence[AgentParameter]:
+    def get_parameters(self) -> Sequence[AgentStrategyParameter]:
         return self.declaration.parameters
 
     def _invoke(

+ 1 - 1
api/core/plugin/backwards_invocation/model.py

@@ -43,7 +43,7 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation):
         # invoke model
         response = model_instance.invoke_llm(
             prompt_messages=payload.prompt_messages,
-            model_parameters=payload.model_parameters,
+            model_parameters=payload.completion_params,
             tools=payload.tools,
             stop=payload.stop,
             stream=payload.stream or True,

+ 5 - 0
api/core/plugin/entities/plugin.py

@@ -6,6 +6,7 @@ from typing import Any, Optional
 
 from pydantic import BaseModel, Field, model_validator
 
+from core.agent.plugin_entities import AgentStrategyProviderEntity
 from core.model_runtime.entities.provider_entities import ProviderEntity
 from core.plugin.entities.base import BasePluginEntity
 from core.plugin.entities.endpoint import EndpointProviderDeclaration
@@ -59,6 +60,7 @@ class PluginCategory(enum.StrEnum):
     Tool = "tool"
     Model = "model"
     Extension = "extension"
+    AgentStrategy = "agent_strategy"
 
 
 class PluginDeclaration(BaseModel):
@@ -82,6 +84,7 @@ class PluginDeclaration(BaseModel):
     tool: Optional[ToolProviderEntity] = None
     model: Optional[ProviderEntity] = None
     endpoint: Optional[EndpointProviderDeclaration] = None
+    agent_strategy: Optional[AgentStrategyProviderEntity] = None
 
     @model_validator(mode="before")
     @classmethod
@@ -91,6 +94,8 @@ class PluginDeclaration(BaseModel):
             values["category"] = PluginCategory.Tool
         elif values.get("model"):
             values["category"] = PluginCategory.Model
+        elif values.get("agent_strategy"):
+            values["category"] = PluginCategory.AgentStrategy
         else:
             values["category"] = PluginCategory.Extension
         return values

+ 1 - 1
api/core/plugin/entities/request.py

@@ -53,7 +53,7 @@ class RequestInvokeLLM(BaseRequestInvokeModel):
 
     model_type: ModelType = ModelType.LLM
     mode: str
-    model_parameters: dict[str, Any] = Field(default_factory=dict)
+    completion_params: dict[str, Any] = Field(default_factory=dict)
     prompt_messages: list[PromptMessage] = Field(default_factory=list)
     tools: Optional[list[PromptMessageTool]] = Field(default_factory=list)
     stop: Optional[list[str]] = Field(default_factory=list)

+ 7 - 7
api/core/plugin/manager/agent.py

@@ -10,7 +10,7 @@ from core.plugin.manager.base import BasePluginManager
 
 
 class PluginAgentManager(BasePluginManager):
-    def fetch_agent_providers(self, tenant_id: str) -> list[PluginAgentProviderEntity]:
+    def fetch_agent_strategy_providers(self, tenant_id: str) -> list[PluginAgentProviderEntity]:
         """
         Fetch agent providers for the given tenant.
         """
@@ -26,7 +26,7 @@ class PluginAgentManager(BasePluginManager):
 
         response = self._request_with_plugin_daemon_response(
             "GET",
-            f"plugin/{tenant_id}/management/agents",
+            f"plugin/{tenant_id}/management/agent_strategies",
             list[PluginAgentProviderEntity],
             params={"page": 1, "page_size": 256},
             transformer=transformer,
@@ -41,7 +41,7 @@ class PluginAgentManager(BasePluginManager):
 
         return response
 
-    def fetch_agent_provider(self, tenant_id: str, provider: str) -> PluginAgentProviderEntity:
+    def fetch_agent_strategy_provider(self, tenant_id: str, provider: str) -> PluginAgentProviderEntity:
         """
         Fetch tool provider for the given tenant and plugin.
         """
@@ -55,7 +55,7 @@ class PluginAgentManager(BasePluginManager):
 
         response = self._request_with_plugin_daemon_response(
             "GET",
-            f"plugin/{tenant_id}/management/agent",
+            f"plugin/{tenant_id}/management/agent_strategy",
             PluginAgentProviderEntity,
             params={"provider": agent_provider_id.provider_name, "plugin_id": agent_provider_id.plugin_id},
             transformer=transformer,
@@ -96,9 +96,9 @@ class PluginAgentManager(BasePluginManager):
                 "app_id": app_id,
                 "message_id": message_id,
                 "data": {
-                    "provider": agent_provider_id.provider_name,
-                    "strategy": agent_strategy,
-                    "agent_params": agent_params,
+                    "agent_strategy_provider": agent_provider_id.provider_name,
+                    "agent_strategy": agent_strategy,
+                    "agent_strategy_params": agent_params,
                 },
             },
             headers={

+ 2 - 2
api/core/workflow/nodes/agent/agent_node.py

@@ -1,7 +1,7 @@
 from collections.abc import Generator
 from typing import Any, Sequence, cast
 
-from core.agent.plugin_entities import AgentParameter
+from core.agent.plugin_entities import AgentStrategyParameter
 from core.plugin.manager.exc import PluginDaemonClientSideError
 from core.workflow.entities.node_entities import NodeRunResult
 from core.workflow.entities.variable_pool import VariablePool
@@ -90,7 +90,7 @@ class AgentNode(ToolNode):
     def _generate_parameters(
         self,
         *,
-        agent_parameters: Sequence[AgentParameter],
+        agent_parameters: Sequence[AgentStrategyParameter],
         variable_pool: VariablePool,
         node_data: AgentNodeData,
         for_log: bool = False,

+ 1 - 1
api/core/workflow/nodes/tool/tool_node.py

@@ -246,7 +246,7 @@ class ToolNode(BaseNode[ToolNodeData]):
                 )
             elif message.type == ToolInvokeMessage.MessageType.TEXT:
                 assert isinstance(message.message, ToolInvokeMessage.TextMessage)
-                text += message.message.text + "\n"
+                text += message.message.text
                 yield RunStreamChunkEvent(
                     chunk_content=message.message.text, from_variable_selector=[self.node_id, "text"]
                 )

+ 1 - 1
api/factories/agent_factory.py

@@ -7,7 +7,7 @@ def get_plugin_agent_strategy(
 ) -> PluginAgentStrategy:
     # TODO: use contexts to cache the agent provider
     manager = PluginAgentManager()
-    agent_provider = manager.fetch_agent_provider(tenant_id, agent_strategy_provider_name)
+    agent_provider = manager.fetch_agent_strategy_provider(tenant_id, agent_strategy_provider_name)
     for agent_strategy in agent_provider.declaration.strategies:
         if agent_strategy.identity.name == agent_strategy_name:
             return PluginAgentStrategy(tenant_id, plugin_unique_identifier, agent_strategy)

+ 2 - 2
api/services/agent_service.py

@@ -160,7 +160,7 @@ class AgentService:
         List agent providers
         """
         manager = PluginAgentManager()
-        return manager.fetch_agent_providers(tenant_id)
+        return manager.fetch_agent_strategy_providers(tenant_id)
 
     @classmethod
     def get_agent_provider(cls, user_id: str, tenant_id: str, provider_name: str):
@@ -168,4 +168,4 @@ class AgentService:
         Get agent provider
         """
         manager = PluginAgentManager()
-        return manager.fetch_agent_provider(tenant_id, provider_name)
+        return manager.fetch_agent_strategy_provider(tenant_id, provider_name)