model.py 5.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178
  1. import tempfile
  2. from binascii import hexlify, unhexlify
  3. from collections.abc import Generator
  4. from core.model_manager import ModelManager
  5. from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk
  6. from core.plugin.backwards_invocation.base import BaseBackwardsInvocation
  7. from core.plugin.entities.request import (
  8. RequestInvokeLLM,
  9. RequestInvokeModeration,
  10. RequestInvokeRerank,
  11. RequestInvokeSpeech2Text,
  12. RequestInvokeTextEmbedding,
  13. RequestInvokeTTS,
  14. )
  15. from core.workflow.nodes.llm.llm_node import LLMNode
  16. from models.account import Tenant
  17. class PluginModelBackwardsInvocation(BaseBackwardsInvocation):
  18. @classmethod
  19. def invoke_llm(
  20. cls, user_id: str, tenant: Tenant, payload: RequestInvokeLLM
  21. ) -> Generator[LLMResultChunk, None, None] | LLMResult:
  22. """
  23. invoke llm
  24. """
  25. model_instance = ModelManager().get_model_instance(
  26. tenant_id=tenant.id,
  27. provider=payload.provider,
  28. model_type=payload.model_type,
  29. model=payload.model,
  30. )
  31. # invoke model
  32. response = model_instance.invoke_llm(
  33. prompt_messages=payload.prompt_messages,
  34. model_parameters=payload.model_parameters,
  35. tools=payload.tools,
  36. stop=payload.stop,
  37. stream=payload.stream or True,
  38. user=user_id,
  39. )
  40. if isinstance(response, Generator):
  41. def handle() -> Generator[LLMResultChunk, None, None]:
  42. for chunk in response:
  43. if chunk.delta.usage:
  44. LLMNode.deduct_llm_quota(
  45. tenant_id=tenant.id, model_instance=model_instance, usage=chunk.delta.usage
  46. )
  47. yield chunk
  48. return handle()
  49. else:
  50. if response.usage:
  51. LLMNode.deduct_llm_quota(tenant_id=tenant.id, model_instance=model_instance, usage=response.usage)
  52. return response
  53. @classmethod
  54. def invoke_text_embedding(cls, user_id: str, tenant: Tenant, payload: RequestInvokeTextEmbedding):
  55. """
  56. invoke text embedding
  57. """
  58. model_instance = ModelManager().get_model_instance(
  59. tenant_id=tenant.id,
  60. provider=payload.provider,
  61. model_type=payload.model_type,
  62. model=payload.model,
  63. )
  64. # invoke model
  65. response = model_instance.invoke_text_embedding(
  66. texts=payload.texts,
  67. user=user_id,
  68. )
  69. return response
  70. @classmethod
  71. def invoke_rerank(cls, user_id: str, tenant: Tenant, payload: RequestInvokeRerank):
  72. """
  73. invoke rerank
  74. """
  75. model_instance = ModelManager().get_model_instance(
  76. tenant_id=tenant.id,
  77. provider=payload.provider,
  78. model_type=payload.model_type,
  79. model=payload.model,
  80. )
  81. # invoke model
  82. response = model_instance.invoke_rerank(
  83. query=payload.query,
  84. docs=payload.docs,
  85. score_threshold=payload.score_threshold,
  86. top_n=payload.top_n,
  87. user=user_id,
  88. )
  89. return response
  90. @classmethod
  91. def invoke_tts(cls, user_id: str, tenant: Tenant, payload: RequestInvokeTTS):
  92. """
  93. invoke tts
  94. """
  95. model_instance = ModelManager().get_model_instance(
  96. tenant_id=tenant.id,
  97. provider=payload.provider,
  98. model_type=payload.model_type,
  99. model=payload.model,
  100. )
  101. # invoke model
  102. response = model_instance.invoke_tts(
  103. content_text=payload.content_text,
  104. tenant_id=tenant.id,
  105. voice=payload.voice,
  106. user=user_id,
  107. )
  108. def handle() -> Generator[dict, None, None]:
  109. for chunk in response:
  110. yield {"result": hexlify(chunk).decode("utf-8")}
  111. return handle()
  112. @classmethod
  113. def invoke_speech2text(cls, user_id: str, tenant: Tenant, payload: RequestInvokeSpeech2Text):
  114. """
  115. invoke speech2text
  116. """
  117. model_instance = ModelManager().get_model_instance(
  118. tenant_id=tenant.id,
  119. provider=payload.provider,
  120. model_type=payload.model_type,
  121. model=payload.model,
  122. )
  123. # invoke model
  124. with tempfile.NamedTemporaryFile(suffix=".mp3", mode="wb", delete=True) as temp:
  125. temp.write(unhexlify(payload.file))
  126. temp.flush()
  127. temp.seek(0)
  128. response = model_instance.invoke_speech2text(
  129. file=temp,
  130. user=user_id,
  131. )
  132. return {
  133. "result": response,
  134. }
  135. @classmethod
  136. def invoke_moderation(cls, user_id: str, tenant: Tenant, payload: RequestInvokeModeration):
  137. """
  138. invoke moderation
  139. """
  140. model_instance = ModelManager().get_model_instance(
  141. tenant_id=tenant.id,
  142. provider=payload.provider,
  143. model_type=payload.model_type,
  144. model=payload.model,
  145. )
  146. # invoke model
  147. response = model_instance.invoke_moderation(
  148. text=payload.text,
  149. user=user_id,
  150. )
  151. return {
  152. "result": response,
  153. }