Explorar el Código

feat: backwards invoke llm

Yeuoly hace 10 meses
padre
commit
31e8b134d1

+ 12 - 2
api/controllers/inner_api/plugin/plugin.py

@@ -1,10 +1,13 @@
 import time
+from collections.abc import Generator
+
 from flask_restful import Resource, reqparse
 
 from controllers.console.setup import setup_required
 from controllers.inner_api import api
 from controllers.inner_api.plugin.wraps import get_tenant, plugin_data
 from controllers.inner_api.wraps import plugin_inner_api_only
+from core.plugin.backwards_invocation.model import PluginBackwardsInvocation
 from core.plugin.entities.request import (
     RequestInvokeLLM,
     RequestInvokeModeration,
@@ -17,7 +20,6 @@ from core.plugin.entities.request import (
 from core.tools.entities.tool_entities import ToolInvokeMessage
 from libs.helper import compact_generate_response
 from models.account import Tenant
-from services.plugin.plugin_invoke_service import PluginInvokeService
 
 
 class PluginInvokeLLMApi(Resource):
@@ -26,7 +28,15 @@ class PluginInvokeLLMApi(Resource):
     @get_tenant
     @plugin_data(payload_type=RequestInvokeLLM)
     def post(self, user_id: str, tenant_model: Tenant, payload: RequestInvokeLLM):
-        pass
+        def generator():
+            response = PluginBackwardsInvocation.invoke_llm(user_id, tenant_model, payload)
+            if isinstance(response, Generator):
+                for chunk in response:
+                    yield chunk.model_dump_json().encode() + b'\n\n'
+            else:
+                yield response.model_dump_json().encode() + b'\n\n'
+
+        return compact_generate_response(generator())
 
 
 class PluginInvokeTextEmbeddingApi(Resource):

+ 2 - 2
api/core/model_manager.py

@@ -7,7 +7,7 @@ from core.entities.provider_configuration import ProviderConfiguration, Provider
 from core.entities.provider_entities import ModelLoadBalancingConfiguration
 from core.errors.error import ProviderTokenNotInitError
 from core.model_runtime.callbacks.base_callback import Callback
-from core.model_runtime.entities.llm_entities import LLMResult
+from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk
 from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool
 from core.model_runtime.entities.model_entities import ModelType
 from core.model_runtime.entities.rerank_entities import RerankResult
@@ -103,7 +103,7 @@ class ModelInstance:
     def invoke_llm(self, prompt_messages: list[PromptMessage], model_parameters: Optional[dict] = None,
                    tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
                    stream: bool = True, user: Optional[str] = None, callbacks: Optional[list[Callback]] = None) \
-            -> Union[LLMResult, Generator]:
+            -> Union[LLMResult, Generator[LLMResultChunk, None, None]]:
         """
         Invoke large language model
 

+ 49 - 0
api/core/plugin/backwards_invocation/model.py

@@ -0,0 +1,49 @@
+from collections.abc import Generator
+
+from core.model_manager import ModelManager
+from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk
+from core.plugin.entities.request import RequestInvokeLLM
+from core.workflow.nodes.llm.llm_node import LLMNode
+from models.account import Tenant
+
+
+class PluginBackwardsInvocation:
+    @classmethod
+    def invoke_llm(
+        cls, user_id: str, tenant: Tenant, payload: RequestInvokeLLM
+    ) -> Generator[LLMResultChunk, None, None] | LLMResult:
+        """
+        invoke llm
+        """
+        model_instance = ModelManager().get_model_instance(
+            tenant_id=tenant.id,
+            provider=payload.provider,
+            model_type=payload.model_type,
+            model=payload.model,
+        )
+
+        # invoke model
+        response = model_instance.invoke_llm(
+            prompt_messages=payload.prompt_messages,
+            model_parameters=payload.model_parameters,
+            tools=payload.tools,
+            stop=payload.stop,
+            stream=payload.stream or True,
+            user=user_id,
+        )
+
+        if isinstance(response, Generator):
+
+            def handle() -> Generator[LLMResultChunk, None, None]:
+                for chunk in response:
+                    if chunk.delta.usage:
+                        LLMNode.deduct_llm_quota(
+                            tenant_id=tenant.id, model_instance=model_instance, usage=chunk.delta.usage
+                        )
+                    yield chunk
+
+            return handle()
+        else:
+            if response.usage:
+                LLMNode.deduct_llm_quota(tenant_id=tenant.id, model_instance=model_instance, usage=response.usage)
+            return response

+ 56 - 3
api/core/plugin/entities/request.py

@@ -1,4 +1,17 @@
-from pydantic import BaseModel
+from typing import Any, Optional
+
+from pydantic import BaseModel, Field, field_validator
+
+from core.model_runtime.entities.message_entities import (
+    AssistantPromptMessage,
+    PromptMessage,
+    PromptMessageRole,
+    PromptMessageTool,
+    SystemPromptMessage,
+    ToolPromptMessage,
+    UserPromptMessage,
+)
+from core.model_runtime.entities.model_entities import ModelType
 
 
 class RequestInvokeTool(BaseModel):
@@ -6,37 +19,77 @@ class RequestInvokeTool(BaseModel):
     Request to invoke a tool
     """
 
-class RequestInvokeLLM(BaseModel):
+
+class BaseRequestInvokeModel(BaseModel):
+    provider: str
+    model: str
+    model_type: ModelType
+
+
+class RequestInvokeLLM(BaseRequestInvokeModel):
     """
     Request to invoke LLM
     """
 
+    model_type: ModelType = ModelType.LLM
+    mode: str
+    model_parameters: dict[str, Any] = Field(default_factory=dict)
+    prompt_messages: list[PromptMessage]
+    tools: Optional[list[PromptMessageTool]] = Field(default_factory=list)
+    stop: Optional[list[str]] = Field(default_factory=list)
+    stream: Optional[bool] = False
+
+    @field_validator('prompt_messages', mode='before')
+    def convert_prompt_messages(cls, v):
+        if not isinstance(v, list):
+            raise ValueError('prompt_messages must be a list')
+
+        for i in range(len(v)):
+            if v[i]['role'] == PromptMessageRole.USER.value:
+                v[i] = UserPromptMessage(**v[i])
+            elif v[i]['role'] == PromptMessageRole.ASSISTANT.value:
+                v[i] = AssistantPromptMessage(**v[i])
+            elif v[i]['role'] == PromptMessageRole.SYSTEM.value:
+                v[i] = SystemPromptMessage(**v[i])
+            elif v[i]['role'] == PromptMessageRole.TOOL.value:
+                v[i] = ToolPromptMessage(**v[i])
+            else:
+                v[i] = PromptMessage(**v[i])
+
+        return v
+
+
 class RequestInvokeTextEmbedding(BaseModel):
     """
     Request to invoke text embedding
     """
 
+
 class RequestInvokeRerank(BaseModel):
     """
     Request to invoke rerank
     """
 
+
 class RequestInvokeTTS(BaseModel):
     """
     Request to invoke TTS
     """
 
+
 class RequestInvokeSpeech2Text(BaseModel):
     """
     Request to invoke speech2text
     """
 
+
 class RequestInvokeModeration(BaseModel):
     """
     Request to invoke moderation
     """
 
+
 class RequestInvokeNode(BaseModel):
     """
     Request to invoke node
-    """
+    """