Browse Source

refactor(workflow): introduce specific error handling for LLM nodes (#10221)

-LAN- 5 months ago
parent
commit
38bca6731c
2 changed files with 47 additions and 12 deletions
  1. 26 0
      api/core/workflow/nodes/llm/exc.py
  2. 21 12
      api/core/workflow/nodes/llm/node.py

+ 26 - 0
api/core/workflow/nodes/llm/exc.py

@@ -0,0 +1,26 @@
+class LLMNodeError(ValueError):
+    """Base class for LLM Node errors."""
+
+
+class VariableNotFoundError(LLMNodeError):
+    """Raised when a required variable is not found."""
+
+
+class InvalidContextStructureError(LLMNodeError):
+    """Raised when the context structure is invalid."""
+
+
+class InvalidVariableTypeError(LLMNodeError):
+    """Raised when the variable type is invalid."""
+
+
+class ModelNotExistError(LLMNodeError):
+    """Raised when the specified model does not exist."""
+
+
+class LLMModeRequiredError(LLMNodeError):
+    """Raised when LLM mode is required but not provided."""
+
+
+class NoPromptFoundError(LLMNodeError):
+    """Raised when no prompt is found in the LLM configuration."""

+ 21 - 12
api/core/workflow/nodes/llm/node.py

@@ -56,6 +56,15 @@ from .entities import (
     LLMNodeData,
     ModelConfig,
 )
+from .exc import (
+    InvalidContextStructureError,
+    InvalidVariableTypeError,
+    LLMModeRequiredError,
+    LLMNodeError,
+    ModelNotExistError,
+    NoPromptFoundError,
+    VariableNotFoundError,
+)
 
 if TYPE_CHECKING:
     from core.file.models import File
@@ -115,7 +124,7 @@ class LLMNode(BaseNode[LLMNodeData]):
             if self.node_data.memory:
                 query = self.graph_runtime_state.variable_pool.get((SYSTEM_VARIABLE_NODE_ID, SystemVariableKey.QUERY))
                 if not query:
-                    raise ValueError("Query not found")
+                    raise VariableNotFoundError("Query not found")
                 query = query.text
             else:
                 query = None
@@ -161,7 +170,7 @@ class LLMNode(BaseNode[LLMNodeData]):
                     usage = event.usage
                     finish_reason = event.finish_reason
                     break
-        except Exception as e:
+        except LLMNodeError as e:
             yield RunCompletedEvent(
                 run_result=NodeRunResult(
                     status=WorkflowNodeExecutionStatus.FAILED,
@@ -275,7 +284,7 @@ class LLMNode(BaseNode[LLMNodeData]):
             variable_name = variable_selector.variable
             variable = self.graph_runtime_state.variable_pool.get(variable_selector.value_selector)
             if variable is None:
-                raise ValueError(f"Variable {variable_selector.variable} not found")
+                raise VariableNotFoundError(f"Variable {variable_selector.variable} not found")
 
             def parse_dict(input_dict: Mapping[str, Any]) -> str:
                 """
@@ -325,7 +334,7 @@ class LLMNode(BaseNode[LLMNodeData]):
         for variable_selector in variable_selectors:
             variable = self.graph_runtime_state.variable_pool.get(variable_selector.value_selector)
             if variable is None:
-                raise ValueError(f"Variable {variable_selector.variable} not found")
+                raise VariableNotFoundError(f"Variable {variable_selector.variable} not found")
             if isinstance(variable, NoneSegment):
                 inputs[variable_selector.variable] = ""
             inputs[variable_selector.variable] = variable.to_object()
@@ -338,7 +347,7 @@ class LLMNode(BaseNode[LLMNodeData]):
             for variable_selector in query_variable_selectors:
                 variable = self.graph_runtime_state.variable_pool.get(variable_selector.value_selector)
                 if variable is None:
-                    raise ValueError(f"Variable {variable_selector.variable} not found")
+                    raise VariableNotFoundError(f"Variable {variable_selector.variable} not found")
                 if isinstance(variable, NoneSegment):
                     continue
                 inputs[variable_selector.variable] = variable.to_object()
@@ -355,7 +364,7 @@ class LLMNode(BaseNode[LLMNodeData]):
             return variable.value
         elif isinstance(variable, NoneSegment | ArrayAnySegment):
             return []
-        raise ValueError(f"Invalid variable type: {type(variable)}")
+        raise InvalidVariableTypeError(f"Invalid variable type: {type(variable)}")
 
     def _fetch_context(self, node_data: LLMNodeData):
         if not node_data.context.enabled:
@@ -376,7 +385,7 @@ class LLMNode(BaseNode[LLMNodeData]):
                         context_str += item + "\n"
                     else:
                         if "content" not in item:
-                            raise ValueError(f"Invalid context structure: {item}")
+                            raise InvalidContextStructureError(f"Invalid context structure: {item}")
 
                         context_str += item["content"] + "\n"
 
@@ -441,7 +450,7 @@ class LLMNode(BaseNode[LLMNodeData]):
         )
 
         if provider_model is None:
-            raise ValueError(f"Model {model_name} not exist.")
+            raise ModelNotExistError(f"Model {model_name} not exist.")
 
         if provider_model.status == ModelStatus.NO_CONFIGURE:
             raise ProviderTokenNotInitError(f"Model {model_name} credentials is not initialized.")
@@ -460,12 +469,12 @@ class LLMNode(BaseNode[LLMNodeData]):
         # get model mode
         model_mode = node_data_model.mode
         if not model_mode:
-            raise ValueError("LLM mode is required.")
+            raise LLMModeRequiredError("LLM mode is required.")
 
         model_schema = model_type_instance.get_model_schema(model_name, model_credentials)
 
         if not model_schema:
-            raise ValueError(f"Model {model_name} not exist.")
+            raise ModelNotExistError(f"Model {model_name} not exist.")
 
         return model_instance, ModelConfigWithCredentialsEntity(
             provider=provider_name,
@@ -564,7 +573,7 @@ class LLMNode(BaseNode[LLMNodeData]):
             filtered_prompt_messages.append(prompt_message)
 
         if not filtered_prompt_messages:
-            raise ValueError(
+            raise NoPromptFoundError(
                 "No prompt found in the LLM configuration. "
                 "Please ensure a prompt is properly configured before proceeding."
             )
@@ -636,7 +645,7 @@ class LLMNode(BaseNode[LLMNodeData]):
                 variable_template_parser = VariableTemplateParser(template=prompt_template.text)
                 variable_selectors = variable_template_parser.extract_variable_selectors()
         else:
-            raise ValueError(f"Invalid prompt template type: {type(prompt_template)}")
+            raise InvalidVariableTypeError(f"Invalid prompt template type: {type(prompt_template)}")
 
         variable_mapping = {}
         for variable_selector in variable_selectors: