|
@@ -8,7 +8,6 @@ from core.app.app_config.entities import DatasetRetrieveConfigEntity
|
|
|
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
|
|
|
from core.entities.agent_entities import PlanningStrategy
|
|
|
from core.entities.model_entities import ModelStatus
|
|
|
-from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
|
|
|
from core.model_manager import ModelInstance, ModelManager
|
|
|
from core.model_runtime.entities.model_entities import ModelFeature, ModelType
|
|
|
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
|
@@ -18,11 +17,19 @@ from core.variables import StringSegment
|
|
|
from core.workflow.entities.node_entities import NodeRunResult
|
|
|
from core.workflow.nodes.base import BaseNode
|
|
|
from core.workflow.nodes.enums import NodeType
|
|
|
-from core.workflow.nodes.knowledge_retrieval.entities import KnowledgeRetrievalNodeData
|
|
|
from extensions.ext_database import db
|
|
|
from models.dataset import Dataset, Document, DocumentSegment
|
|
|
from models.workflow import WorkflowNodeExecutionStatus
|
|
|
|
|
|
+from .entities import KnowledgeRetrievalNodeData
|
|
|
+from .exc import (
|
|
|
+ KnowledgeRetrievalNodeError,
|
|
|
+ ModelCredentialsNotInitializedError,
|
|
|
+ ModelNotExistError,
|
|
|
+ ModelNotSupportedError,
|
|
|
+ ModelQuotaExceededError,
|
|
|
+)
|
|
|
+
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
default_retrieval_model = {
|
|
@@ -61,8 +68,8 @@ class KnowledgeRetrievalNode(BaseNode[KnowledgeRetrievalNodeData]):
|
|
|
status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=variables, process_data=None, outputs=outputs
|
|
|
)
|
|
|
|
|
|
- except Exception as e:
|
|
|
- logger.exception("Error when running knowledge retrieval node")
|
|
|
+ except KnowledgeRetrievalNodeError as e:
|
|
|
+ logger.warning("Error when running knowledge retrieval node")
|
|
|
return NodeRunResult(status=WorkflowNodeExecutionStatus.FAILED, inputs=variables, error=str(e))
|
|
|
|
|
|
def _fetch_dataset_retriever(self, node_data: KnowledgeRetrievalNodeData, query: str) -> list[dict[str, Any]]:
|
|
@@ -295,14 +302,14 @@ class KnowledgeRetrievalNode(BaseNode[KnowledgeRetrievalNodeData]):
|
|
|
)
|
|
|
|
|
|
if provider_model is None:
|
|
|
- raise ValueError(f"Model {model_name} not exist.")
|
|
|
+ raise ModelNotExistError(f"Model {model_name} not exist.")
|
|
|
|
|
|
if provider_model.status == ModelStatus.NO_CONFIGURE:
|
|
|
- raise ProviderTokenNotInitError(f"Model {model_name} credentials is not initialized.")
|
|
|
+ raise ModelCredentialsNotInitializedError(f"Model {model_name} credentials is not initialized.")
|
|
|
elif provider_model.status == ModelStatus.NO_PERMISSION:
|
|
|
- raise ModelCurrentlyNotSupportError(f"Dify Hosted OpenAI {model_name} currently not support.")
|
|
|
+ raise ModelNotSupportedError(f"Dify Hosted OpenAI {model_name} currently not support.")
|
|
|
elif provider_model.status == ModelStatus.QUOTA_EXCEEDED:
|
|
|
- raise QuotaExceededError(f"Model provider {provider_name} quota exceeded.")
|
|
|
+ raise ModelQuotaExceededError(f"Model provider {provider_name} quota exceeded.")
|
|
|
|
|
|
# model config
|
|
|
completion_params = node_data.single_retrieval_config.model.completion_params
|
|
@@ -314,12 +321,12 @@ class KnowledgeRetrievalNode(BaseNode[KnowledgeRetrievalNodeData]):
|
|
|
# get model mode
|
|
|
model_mode = node_data.single_retrieval_config.model.mode
|
|
|
if not model_mode:
|
|
|
- raise ValueError("LLM mode is required.")
|
|
|
+ raise ModelNotExistError("LLM mode is required.")
|
|
|
|
|
|
model_schema = model_type_instance.get_model_schema(model_name, model_credentials)
|
|
|
|
|
|
if not model_schema:
|
|
|
- raise ValueError(f"Model {model_name} not exist.")
|
|
|
+ raise ModelNotExistError(f"Model {model_name} not exist.")
|
|
|
|
|
|
return model_instance, ModelConfigWithCredentialsEntity(
|
|
|
provider=provider_name,
|