Browse Source

refactor(parameter_extractor): implement custom error classes (#10260)

-LAN- 5 months ago
parent
commit
7a98dab6a4

+ 50 - 0
api/core/workflow/nodes/parameter_extractor/exc.py

@@ -0,0 +1,50 @@
+class ParameterExtractorNodeError(ValueError):
+    """Base error for ParameterExtractorNode."""
+
+
+class InvalidModelTypeError(ParameterExtractorNodeError):
+    """Raised when the model is not a Large Language Model."""
+
+
+class ModelSchemaNotFoundError(ParameterExtractorNodeError):
+    """Raised when the model schema is not found."""
+
+
+class InvalidInvokeResultError(ParameterExtractorNodeError):
+    """Raised when the invoke result is invalid."""
+
+
+class InvalidTextContentTypeError(ParameterExtractorNodeError):
+    """Raised when the text content type is invalid."""
+
+
+class InvalidNumberOfParametersError(ParameterExtractorNodeError):
+    """Raised when the number of parameters is invalid."""
+
+
+class RequiredParameterMissingError(ParameterExtractorNodeError):
+    """Raised when a required parameter is missing."""
+
+
+class InvalidSelectValueError(ParameterExtractorNodeError):
+    """Raised when a select value is invalid."""
+
+
+class InvalidNumberValueError(ParameterExtractorNodeError):
+    """Raised when a number value is invalid."""
+
+
+class InvalidBoolValueError(ParameterExtractorNodeError):
+    """Raised when a bool value is invalid."""
+
+
+class InvalidStringValueError(ParameterExtractorNodeError):
+    """Raised when a string value is invalid."""
+
+
+class InvalidArrayValueError(ParameterExtractorNodeError):
+    """Raised when an array value is invalid."""
+
+
+class InvalidModelModeError(ParameterExtractorNodeError):
+    """Raised when the model mode is invalid."""

+ 36 - 21
api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py

@@ -32,6 +32,21 @@ from extensions.ext_database import db
 from models.workflow import WorkflowNodeExecutionStatus
 
 from .entities import ParameterExtractorNodeData
+from .exc import (
+    InvalidArrayValueError,
+    InvalidBoolValueError,
+    InvalidInvokeResultError,
+    InvalidModelModeError,
+    InvalidModelTypeError,
+    InvalidNumberOfParametersError,
+    InvalidNumberValueError,
+    InvalidSelectValueError,
+    InvalidStringValueError,
+    InvalidTextContentTypeError,
+    ModelSchemaNotFoundError,
+    ParameterExtractorNodeError,
+    RequiredParameterMissingError,
+)
 from .prompts import (
     CHAT_EXAMPLE,
     CHAT_GENERATE_JSON_USER_MESSAGE_TEMPLATE,
@@ -85,7 +100,7 @@ class ParameterExtractorNode(LLMNode):
 
         model_instance, model_config = self._fetch_model_config(node_data.model)
         if not isinstance(model_instance.model_type_instance, LargeLanguageModel):
-            raise ValueError("Model is not a Large Language Model")
+            raise InvalidModelTypeError("Model is not a Large Language Model")
 
         llm_model = model_instance.model_type_instance
         model_schema = llm_model.get_model_schema(
@@ -93,7 +108,7 @@ class ParameterExtractorNode(LLMNode):
             credentials=model_config.credentials,
         )
         if not model_schema:
-            raise ValueError("Model schema not found")
+            raise ModelSchemaNotFoundError("Model schema not found")
 
         # fetch memory
         memory = self._fetch_memory(
@@ -155,7 +170,7 @@ class ParameterExtractorNode(LLMNode):
             process_data["usage"] = jsonable_encoder(usage)
             process_data["tool_call"] = jsonable_encoder(tool_call)
             process_data["llm_text"] = text
-        except Exception as e:
+        except ParameterExtractorNodeError as e:
             return NodeRunResult(
                 status=WorkflowNodeExecutionStatus.FAILED,
                 inputs=inputs,
@@ -177,7 +192,7 @@ class ParameterExtractorNode(LLMNode):
 
         try:
             result = self._validate_result(data=node_data, result=result or {})
-        except Exception as e:
+        except ParameterExtractorNodeError as e:
             error = str(e)
 
         # transform result into standard format
@@ -217,11 +232,11 @@ class ParameterExtractorNode(LLMNode):
 
         # handle invoke result
         if not isinstance(invoke_result, LLMResult):
-            raise ValueError(f"Invalid invoke result: {invoke_result}")
+            raise InvalidInvokeResultError(f"Invalid invoke result: {invoke_result}")
 
         text = invoke_result.message.content
         if not isinstance(text, str):
-            raise ValueError(f"Invalid text content type: {type(text)}. Expected str.")
+            raise InvalidTextContentTypeError(f"Invalid text content type: {type(text)}. Expected str.")
 
         usage = invoke_result.usage
         tool_call = invoke_result.message.tool_calls[0] if invoke_result.message.tool_calls else None
@@ -344,7 +359,7 @@ class ParameterExtractorNode(LLMNode):
                 files=files,
             )
         else:
-            raise ValueError(f"Invalid model mode: {model_mode}")
+            raise InvalidModelModeError(f"Invalid model mode: {model_mode}")
 
     def _generate_prompt_engineering_completion_prompt(
         self,
@@ -449,36 +464,36 @@ class ParameterExtractorNode(LLMNode):
         Validate result.
         """
         if len(data.parameters) != len(result):
-            raise ValueError("Invalid number of parameters")
+            raise InvalidNumberOfParametersError("Invalid number of parameters")
 
         for parameter in data.parameters:
             if parameter.required and parameter.name not in result:
-                raise ValueError(f"Parameter {parameter.name} is required")
+                raise RequiredParameterMissingError(f"Parameter {parameter.name} is required")
 
             if parameter.type == "select" and parameter.options and result.get(parameter.name) not in parameter.options:
-                raise ValueError(f"Invalid `select` value for parameter {parameter.name}")
+                raise InvalidSelectValueError(f"Invalid `select` value for parameter {parameter.name}")
 
             if parameter.type == "number" and not isinstance(result.get(parameter.name), int | float):
-                raise ValueError(f"Invalid `number` value for parameter {parameter.name}")
+                raise InvalidNumberValueError(f"Invalid `number` value for parameter {parameter.name}")
 
             if parameter.type == "bool" and not isinstance(result.get(parameter.name), bool):
-                raise ValueError(f"Invalid `bool` value for parameter {parameter.name}")
+                raise InvalidBoolValueError(f"Invalid `bool` value for parameter {parameter.name}")
 
             if parameter.type == "string" and not isinstance(result.get(parameter.name), str):
-                raise ValueError(f"Invalid `string` value for parameter {parameter.name}")
+                raise InvalidStringValueError(f"Invalid `string` value for parameter {parameter.name}")
 
             if parameter.type.startswith("array"):
                 parameters = result.get(parameter.name)
                 if not isinstance(parameters, list):
-                    raise ValueError(f"Invalid `array` value for parameter {parameter.name}")
+                    raise InvalidArrayValueError(f"Invalid `array` value for parameter {parameter.name}")
                 nested_type = parameter.type[6:-1]
                 for item in parameters:
                     if nested_type == "number" and not isinstance(item, int | float):
-                        raise ValueError(f"Invalid `array[number]` value for parameter {parameter.name}")
+                        raise InvalidArrayValueError(f"Invalid `array[number]` value for parameter {parameter.name}")
                     if nested_type == "string" and not isinstance(item, str):
-                        raise ValueError(f"Invalid `array[string]` value for parameter {parameter.name}")
+                        raise InvalidArrayValueError(f"Invalid `array[string]` value for parameter {parameter.name}")
                     if nested_type == "object" and not isinstance(item, dict):
-                        raise ValueError(f"Invalid `array[object]` value for parameter {parameter.name}")
+                        raise InvalidArrayValueError(f"Invalid `array[object]` value for parameter {parameter.name}")
         return result
 
     def _transform_result(self, data: ParameterExtractorNodeData, result: dict) -> dict:
@@ -634,7 +649,7 @@ class ParameterExtractorNode(LLMNode):
             user_prompt_message = ChatModelMessage(role=PromptMessageRole.USER, text=input_text)
             return [system_prompt_messages, user_prompt_message]
         else:
-            raise ValueError(f"Model mode {model_mode} not support.")
+            raise InvalidModelModeError(f"Model mode {model_mode} not support.")
 
     def _get_prompt_engineering_prompt_template(
         self,
@@ -669,7 +684,7 @@ class ParameterExtractorNode(LLMNode):
                 .replace("}γγγ", "")
             )
         else:
-            raise ValueError(f"Model mode {model_mode} not support.")
+            raise InvalidModelModeError(f"Model mode {model_mode} not support.")
 
     def _calculate_rest_token(
         self,
@@ -683,12 +698,12 @@ class ParameterExtractorNode(LLMNode):
 
         model_instance, model_config = self._fetch_model_config(node_data.model)
         if not isinstance(model_instance.model_type_instance, LargeLanguageModel):
-            raise ValueError("Model is not a Large Language Model")
+            raise InvalidModelTypeError("Model is not a Large Language Model")
 
         llm_model = model_instance.model_type_instance
         model_schema = llm_model.get_model_schema(model_config.model, model_config.credentials)
         if not model_schema:
-            raise ValueError("Model schema not found")
+            raise ModelSchemaNotFoundError("Model schema not found")
 
         if set(model_schema.features or []) & {ModelFeature.MULTI_TOOL_CALL, ModelFeature.MULTI_TOOL_CALL}:
             prompt_template = self._get_function_calling_prompt_template(node_data, query, variable_pool, None, 2000)