Browse Source

feat: backwards invoke model

Yeuoly 9 months ago
parent
commit
0ad9dbea63

+ 5 - 1
api/controllers/inner_api/plugin/plugin.py

@@ -47,7 +47,11 @@ class PluginInvokeTextEmbeddingApi(Resource):
     @get_tenant
     @plugin_data(payload_type=RequestInvokeTextEmbedding)
     def post(self, user_id: str, tenant_model: Tenant, payload: RequestInvokeTextEmbedding):
-        pass
+        return PluginModelBackwardsInvocation.invoke_text_embedding(
+            user_id=user_id,
+            tenant=tenant_model,
+            payload=payload,
+        )
 
 
 class PluginInvokeRerankApi(Resource):

+ 3 - 1
api/core/model_manager.py

@@ -310,7 +310,9 @@ class ModelInstance:
             user=user,
         )
 
-    def invoke_tts(self, content_text: str, tenant_id: str, voice: str, user: Optional[str] = None) -> str:
+    def invoke_tts(
+        self, content_text: str, tenant_id: str, voice: str, user: Optional[str] = None
+    ) -> Generator[bytes, None, None]:
         """
         Invoke large language tts model
 

+ 128 - 3
api/core/plugin/backwards_invocation/model.py

@@ -1,9 +1,18 @@
+import tempfile
+from binascii import hexlify, unhexlify
 from collections.abc import Generator
 
 from core.model_manager import ModelManager
 from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk
 from core.plugin.backwards_invocation.base import BaseBackwardsInvocation
-from core.plugin.entities.request import RequestInvokeLLM
+from core.plugin.entities.request import (
+    RequestInvokeLLM,
+    RequestInvokeModeration,
+    RequestInvokeRerank,
+    RequestInvokeSpeech2Text,
+    RequestInvokeTextEmbedding,
+    RequestInvokeTTS,
+)
 from core.workflow.nodes.llm.llm_node import LLMNode
 from models.account import Tenant
 
@@ -48,5 +57,121 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation):
             if response.usage:
                 LLMNode.deduct_llm_quota(tenant_id=tenant.id, model_instance=model_instance, usage=response.usage)
             return response
-    
-    
+
+    @classmethod
+    def invoke_text_embedding(cls, user_id: str, tenant: Tenant, payload: RequestInvokeTextEmbedding):
+        """
+        invoke text embedding
+        """
+        model_instance = ModelManager().get_model_instance(
+            tenant_id=tenant.id,
+            provider=payload.provider,
+            model_type=payload.model_type,
+            model=payload.model,
+        )
+
+        # invoke model
+        response = model_instance.invoke_text_embedding(
+            texts=payload.texts,
+            user=user_id,
+        )
+
+        return response
+
+    @classmethod
+    def invoke_rerank(cls, user_id: str, tenant: Tenant, payload: RequestInvokeRerank):
+        """
+        invoke rerank
+        """
+        model_instance = ModelManager().get_model_instance(
+            tenant_id=tenant.id,
+            provider=payload.provider,
+            model_type=payload.model_type,
+            model=payload.model,
+        )
+
+        # invoke model
+        response = model_instance.invoke_rerank(
+            query=payload.query,
+            docs=payload.docs,
+            score_threshold=payload.score_threshold,
+            top_n=payload.top_n,
+            user=user_id,
+        )
+
+        return response
+
+    @classmethod
+    def invoke_tts(cls, user_id: str, tenant: Tenant, payload: RequestInvokeTTS):
+        """
+        invoke tts
+        """
+        model_instance = ModelManager().get_model_instance(
+            tenant_id=tenant.id,
+            provider=payload.provider,
+            model_type=payload.model_type,
+            model=payload.model,
+        )
+
+        # invoke model
+        response = model_instance.invoke_tts(
+            content_text=payload.content_text,
+            tenant_id=tenant.id,
+            voice=payload.voice,
+            user=user_id,
+        )
+
+        def handle() -> Generator[dict, None, None]:
+            for chunk in response:
+                yield {"result": hexlify(chunk).decode("utf-8")}
+
+        return handle()
+
+    @classmethod
+    def invoke_speech2text(cls, user_id: str, tenant: Tenant, payload: RequestInvokeSpeech2Text):
+        """
+        invoke speech2text
+        """
+        model_instance = ModelManager().get_model_instance(
+            tenant_id=tenant.id,
+            provider=payload.provider,
+            model_type=payload.model_type,
+            model=payload.model,
+        )
+
+        # invoke model
+        with tempfile.NamedTemporaryFile(suffix=".mp3", mode="wb", delete=True) as temp:
+            temp.write(unhexlify(payload.file))
+            temp.flush()
+            temp.seek(0)
+
+        response = model_instance.invoke_speech2text(
+            file=temp,
+            user=user_id,
+        )
+
+        return {
+            "result": response,
+        }
+
+    @classmethod
+    def invoke_moderation(cls, user_id: str, tenant: Tenant, payload: RequestInvokeModeration):
+        """
+        invoke moderation
+        """
+        model_instance = ModelManager().get_model_instance(
+            tenant_id=tenant.id,
+            provider=payload.provider,
+            model_type=payload.model_type,
+            model=payload.model,
+        )
+
+        # invoke model
+        response = model_instance.invoke_moderation(
+            text=payload.text,
+            user=user_id,
+        )
+
+        return {
+            "result": response,
+        }

+ 33 - 5
api/core/plugin/entities/request.py

@@ -74,35 +74,63 @@ class RequestInvokeLLM(BaseRequestInvokeModel):
         return v
 
 
-class RequestInvokeTextEmbedding(BaseModel):
+class RequestInvokeTextEmbedding(BaseRequestInvokeModel):
     """
     Request to invoke text embedding
     """
 
+    model_type: ModelType = ModelType.TEXT_EMBEDDING
+    texts: list[str]
 
-class RequestInvokeRerank(BaseModel):
+
+class RequestInvokeRerank(BaseRequestInvokeModel):
     """
     Request to invoke rerank
     """
 
+    model_type: ModelType = ModelType.RERANK
+    query: str
+    docs: list[str]
+    score_threshold: float
+    top_n: int
+
 
-class RequestInvokeTTS(BaseModel):
+class RequestInvokeTTS(BaseRequestInvokeModel):
     """
     Request to invoke TTS
     """
 
+    model_type: ModelType = ModelType.TTS
+    content_text: str
+    voice: str
+
 
-class RequestInvokeSpeech2Text(BaseModel):
+class RequestInvokeSpeech2Text(BaseRequestInvokeModel):
     """
     Request to invoke speech2text
     """
 
+    model_type: ModelType = ModelType.SPEECH2TEXT
+    file: bytes
 
-class RequestInvokeModeration(BaseModel):
+    @field_validator("file", mode="before")
+    @classmethod
+    def convert_file(cls, v):
+        # hex string to bytes
+        if isinstance(v, str):
+            return bytes.fromhex(v)
+        else:
+            raise ValueError("file must be a hex string")
+
+
+class RequestInvokeModeration(BaseRequestInvokeModel):
     """
     Request to invoke moderation
     """
 
+    model_type: ModelType = ModelType.MODERATION
+    text: str
+
 
 class RequestInvokeParameterExtractorNode(BaseModel):
     """