Browse Source

refactor(tool-node): introduce specific exceptions for tool node errors (#10357)

-LAN- 5 months ago
parent
commit
35d3da9697
2 changed files with 31 additions and 9 deletions
  1. 16 0
      api/core/workflow/nodes/tool/exc.py
  2. 15 9
      api/core/workflow/nodes/tool/tool_node.py

+ 16 - 0
api/core/workflow/nodes/tool/exc.py

@@ -0,0 +1,16 @@
+class ToolNodeError(ValueError):
+    """Base exception for tool node errors."""
+
+    pass
+
+
+class ToolParameterError(ToolNodeError):
+    """Exception raised for errors in tool parameters."""
+
+    pass
+
+
+class ToolFileError(ToolNodeError):
+    """Exception raised for errors related to tool files."""
+
+    pass

+ 15 - 9
api/core/workflow/nodes/tool/tool_node.py

@@ -6,7 +6,7 @@ from sqlalchemy import select
 from sqlalchemy.orm import Session
 
 from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler
-from core.file.models import File, FileTransferMethod, FileType
+from core.file import File, FileTransferMethod, FileType
 from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter
 from core.tools.tool_engine import ToolEngine
 from core.tools.tool_manager import ToolManager
@@ -15,12 +15,18 @@ from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResu
 from core.workflow.entities.variable_pool import VariablePool
 from core.workflow.nodes.base import BaseNode
 from core.workflow.nodes.enums import NodeType
-from core.workflow.nodes.tool.entities import ToolNodeData
 from core.workflow.utils.variable_template_parser import VariableTemplateParser
 from extensions.ext_database import db
 from models import ToolFile
 from models.workflow import WorkflowNodeExecutionStatus
 
+from .entities import ToolNodeData
+from .exc import (
+    ToolFileError,
+    ToolNodeError,
+    ToolParameterError,
+)
+
 
 class ToolNode(BaseNode[ToolNodeData]):
     """
@@ -42,7 +48,7 @@ class ToolNode(BaseNode[ToolNodeData]):
             tool_runtime = ToolManager.get_workflow_tool_runtime(
                 self.tenant_id, self.app_id, self.node_id, self.node_data, self.invoke_from
             )
-        except Exception as e:
+        except ToolNodeError as e:
             return NodeRunResult(
                 status=WorkflowNodeExecutionStatus.FAILED,
                 inputs={},
@@ -75,7 +81,7 @@ class ToolNode(BaseNode[ToolNodeData]):
                 workflow_call_depth=self.workflow_call_depth,
                 thread_pool_id=self.thread_pool_id,
             )
-        except Exception as e:
+        except ToolNodeError as e:
             return NodeRunResult(
                 status=WorkflowNodeExecutionStatus.FAILED,
                 inputs=parameters_for_log,
@@ -133,13 +139,13 @@ class ToolNode(BaseNode[ToolNodeData]):
             if tool_input.type == "variable":
                 variable = variable_pool.get(tool_input.value)
                 if variable is None:
-                    raise ValueError(f"variable {tool_input.value} not exists")
+                    raise ToolParameterError(f"Variable {tool_input.value} does not exist")
                 parameter_value = variable.value
             elif tool_input.type in {"mixed", "constant"}:
                 segment_group = variable_pool.convert_template(str(tool_input.value))
                 parameter_value = segment_group.log if for_log else segment_group.text
             else:
-                raise ValueError(f"unknown tool input type '{tool_input.type}'")
+                raise ToolParameterError(f"Unknown tool input type '{tool_input.type}'")
             result[parameter_name] = parameter_value
 
         return result
@@ -181,7 +187,7 @@ class ToolNode(BaseNode[ToolNodeData]):
                     stmt = select(ToolFile).where(ToolFile.id == tool_file_id)
                     tool_file = session.scalar(stmt)
                     if tool_file is None:
-                        raise ValueError(f"tool file {tool_file_id} not exists")
+                        raise ToolFileError(f"Tool file {tool_file_id} does not exist")
 
                 result.append(
                     File(
@@ -203,7 +209,7 @@ class ToolNode(BaseNode[ToolNodeData]):
                     stmt = select(ToolFile).where(ToolFile.id == tool_file_id)
                     tool_file = session.scalar(stmt)
                     if tool_file is None:
-                        raise ValueError(f"tool file {tool_file_id} not exists")
+                        raise ToolFileError(f"Tool file {tool_file_id} does not exist")
                 result.append(
                     File(
                         tenant_id=self.tenant_id,
@@ -224,7 +230,7 @@ class ToolNode(BaseNode[ToolNodeData]):
                     stmt = select(ToolFile).where(ToolFile.id == tool_file_id)
                     tool_file = session.scalar(stmt)
                     if tool_file is None:
-                        raise ValueError(f"tool file {tool_file_id} not exists")
+                        raise ToolFileError(f"Tool file {tool_file_id} does not exist")
                 if "." in url:
                     extension = "." + url.split("/")[-1].split(".")[1]
                 else: