Browse Source

fix: missing error message

Yeuoly 8 months ago
parent
commit
5dcd25a613

+ 82 - 26
api/controllers/inner_api/plugin/plugin.py

@@ -7,6 +7,7 @@ from controllers.inner_api import api
 from controllers.inner_api.plugin.wraps import get_tenant, plugin_data
 from controllers.inner_api.wraps import plugin_inner_api_only
 from core.plugin.backwards_invocation.app import PluginAppBackwardsInvocation
+from core.plugin.backwards_invocation.base import BaseBackwardsInvocationResponse
 from core.plugin.backwards_invocation.model import PluginModelBackwardsInvocation
 from core.plugin.backwards_invocation.node import PluginNodeBackwardsInvocation
 from core.plugin.encrypt import PluginEncrypter
@@ -47,11 +48,16 @@ class PluginInvokeTextEmbeddingApi(Resource):
     @get_tenant
     @plugin_data(payload_type=RequestInvokeTextEmbedding)
     def post(self, user_id: str, tenant_model: Tenant, payload: RequestInvokeTextEmbedding):
-        return PluginModelBackwardsInvocation.invoke_text_embedding(
-            user_id=user_id,
-            tenant=tenant_model,
-            payload=payload,
-        )
+        try:
+            return BaseBackwardsInvocationResponse(
+                data=PluginModelBackwardsInvocation.invoke_text_embedding(
+                    user_id=user_id,
+                    tenant=tenant_model,
+                    payload=payload,
+                )
+            ).model_dump()
+        except Exception as e:
+            return BaseBackwardsInvocationResponse(error=str(e)).model_dump()
 
 
 class PluginInvokeRerankApi(Resource):
@@ -60,7 +66,16 @@ class PluginInvokeRerankApi(Resource):
     @get_tenant
     @plugin_data(payload_type=RequestInvokeRerank)
     def post(self, user_id: str, tenant_model: Tenant, payload: RequestInvokeRerank):
-        pass
+        try:
+            return BaseBackwardsInvocationResponse(
+                data=PluginModelBackwardsInvocation.invoke_rerank(
+                    user_id=user_id,
+                    tenant=tenant_model,
+                    payload=payload,
+                )
+            ).model_dump()
+        except Exception as e:
+            return BaseBackwardsInvocationResponse(error=str(e)).model_dump()
 
 
 class PluginInvokeTTSApi(Resource):
@@ -69,7 +84,15 @@ class PluginInvokeTTSApi(Resource):
     @get_tenant
     @plugin_data(payload_type=RequestInvokeTTS)
     def post(self, user_id: str, tenant_model: Tenant, payload: RequestInvokeTTS):
-        pass
+        def generator():
+            response = PluginModelBackwardsInvocation.invoke_tts(
+                user_id=user_id,
+                tenant=tenant_model,
+                payload=payload,
+            )
+            return PluginModelBackwardsInvocation.convert_to_event_stream(response)
+
+        return compact_generate_response(generator())
 
 
 class PluginInvokeSpeech2TextApi(Resource):
@@ -78,7 +101,16 @@ class PluginInvokeSpeech2TextApi(Resource):
     @get_tenant
     @plugin_data(payload_type=RequestInvokeSpeech2Text)
     def post(self, user_id: str, tenant_model: Tenant, payload: RequestInvokeSpeech2Text):
-        pass
+        try:
+            return BaseBackwardsInvocationResponse(
+                data=PluginModelBackwardsInvocation.invoke_speech2text(
+                    user_id=user_id,
+                    tenant=tenant_model,
+                    payload=payload,
+                )
+            ).model_dump()
+        except Exception as e:
+            return BaseBackwardsInvocationResponse(error=str(e)).model_dump()
 
 
 class PluginInvokeModerationApi(Resource):
@@ -87,7 +119,16 @@ class PluginInvokeModerationApi(Resource):
     @get_tenant
     @plugin_data(payload_type=RequestInvokeModeration)
     def post(self, user_id: str, tenant_model: Tenant, payload: RequestInvokeModeration):
-        pass
+        try:
+            return BaseBackwardsInvocationResponse(
+                data=PluginModelBackwardsInvocation.invoke_moderation(
+                    user_id=user_id,
+                    tenant=tenant_model,
+                    payload=payload,
+                )
+            ).model_dump()
+        except Exception as e:
+            return BaseBackwardsInvocationResponse(error=str(e)).model_dump()
 
 
 class PluginInvokeToolApi(Resource):
@@ -118,14 +159,19 @@ class PluginInvokeParameterExtractorNodeApi(Resource):
     @get_tenant
     @plugin_data(payload_type=RequestInvokeParameterExtractorNode)
     def post(self, user_id: str, tenant_model: Tenant, payload: RequestInvokeParameterExtractorNode):
