Przeglądaj źródła

refactor(iteration): introduce specific exceptions for iteration errors (#10366)

-LAN- 5 miesięcy temu
rodzic
commit
f8c958a409

+ 22 - 0
api/core/workflow/nodes/iteration/exc.py

@@ -0,0 +1,22 @@
+class IterationNodeError(ValueError):
+    """Base class for iteration node errors."""
+
+
+class IteratorVariableNotFoundError(IterationNodeError):
+    """Raised when the iterator variable is not found."""
+
+
+class InvalidIteratorValueError(IterationNodeError):
+    """Raised when the iterator value is invalid."""
+
+
+class StartNodeIdNotFoundError(IterationNodeError):
+    """Raised when the start node ID is not found."""
+
+
+class IterationGraphNotFoundError(IterationNodeError):
+    """Raised when the iteration graph is not found."""
+
+
+class IterationIndexNotFoundError(IterationNodeError):
+    """Raised when the iteration index is not found."""

+ 19 - 10
api/core/workflow/nodes/iteration/iteration_node.py

@@ -38,6 +38,15 @@ from core.workflow.nodes.event import NodeEvent, RunCompletedEvent
 from core.workflow.nodes.iteration.entities import ErrorHandleMode, IterationNodeData
 from models.workflow import WorkflowNodeExecutionStatus
 
+from .exc import (
+    InvalidIteratorValueError,
+    IterationGraphNotFoundError,
+    IterationIndexNotFoundError,
+    IterationNodeError,
+    IteratorVariableNotFoundError,
+    StartNodeIdNotFoundError,
+)
+
 if TYPE_CHECKING:
     from core.workflow.graph_engine.graph_engine import GraphEngine
 logger = logging.getLogger(__name__)
@@ -69,7 +78,7 @@ class IterationNode(BaseNode[IterationNodeData]):
         iterator_list_segment = self.graph_runtime_state.variable_pool.get(self.node_data.iterator_selector)
 
         if not iterator_list_segment:
-            raise ValueError(f"Iterator variable {self.node_data.iterator_selector} not found")
+            raise IteratorVariableNotFoundError(f"Iterator variable {self.node_data.iterator_selector} not found")
 
         if len(iterator_list_segment.value) == 0:
             yield RunCompletedEvent(
@@ -83,14 +92,14 @@ class IterationNode(BaseNode[IterationNodeData]):
         iterator_list_value = iterator_list_segment.to_object()
 
         if not isinstance(iterator_list_value, list):
-            raise ValueError(f"Invalid iterator value: {iterator_list_value}, please provide a list.")
+            raise InvalidIteratorValueError(f"Invalid iterator value: {iterator_list_value}, please provide a list.")
 
         inputs = {"iterator_selector": iterator_list_value}
 
         graph_config = self.graph_config
 
         if not self.node_data.start_node_id:
-            raise ValueError(f"field start_node_id in iteration {self.node_id} not found")
+            raise StartNodeIdNotFoundError(f"field start_node_id in iteration {self.node_id} not found")
 
         root_node_id = self.node_data.start_node_id
 
@@ -98,7 +107,7 @@ class IterationNode(BaseNode[IterationNodeData]):
         iteration_graph = Graph.init(graph_config=graph_config, root_node_id=root_node_id)
 
         if not iteration_graph:
-            raise ValueError("iteration graph not found")
+            raise IterationGraphNotFoundError("iteration graph not found")
 
         variable_pool = self.graph_runtime_state.variable_pool
 
@@ -222,9 +231,9 @@ class IterationNode(BaseNode[IterationNodeData]):
                     status=WorkflowNodeExecutionStatus.SUCCEEDED, outputs={"output": jsonable_encoder(outputs)}
                 )
             )
-        except Exception as e:
+        except IterationNodeError as e:
             # iteration run failed
-            logger.exception("Iteration run failed")
+            logger.warning("Iteration run failed")
             yield IterationRunFailedEvent(
                 iteration_id=self.id,
                 iteration_node_id=self.node_id,
@@ -272,7 +281,7 @@ class IterationNode(BaseNode[IterationNodeData]):
         iteration_graph = Graph.init(graph_config=graph_config, root_node_id=node_data.start_node_id)
 
         if not iteration_graph:
-            raise ValueError("iteration graph not found")
+            raise IterationGraphNotFoundError("iteration graph not found")
 
         for sub_node_id, sub_node_config in iteration_graph.node_id_config_mapping.items():
             if sub_node_config.get("data", {}).get("iteration_id") != node_id:
@@ -357,7 +366,7 @@ class IterationNode(BaseNode[IterationNodeData]):
             next_index = int(current_index) + 1
 
             if current_index is None:
-                raise ValueError(f"iteration {self.node_id} current index not found")
+                raise IterationIndexNotFoundError(f"iteration {self.node_id} current index not found")
             for event in rst:
                 if isinstance(event, (BaseNodeEvent | BaseParallelBranchEvent)) and not event.in_iteration_id:
                     event.in_iteration_id = self.node_id
@@ -484,8 +493,8 @@ class IterationNode(BaseNode[IterationNodeData]):
                 pre_iteration_output=jsonable_encoder(current_iteration_output) if current_iteration_output else None,
             )
 
-        except Exception as e:
-            logger.exception(f"Iteration run failed:{str(e)}")
+        except IterationNodeError as e:
+            logger.warning(f"Iteration run failed:{str(e)}")
             yield IterationRunFailedEvent(
                 iteration_id=self.id,
                 iteration_node_id=self.node_id,