Browse Source

feat: add llm blocking invoke (#15732)

Novice 1 month ago
parent
commit
04a0ae3aa9
1 changed files with 17 additions and 3 deletions
  1. 17 3
      api/core/plugin/backwards_invocation/model.py

+ 17 - 3
api/core/plugin/backwards_invocation/model.py

@@ -3,7 +3,7 @@ from binascii import hexlify, unhexlify
 from collections.abc import Generator
 from collections.abc import Generator
 
 
 from core.model_manager import ModelManager
 from core.model_manager import ModelManager
-from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk
+from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
 from core.model_runtime.entities.message_entities import (
 from core.model_runtime.entities.message_entities import (
     PromptMessage,
     PromptMessage,
     SystemPromptMessage,
     SystemPromptMessage,
@@ -46,7 +46,7 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation):
             model_parameters=payload.completion_params,
             model_parameters=payload.completion_params,
             tools=payload.tools,
             tools=payload.tools,
             stop=payload.stop,
             stop=payload.stop,
-            stream=payload.stream or True,
+            stream=True if payload.stream is None else payload.stream,
             user=user_id,
             user=user_id,
         )
         )
 
 
@@ -64,7 +64,21 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation):
         else:
         else:
             if response.usage:
             if response.usage:
                 LLMNode.deduct_llm_quota(tenant_id=tenant.id, model_instance=model_instance, usage=response.usage)
                 LLMNode.deduct_llm_quota(tenant_id=tenant.id, model_instance=model_instance, usage=response.usage)
-            return response
+
+            def handle_non_streaming(response: LLMResult) -> Generator[LLMResultChunk, None, None]:
+                yield LLMResultChunk(
+                    model=response.model,
+                    prompt_messages=response.prompt_messages,
+                    system_fingerprint=response.system_fingerprint,
+                    delta=LLMResultChunkDelta(
+                        index=0,
+                        message=response.message,
+                        usage=response.usage,
+                        finish_reason="",
+                    ),
+                )
+
+            return handle_non_streaming(response)
 
 
     @classmethod
     @classmethod
     def invoke_text_embedding(cls, user_id: str, tenant: Tenant, payload: RequestInvokeTextEmbedding):
     def invoke_text_embedding(cls, user_id: str, tenant: Tenant, payload: RequestInvokeTextEmbedding):