Parcourir la source

fix: tool node

Yeuoly il y a 10 mois
Parent
commit
531ffaec4f
1 fichiers modifiés avec 32 ajouts et 13 suppressions
  1. 32 13
      api/core/workflow/nodes/tool/tool_node.py

+ 32 - 13
api/core/workflow/nodes/tool/tool_node.py

@@ -128,8 +128,10 @@ class ToolNode(BaseNode):
             else:
                 tool_input = node_data.tool_parameters[parameter_name]
                 if tool_input.type == 'variable':
-                    # TODO: check if the variable exists in the variable pool
-                    parameter_value = variable_pool.get(tool_input.value).value
+                    parameter_value_segment = variable_pool.get(tool_input.value)
+                    if not parameter_value_segment:
+                        raise Exception("input variable dose not exists")
+                    parameter_value = parameter_value_segment.value
                 else:
                     segment_group = parser.convert_template(
                         template=str(tool_input.value),
@@ -163,7 +165,7 @@ class ToolNode(BaseNode):
 
         return plain_text, files, json
 
-    def _extract_tool_response_binary(self, tool_response: list[ToolInvokeMessage]) -> list[FileVar]:
+    def _extract_tool_response_binary(self, tool_response: Generator[ToolInvokeMessage, None, None]) -> list[FileVar]:
         """
         Extract tool response binary
         """
@@ -172,7 +174,10 @@ class ToolNode(BaseNode):
         for response in tool_response:
             if response.type == ToolInvokeMessage.MessageType.IMAGE_LINK or \
                     response.type == ToolInvokeMessage.MessageType.IMAGE:
-                url = response.message
+                assert isinstance(response.message, ToolInvokeMessage.TextMessage)
+                assert response.meta
+
+                url = response.message.text
                 ext = path.splitext(url)[1]
                 mimetype = response.meta.get('mime_type', 'image/jpeg')
                 filename = response.save_as or url.split('/')[-1]
@@ -192,7 +197,10 @@ class ToolNode(BaseNode):
                 ))
             elif response.type == ToolInvokeMessage.MessageType.BLOB:
                 # get tool file id
-                tool_file_id = response.message.split('/')[-1].split('.')[0]
+                assert isinstance(response.message, ToolInvokeMessage.TextMessage)
+                assert response.meta
+
+                tool_file_id = response.message.text.split('/')[-1].split('.')[0]
                 result.append(FileVar(
                     tenant_id=self.tenant_id,
                     type=FileType.IMAGE,
@@ -207,18 +215,28 @@ class ToolNode(BaseNode):
 
         return result
 
-    def _extract_tool_response_text(self, tool_response: list[ToolInvokeMessage]) -> str:
+    def _extract_tool_response_text(self, tool_response: Generator[ToolInvokeMessage]) -> str:
         """
         Extract tool response text
         """
-        return '\n'.join([
-            f'{message.message}' if message.type == ToolInvokeMessage.MessageType.TEXT else
-            f'Link: {message.message}' if message.type == ToolInvokeMessage.MessageType.LINK else ''
-            for message in tool_response
-        ])
+        result: list[str] = []
+        for message in tool_response:
+            if message.type == ToolInvokeMessage.MessageType.TEXT:
+                assert isinstance(message.message, ToolInvokeMessage.TextMessage)
+                result.append(message.message.text)
+            elif message.type == ToolInvokeMessage.MessageType.LINK:
+                assert isinstance(message.message, ToolInvokeMessage.TextMessage)
+                result.append(f'Link: {message.message.text}')
 
-    def _extract_tool_response_json(self, tool_response: list[ToolInvokeMessage]) -> list[dict]:
-        return [message.message for message in tool_response if message.type == ToolInvokeMessage.MessageType.JSON]
+        return '\n'.join(result)
+
+    def _extract_tool_response_json(self, tool_response: Generator[ToolInvokeMessage]) -> list[dict]:
+        result: list[dict] = []
+        for message in tool_response:
+            if message.type == ToolInvokeMessage.MessageType.JSON:
+                assert isinstance(message, ToolInvokeMessage.JsonMessage)
+                result.append(message.json_object)
+        return result
 
     @classmethod
     def _extract_variable_selector_to_variable_mapping(cls, node_data: ToolNodeData) -> dict[str, list[str]]:
@@ -231,6 +249,7 @@ class ToolNode(BaseNode):
         for parameter_name in node_data.tool_parameters:
             input = node_data.tool_parameters[parameter_name]
             if input.type == 'mixed':
+                assert isinstance(input.value, str)
                 selectors = VariableTemplateParser(input.value).extract_variable_selectors()
                 for selector in selectors:
                     result[selector.variable] = selector.value_selector