Bladeren bron

refactor: text-embedding interfaces to returns list[int]

Yeuoly 6 maanden geleden
bovenliggende
commit
cfa7c89dfe

+ 1 - 1
api/core/app/entities/app_invoke_entities.py

@@ -183,7 +183,7 @@ class AdvancedChatAppGenerateEntity(ConversationAppGenerateEntity):
         """
 
         node_id: str
-        inputs: dict
+        inputs: Mapping
 
     single_iteration_run: Optional[SingleIterationRunEntity] = None
 

+ 1 - 1
api/core/model_manager.py

@@ -219,7 +219,7 @@ class ModelInstance:
             input_type=input_type,
         )
 
-    def get_text_embedding_num_tokens(self, texts: list[str]) -> int:
+    def get_text_embedding_num_tokens(self, texts: list[str]) -> list[int]:
         """
         Get number of tokens for text embedding
 

+ 1 - 1
api/core/model_runtime/model_providers/__base/text_embedding_model.py

@@ -52,7 +52,7 @@ class TextEmbeddingModel(AIModel):
         except Exception as e:
             raise self._transform_invoke_error(e)
 
-    def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int:
+    def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> list[int]:
         """
         Get number of tokens for given prompt messages
 

+ 1 - 1
api/core/plugin/entities/plugin_daemon.py

@@ -76,7 +76,7 @@ class PluginNumTokensResponse(BaseModel):
     Response for number of tokens.
     """
 
-    num_tokens: int = Field(description="The number of tokens.")
+    num_tokens: list[int] = Field(description="The number of tokens.")
 
 
 class PluginStringResultResponse(BaseModel):

+ 37 - 14
api/core/plugin/manager/base.py

@@ -17,6 +17,14 @@ from core.model_runtime.errors.invoke import (
     InvokeServerUnavailableError,
 )
 from core.plugin.entities.plugin_daemon import PluginDaemonBasicResponse, PluginDaemonError, PluginDaemonInnerError
+from core.plugin.manager.exc import (
+    PluginDaemonBadRequestError,
+    PluginDaemonInternalServerError,
+    PluginDaemonNotFoundError,
+    PluginDaemonUnauthorizedError,
+    PluginPermissionDeniedError,
+    PluginUniqueIdentifierError,
+)
 
 plugin_daemon_inner_api_baseurl = dify_config.PLUGIN_API_URL
 plugin_daemon_inner_api_key = dify_config.PLUGIN_API_KEY
@@ -190,17 +198,32 @@ class BasePluginManager:
         """
         args = args or {}
 
-        if error_type == PluginDaemonInnerError.__name__:
-            raise PluginDaemonInnerError(code=-500, message=message)
-        elif error_type == InvokeRateLimitError.__name__:
-            raise InvokeRateLimitError(description=args.get("description"))
-        elif error_type == InvokeAuthorizationError.__name__:
-            raise InvokeAuthorizationError(description=args.get("description"))
-        elif error_type == InvokeBadRequestError.__name__:
-            raise InvokeBadRequestError(description=args.get("description"))
-        elif error_type == InvokeConnectionError.__name__:
-            raise InvokeConnectionError(description=args.get("description"))
-        elif error_type == InvokeServerUnavailableError.__name__:
-            raise InvokeServerUnavailableError(description=args.get("description"))
-        else:
-            raise ValueError(f"got unknown error from plugin daemon: {error_type}, message: {message}, args: {args}")
+        match error_type:
+            case PluginDaemonInnerError.__name__:
+                raise PluginDaemonInnerError(code=-500, message=message)
+            case InvokeRateLimitError.__name__:
+                raise InvokeRateLimitError(description=args.get("description"))
+            case InvokeAuthorizationError.__name__:
+                raise InvokeAuthorizationError(description=args.get("description"))
+            case InvokeBadRequestError.__name__:
+                raise InvokeBadRequestError(description=args.get("description"))
+            case InvokeConnectionError.__name__:
+                raise InvokeConnectionError(description=args.get("description"))
+            case InvokeServerUnavailableError.__name__:
+                raise InvokeServerUnavailableError(description=args.get("description"))
+            case PluginDaemonInternalServerError.__name__:
+                raise PluginDaemonInternalServerError(description=message)
+            case PluginDaemonBadRequestError.__name__:
+                raise PluginDaemonBadRequestError(description=message)
+            case PluginDaemonNotFoundError.__name__:
+                raise PluginDaemonNotFoundError(description=message)
+            case PluginUniqueIdentifierError.__name__:
+                raise PluginUniqueIdentifierError(description=message)
+            case PluginDaemonUnauthorizedError.__name__:
+                raise PluginDaemonUnauthorizedError(description=message)
+            case PluginPermissionDeniedError.__name__:
+                raise PluginPermissionDeniedError(description=message)
+            case _:
+                raise ValueError(
+                    f"got unknown error from plugin daemon: {error_type}, message: {message}, args: {args}"
+                )

+ 33 - 0
api/core/plugin/manager/exc.py

@@ -0,0 +1,33 @@
+class PluginDaemonError(Exception):
+    """Base class for all plugin daemon errors."""
+
+    def __init__(self, description: str) -> None:
+        self.description = description
+
+
+class PluginDaemonInternalServerError(PluginDaemonError):
+    description: str = "Internal Server Error"
+
+
+class PluginDaemonBadRequestError(PluginDaemonError):
+    description: str = "Bad Request"
+
+
+class PluginDaemonNotFoundError(PluginDaemonError):
+    description: str = "Not Found"
+
+
+class PluginUniqueIdentifierError(PluginDaemonError):
+    description: str = "Unique Identifier Error"
+
+
+class PluginNotFoundError(PluginDaemonError):
+    description: str = "Plugin Not Found"
+
+
+class PluginDaemonUnauthorizedError(PluginDaemonError):
+    description: str = "Unauthorized"
+
+
+class PluginPermissionDeniedError(PluginDaemonError):
+    description: str = "Permission Denied"

+ 2 - 2
api/core/plugin/manager/model.py

@@ -277,7 +277,7 @@ class PluginModelManager(BasePluginManager):
         model: str,
         credentials: dict,
         texts: list[str],
-    ) -> int:
+    ) -> list[int]:
         """
         Get number of tokens for text embedding
         """
@@ -306,7 +306,7 @@ class PluginModelManager(BasePluginManager):
         for resp in response:
             return resp.num_tokens
 
-        return 0
+        return []
 
     def invoke_rerank(
         self,