Browse Source

refactor(knowledge-retrieval): improve error handling with custom exceptions (#10385)

-LAN- 5 months ago
parent
commit
25785d8c3f

+ 18 - 0
api/core/workflow/nodes/knowledge_retrieval/exc.py

@@ -0,0 +1,18 @@
+class KnowledgeRetrievalNodeError(ValueError):
+    """Base class for KnowledgeRetrievalNode errors."""
+
+
+class ModelNotExistError(KnowledgeRetrievalNodeError):
+    """Raised when the model does not exist."""
+
+
+class ModelCredentialsNotInitializedError(KnowledgeRetrievalNodeError):
+    """Raised when the model credentials are not initialized."""
+
+
+class ModelNotSupportedError(KnowledgeRetrievalNodeError):
+    """Raised when the model is not supported."""
+
+
+class ModelQuotaExceededError(KnowledgeRetrievalNodeError):
+    """Raised when the model provider quota is exceeded."""

+ 17 - 10
api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py

@@ -8,7 +8,6 @@ from core.app.app_config.entities import DatasetRetrieveConfigEntity
 from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
 from core.entities.agent_entities import PlanningStrategy
 from core.entities.model_entities import ModelStatus
-from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
 from core.model_manager import ModelInstance, ModelManager
 from core.model_runtime.entities.model_entities import ModelFeature, ModelType
 from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
@@ -18,11 +17,19 @@ from core.variables import StringSegment
 from core.workflow.entities.node_entities import NodeRunResult
 from core.workflow.nodes.base import BaseNode
 from core.workflow.nodes.enums import NodeType
-from core.workflow.nodes.knowledge_retrieval.entities import KnowledgeRetrievalNodeData
 from extensions.ext_database import db
 from models.dataset import Dataset, Document, DocumentSegment
 from models.workflow import WorkflowNodeExecutionStatus
 
+from .entities import KnowledgeRetrievalNodeData
+from .exc import (
+    KnowledgeRetrievalNodeError,
+    ModelCredentialsNotInitializedError,
+    ModelNotExistError,
+    ModelNotSupportedError,
+    ModelQuotaExceededError,
+)
+
 logger = logging.getLogger(__name__)
 
 default_retrieval_model = {
@@ -61,8 +68,8 @@ class KnowledgeRetrievalNode(BaseNode[KnowledgeRetrievalNodeData]):
                 status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=variables, process_data=None, outputs=outputs
             )
 
-        except Exception as e:
-            logger.exception("Error when running knowledge retrieval node")
+        except KnowledgeRetrievalNodeError as e:
+            logger.warning("Error when running knowledge retrieval node")
             return NodeRunResult(status=WorkflowNodeExecutionStatus.FAILED, inputs=variables, error=str(e))
 
     def _fetch_dataset_retriever(self, node_data: KnowledgeRetrievalNodeData, query: str) -> list[dict[str, Any]]:
@@ -295,14 +302,14 @@ class KnowledgeRetrievalNode(BaseNode[KnowledgeRetrievalNodeData]):
         )
 
         if provider_model is None:
-            raise ValueError(f"Model {model_name} not exist.")
+            raise ModelNotExistError(f"Model {model_name} not exist.")
 
         if provider_model.status == ModelStatus.NO_CONFIGURE:
-            raise ProviderTokenNotInitError(f"Model {model_name} credentials is not initialized.")
+            raise ModelCredentialsNotInitializedError(f"Model {model_name} credentials is not initialized.")
         elif provider_model.status == ModelStatus.NO_PERMISSION:
-            raise ModelCurrentlyNotSupportError(f"Dify Hosted OpenAI {model_name} currently not support.")
+            raise ModelNotSupportedError(f"Dify Hosted OpenAI {model_name} currently not support.")
         elif provider_model.status == ModelStatus.QUOTA_EXCEEDED:
-            raise QuotaExceededError(f"Model provider {provider_name} quota exceeded.")
+            raise ModelQuotaExceededError(f"Model provider {provider_name} quota exceeded.")
 
         # model config
         completion_params = node_data.single_retrieval_config.model.completion_params
@@ -314,12 +321,12 @@ class KnowledgeRetrievalNode(BaseNode[KnowledgeRetrievalNodeData]):
         # get model mode
         model_mode = node_data.single_retrieval_config.model.mode
         if not model_mode:
-            raise ValueError("LLM mode is required.")
+            raise ModelNotExistError("LLM mode is required.")
 
         model_schema = model_type_instance.get_model_schema(model_name, model_credentials)
 
         if not model_schema:
-            raise ValueError(f"Model {model_name} not exist.")
+            raise ModelNotExistError(f"Model {model_name} not exist.")
 
         return model_instance, ModelConfigWithCredentialsEntity(
             provider=provider_name,