-        return PluginNodeBackwardsInvocation.invoke_parameter_extractor(
-            tenant_id=tenant_model.id,
-            user_id=user_id,
-            parameters=payload.parameters,
-            model_config=payload.model,
-            instruction=payload.instruction,
-            query=payload.query,
-        )
+        try:
+            return BaseBackwardsInvocationResponse(
+                data=PluginNodeBackwardsInvocation.invoke_parameter_extractor(
+                    tenant_id=tenant_model.id,
+                    user_id=user_id,
+                    parameters=payload.parameters,
+                    model_config=payload.model,
+                    instruction=payload.instruction,
+                    query=payload.query,
+                )
+            ).model_dump()
+        except Exception as e:
+            return BaseBackwardsInvocationResponse(error=str(e)).model_dump()
 
 
 class PluginInvokeQuestionClassifierNodeApi(Resource):
@@ -134,14 +180,19 @@ class PluginInvokeQuestionClassifierNodeApi(Resource):
     @get_tenant
     @plugin_data(payload_type=RequestInvokeQuestionClassifierNode)
     def post(self, user_id: str, tenant_model: Tenant, payload: RequestInvokeQuestionClassifierNode):
-        return PluginNodeBackwardsInvocation.invoke_question_classifier(
-            tenant_id=tenant_model.id,
-            user_id=user_id,
-            query=payload.query,
-            model_config=payload.model,
-            classes=payload.classes,
-            instruction=payload.instruction,
-        )
+        try:
+            return BaseBackwardsInvocationResponse(
+                data=PluginNodeBackwardsInvocation.invoke_question_classifier(
+                    tenant_id=tenant_model.id,
+                    user_id=user_id,
+                    query=payload.query,
+                    model_config=payload.model,
+                    classes=payload.classes,
+                    instruction=payload.instruction,
+                )
+            ).model_dump()
+        except Exception as e:
+            return BaseBackwardsInvocationResponse(error=str(e)).model_dump()
 
 
 class PluginInvokeAppApi(Resource):
@@ -173,7 +224,12 @@ class PluginInvokeEncryptApi(Resource):
         """
         encrypt or decrypt data
         """
-        return PluginEncrypter.invoke_encrypt(tenant_model, payload)
+        try:
+            return BaseBackwardsInvocationResponse(
+                data=PluginEncrypter.invoke_encrypt(tenant_model, payload)
+            ).model_dump()
+        except Exception as e:
+            return BaseBackwardsInvocationResponse(error=str(e)).model_dump()
 
 
 api.add_resource(PluginInvokeLLMApi, "/invoke/llm")

+ 23 - 9
api/core/plugin/backwards_invocation/base.py

@@ -1,5 +1,6 @@
 import json
 from collections.abc import Generator
+from typing import Generic, Optional, TypeVar
 
 from pydantic import BaseModel
 
@@ -8,15 +9,28 @@ class BaseBackwardsInvocation:
     @classmethod
     def convert_to_event_stream(cls, response: Generator[BaseModel | dict | str, None, None] | BaseModel | dict):
         if isinstance(response, Generator):
-            for chunk in response:
-                if isinstance(chunk, BaseModel):
-                    yield chunk.model_dump_json().encode() + b'\n\n'
-                elif isinstance(chunk, str):
-                    yield f"event: {chunk}\n\n".encode()
-                else:
-                    yield json.dumps(chunk).encode() + b'\n\n'
+            try:
+                for chunk in response:
+                    if isinstance(chunk, BaseModel):
+                        yield BaseBackwardsInvocationResponse(data=chunk).model_dump_json().encode() + b"\n\n"
+
+                    elif isinstance(chunk, str):
+                        yield f"event: {chunk}\n\n".encode()
+                    else:
+                        yield json.dumps(chunk).encode() + b"\n\n"
+            except Exception as e:
+                error_message = BaseBackwardsInvocationResponse(error=str(e)).model_dump_json()
+                yield f"{error_message}\n\n".encode()
         else:
             if isinstance(response, BaseModel):
-                yield response.model_dump_json().encode() + b'\n\n'
+                yield response.model_dump_json().encode() + b"\n\n"
             else:
-                yield json.dumps(response).encode() + b'\n\n'
+                yield json.dumps(response).encode() + b"\n\n"
+
+
+T = TypeVar("T", bound=BaseModel | dict | str | bool | int)
+
+
+class BaseBackwardsInvocationResponse(BaseModel, Generic[T]):
+    data: Optional[T] = None
+    error: str = ""

+ 5 - 14
api/core/plugin/encrypt/__init__.py

@@ -8,7 +8,7 @@ from models.account import Tenant
 
 class PluginEncrypter:
     @classmethod
-    def invoke_encrypt(cls, tenant: Tenant, payload: RequestInvokeEncrypt) -> Mapping[str, Any]:
+    def invoke_encrypt(cls, tenant: Tenant, payload: RequestInvokeEncrypt) -> dict:
         encrypter = ProviderConfigEncrypter(
             tenant_id=tenant.id,
             config=payload.data,
@@ -16,16 +16,7 @@ class PluginEncrypter:
             provider_identity=payload.identity,
         )
 
-        try:
-            if payload.opt == "encrypt":
-                return {
-                    "data": encrypter.encrypt(payload.data),
-                }
-            else:
-                return {
-                    "data": encrypter.decrypt(payload.data),
-                }
-        except Exception as e:
-            return {
-                "error": str(e),
-            }
+        if payload.opt == "encrypt":
+            return encrypter.encrypt(payload.data)
+        else:
+            return encrypter.decrypt(payload.data)