123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178 |
- import tempfile
- from binascii import hexlify, unhexlify
- from collections.abc import Generator
- from core.model_manager import ModelManager
- from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk
- from core.plugin.backwards_invocation.base import BaseBackwardsInvocation
- from core.plugin.entities.request import (
- RequestInvokeLLM,
- RequestInvokeModeration,
- RequestInvokeRerank,
- RequestInvokeSpeech2Text,
- RequestInvokeTextEmbedding,
- RequestInvokeTTS,
- )
- from core.workflow.nodes.llm.llm_node import LLMNode
- from models.account import Tenant
- class PluginModelBackwardsInvocation(BaseBackwardsInvocation):
- @classmethod
- def invoke_llm(
- cls, user_id: str, tenant: Tenant, payload: RequestInvokeLLM
- ) -> Generator[LLMResultChunk, None, None] | LLMResult:
- """
- invoke llm
- """
- model_instance = ModelManager().get_model_instance(
- tenant_id=tenant.id,
- provider=payload.provider,
- model_type=payload.model_type,
- model=payload.model,
- )
- # invoke model
- response = model_instance.invoke_llm(
- prompt_messages=payload.prompt_messages,
- model_parameters=payload.model_parameters,
- tools=payload.tools,
- stop=payload.stop,
- stream=payload.stream or True,
- user=user_id,
- )
- if isinstance(response, Generator):
- def handle() -> Generator[LLMResultChunk, None, None]:
- for chunk in response:
- if chunk.delta.usage:
- LLMNode.deduct_llm_quota(
- tenant_id=tenant.id, model_instance=model_instance, usage=chunk.delta.usage
- )
- yield chunk
- return handle()
- else:
- if response.usage:
- LLMNode.deduct_llm_quota(tenant_id=tenant.id, model_instance=model_instance, usage=response.usage)
- return response
- @classmethod
- def invoke_text_embedding(cls, user_id: str, tenant: Tenant, payload: RequestInvokeTextEmbedding):
- """
- invoke text embedding
- """
- model_instance = ModelManager().get_model_instance(
- tenant_id=tenant.id,
- provider=payload.provider,
- model_type=payload.model_type,
- model=payload.model,
- )
- # invoke model
- response = model_instance.invoke_text_embedding(
- texts=payload.texts,
- user=user_id,
- )
- return response
- @classmethod
- def invoke_rerank(cls, user_id: str, tenant: Tenant, payload: RequestInvokeRerank):
- """
- invoke rerank
- """
- model_instance = ModelManager().get_model_instance(
- tenant_id=tenant.id,
- provider=payload.provider,
- model_type=payload.model_type,
- model=payload.model,
- )
- # invoke model
- response = model_instance.invoke_rerank(
- query=payload.query,
- docs=payload.docs,
- score_threshold=payload.score_threshold,
- top_n=payload.top_n,
- user=user_id,
- )
- return response
- @classmethod
- def invoke_tts(cls, user_id: str, tenant: Tenant, payload: RequestInvokeTTS):
- """
- invoke tts
- """
- model_instance = ModelManager().get_model_instance(
- tenant_id=tenant.id,
- provider=payload.provider,
- model_type=payload.model_type,
- model=payload.model,
- )
- # invoke model
- response = model_instance.invoke_tts(
- content_text=payload.content_text,
- tenant_id=tenant.id,
- voice=payload.voice,
- user=user_id,
- )
- def handle() -> Generator[dict, None, None]:
- for chunk in response:
- yield {"result": hexlify(chunk).decode("utf-8")}
- return handle()
- @classmethod
- def invoke_speech2text(cls, user_id: str, tenant: Tenant, payload: RequestInvokeSpeech2Text):
- """
- invoke speech2text
- """
- model_instance = ModelManager().get_model_instance(
- tenant_id=tenant.id,
- provider=payload.provider,
- model_type=payload.model_type,
- model=payload.model,
- )
- # invoke model
- with tempfile.NamedTemporaryFile(suffix=".mp3", mode="wb", delete=True) as temp:
- temp.write(unhexlify(payload.file))
- temp.flush()
- temp.seek(0)
- response = model_instance.invoke_speech2text(
- file=temp,
- user=user_id,
- )
- return {
- "result": response,
- }
- @classmethod
- def invoke_moderation(cls, user_id: str, tenant: Tenant, payload: RequestInvokeModeration):
- """
- invoke moderation
- """
- model_instance = ModelManager().get_model_instance(
- tenant_id=tenant.id,
- provider=payload.provider,
- model_type=payload.model_type,
- model=payload.model,
- )
- # invoke model
- response = model_instance.invoke_moderation(
- text=payload.text,
- user=user_id,
- )
- return {
- "result": response,
- }
|