Browse Source

fix: agent node

Yeuoly 6 months ago
parent
commit
296fd82bbf
1 changed files with 50 additions and 5 deletions
  1. 50 5
      api/core/workflow/nodes/agent/agent_node.py

+ 50 - 5
api/core/workflow/nodes/agent/agent_node.py

@@ -1,8 +1,10 @@
 from collections.abc import Generator
-from typing import cast
+from typing import Any, Sequence, cast
 
+from core.agent.plugin_entities import AgentParameter
 from core.plugin.manager.exc import PluginDaemonClientSideError
 from core.workflow.entities.node_entities import NodeRunResult
+from core.workflow.entities.variable_pool import VariablePool
 from core.workflow.nodes.agent.entities import AgentNodeData
 from core.workflow.nodes.enums import NodeType
 from core.workflow.nodes.event.event import RunCompletedEvent
@@ -46,14 +48,14 @@ class AgentNode(ToolNode):
 
         # get parameters
         parameters = self._generate_parameters(
-            tool_parameters=agent_parameters,
+            agent_parameters=agent_parameters,
             variable_pool=self.graph_runtime_state.variable_pool,
-            node_data=self.node_data,
+            node_data=node_data,
         )
         parameters_for_log = self._generate_parameters(
-            tool_parameters=agent_parameters,
+            agent_parameters=agent_parameters,
             variable_pool=self.graph_runtime_state.variable_pool,
-            node_data=self.node_data,
+            node_data=node_data,
             for_log=True,
         )
 
@@ -84,3 +86,46 @@ class AgentNode(ToolNode):
                     error=f"Failed to transform agent message: {str(e)}",
                 )
             )
+
+    def _generate_parameters(
+        self,
+        *,
+        agent_parameters: Sequence[AgentParameter],
+        variable_pool: VariablePool,
+        node_data: AgentNodeData,
+        for_log: bool = False,
+    ) -> dict[str, Any]:
+        """
+        Generate parameters based on the given tool parameters, variable pool, and node data.
+
+        Args:
+            tool_parameters (Sequence[ToolParameter]): The list of tool parameters.
+            variable_pool (VariablePool): The variable pool containing the variables.
+            node_data (ToolNodeData): The data associated with the tool node.
+
+        Returns:
+            Mapping[str, Any]: A dictionary containing the generated parameters.
+
+        """
+        agent_parameters_dictionary = {parameter.name: parameter for parameter in agent_parameters}
+
+        result = {}
+        for parameter_name in node_data.agent_parameters:
+            parameter = agent_parameters_dictionary.get(parameter_name)
+            if not parameter:
+                result[parameter_name] = None
+                continue
+            agent_input = node_data.agent_parameters[parameter_name]
+            if agent_input.type == "variable":
+                variable = variable_pool.get(agent_input.value)
+                if variable is None:
+                    raise ValueError(f"Variable {agent_input.value} does not exist")
+                parameter_value = variable.value
+            elif agent_input.type in {"mixed", "constant"}:
+                segment_group = variable_pool.convert_template(str(agent_input.value))
+                parameter_value = segment_group.log if for_log else segment_group.text
+            else:
+                raise ValueError(f"Unknown agent input type '{agent_input.type}'")
+            result[parameter_name] = parameter_value
+
+        return result