model.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322
  1. import tempfile
  2. from binascii import hexlify, unhexlify
  3. from collections.abc import Generator
  4. from core.model_manager import ModelManager
  5. from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk
  6. from core.model_runtime.entities.message_entities import (
  7. PromptMessage,
  8. SystemPromptMessage,
  9. UserPromptMessage,
  10. )
  11. from core.plugin.backwards_invocation.base import BaseBackwardsInvocation
  12. from core.plugin.entities.request import (
  13. RequestInvokeLLM,
  14. RequestInvokeModeration,
  15. RequestInvokeRerank,
  16. RequestInvokeSpeech2Text,
  17. RequestInvokeSummary,
  18. RequestInvokeTextEmbedding,
  19. RequestInvokeTTS,
  20. )
  21. from core.tools.entities.tool_entities import ToolProviderType
  22. from core.tools.utils.model_invocation_utils import ModelInvocationUtils
  23. from core.workflow.nodes.llm.node import LLMNode
  24. from models.account import Tenant
  25. class PluginModelBackwardsInvocation(BaseBackwardsInvocation):
  26. @classmethod
  27. def invoke_llm(
  28. cls, user_id: str, tenant: Tenant, payload: RequestInvokeLLM
  29. ) -> Generator[LLMResultChunk, None, None] | LLMResult:
  30. """
  31. invoke llm
  32. """
  33. model_instance = ModelManager().get_model_instance(
  34. tenant_id=tenant.id,
  35. provider=payload.provider,
  36. model_type=payload.model_type,
  37. model=payload.model,
  38. )
  39. # invoke model
  40. response = model_instance.invoke_llm(
  41. prompt_messages=payload.prompt_messages,
  42. model_parameters=payload.completion_params,
  43. tools=payload.tools,
  44. stop=payload.stop,
  45. stream=payload.stream or True,
  46. user=user_id,
  47. )
  48. if isinstance(response, Generator):
  49. def handle() -> Generator[LLMResultChunk, None, None]:
  50. for chunk in response:
  51. if chunk.delta.usage:
  52. LLMNode.deduct_llm_quota(
  53. tenant_id=tenant.id, model_instance=model_instance, usage=chunk.delta.usage
  54. )
  55. yield chunk
  56. return handle()
  57. else:
  58. if response.usage:
  59. LLMNode.deduct_llm_quota(tenant_id=tenant.id, model_instance=model_instance, usage=response.usage)
  60. return response
  61. @classmethod
  62. def invoke_text_embedding(cls, user_id: str, tenant: Tenant, payload: RequestInvokeTextEmbedding):
  63. """
  64. invoke text embedding
  65. """
  66. model_instance = ModelManager().get_model_instance(
  67. tenant_id=tenant.id,
  68. provider=payload.provider,
  69. model_type=payload.model_type,
  70. model=payload.model,
  71. )
  72. # invoke model
  73. response = model_instance.invoke_text_embedding(
  74. texts=payload.texts,
  75. user=user_id,
  76. )
  77. return response
  78. @classmethod
  79. def invoke_rerank(cls, user_id: str, tenant: Tenant, payload: RequestInvokeRerank):
  80. """
  81. invoke rerank
  82. """
  83. model_instance = ModelManager().get_model_instance(
  84. tenant_id=tenant.id,
  85. provider=payload.provider,
  86. model_type=payload.model_type,
  87. model=payload.model,
  88. )
  89. # invoke model
  90. response = model_instance.invoke_rerank(
  91. query=payload.query,
  92. docs=payload.docs,
  93. score_threshold=payload.score_threshold,
  94. top_n=payload.top_n,
  95. user=user_id,
  96. )
  97. return response
  98. @classmethod
  99. def invoke_tts(cls, user_id: str, tenant: Tenant, payload: RequestInvokeTTS):
  100. """
  101. invoke tts
  102. """
  103. model_instance = ModelManager().get_model_instance(
  104. tenant_id=tenant.id,
  105. provider=payload.provider,
  106. model_type=payload.model_type,
  107. model=payload.model,
  108. )
  109. # invoke model
  110. response = model_instance.invoke_tts(
  111. content_text=payload.content_text,
  112. tenant_id=tenant.id,
  113. voice=payload.voice,
  114. user=user_id,
  115. )
  116. def handle() -> Generator[dict, None, None]:
  117. for chunk in response:
  118. yield {"result": hexlify(chunk).decode("utf-8")}
  119. return handle()
  120. @classmethod
  121. def invoke_speech2text(cls, user_id: str, tenant: Tenant, payload: RequestInvokeSpeech2Text):
  122. """
  123. invoke speech2text
  124. """
  125. model_instance = ModelManager().get_model_instance(
  126. tenant_id=tenant.id,
  127. provider=payload.provider,
  128. model_type=payload.model_type,
  129. model=payload.model,
  130. )
  131. # invoke model
  132. with tempfile.NamedTemporaryFile(suffix=".mp3", mode="wb", delete=True) as temp:
  133. temp.write(unhexlify(payload.file))
  134. temp.flush()
  135. temp.seek(0)
  136. response = model_instance.invoke_speech2text(
  137. file=temp,
  138. user=user_id,
  139. )
  140. return {
  141. "result": response,
  142. }
  143. @classmethod
  144. def invoke_moderation(cls, user_id: str, tenant: Tenant, payload: RequestInvokeModeration):
  145. """
  146. invoke moderation
  147. """
  148. model_instance = ModelManager().get_model_instance(
  149. tenant_id=tenant.id,
  150. provider=payload.provider,
  151. model_type=payload.model_type,
  152. model=payload.model,
  153. )
  154. # invoke model
  155. response = model_instance.invoke_moderation(
  156. text=payload.text,
  157. user=user_id,
  158. )
  159. return {
  160. "result": response,
  161. }
  162. @classmethod
  163. def get_system_model_max_tokens(cls, tenant_id: str) -> int:
  164. """
  165. get system model max tokens
  166. """
  167. return ModelInvocationUtils.get_max_llm_context_tokens(tenant_id=tenant_id)
  168. @classmethod
  169. def get_prompt_tokens(cls, tenant_id: str, prompt_messages: list[PromptMessage]) -> int:
  170. """
  171. get prompt tokens
  172. """
  173. return ModelInvocationUtils.calculate_tokens(tenant_id=tenant_id, prompt_messages=prompt_messages)
  174. @classmethod
  175. def invoke_system_model(
  176. cls,
  177. user_id: str,
  178. tenant: Tenant,
  179. prompt_messages: list[PromptMessage],
  180. ) -> LLMResult:
  181. """
  182. invoke system model
  183. """
  184. return ModelInvocationUtils.invoke(
  185. user_id=user_id,
  186. tenant_id=tenant.id,
  187. tool_type=ToolProviderType.PLUGIN,
  188. tool_name="plugin",
  189. prompt_messages=prompt_messages,
  190. )
  191. @classmethod
  192. def invoke_summary(cls, user_id: str, tenant: Tenant, payload: RequestInvokeSummary):
  193. """
  194. invoke summary
  195. """
  196. max_tokens = cls.get_system_model_max_tokens(tenant_id=tenant.id)
  197. content = payload.text
  198. SUMMARY_PROMPT = """You are a professional language researcher, you are interested in the language
  199. and you can quickly aimed at the main point of an webpage and reproduce it in your own words but
  200. retain the original meaning and keep the key points.
  201. however, the text you got is too long, what you got is possible a part of the text.
  202. Please summarize the text you got.
  203. Here is the extra instruction you need to follow:
  204. <extra_instruction>
  205. {payload.instruction}
  206. </extra_instruction>
  207. """
  208. if (
  209. cls.get_prompt_tokens(
  210. tenant_id=tenant.id,
  211. prompt_messages=[UserPromptMessage(content=content)],
  212. )
  213. < max_tokens * 0.6
  214. ):
  215. return content
  216. def get_prompt_tokens(content: str) -> int:
  217. return cls.get_prompt_tokens(
  218. tenant_id=tenant.id,
  219. prompt_messages=[
  220. SystemPromptMessage(content=SUMMARY_PROMPT.replace("{payload.instruction}", payload.instruction)),
  221. UserPromptMessage(content=content),
  222. ],
  223. )
  224. def summarize(content: str) -> str:
  225. summary = cls.invoke_system_model(
  226. user_id=user_id,
  227. tenant=tenant,
  228. prompt_messages=[
  229. SystemPromptMessage(content=SUMMARY_PROMPT.replace("{payload.instruction}", payload.instruction)),
  230. UserPromptMessage(content=content),
  231. ],
  232. )
  233. assert isinstance(summary.message.content, str)
  234. return summary.message.content
  235. lines = content.split("\n")
  236. new_lines: list[str] = []
  237. # split long line into multiple lines
  238. for i in range(len(lines)):
  239. line = lines[i]
  240. if not line.strip():
  241. continue
  242. if len(line) < max_tokens * 0.5:
  243. new_lines.append(line)
  244. elif get_prompt_tokens(line) > max_tokens * 0.7:
  245. while get_prompt_tokens(line) > max_tokens * 0.7:
  246. new_lines.append(line[: int(max_tokens * 0.5)])
  247. line = line[int(max_tokens * 0.5) :]
  248. new_lines.append(line)
  249. else:
  250. new_lines.append(line)
  251. # merge lines into messages with max tokens
  252. messages: list[str] = []
  253. for i in new_lines: # type: ignore
  254. if len(messages) == 0:
  255. messages.append(i) # type: ignore
  256. else:
  257. if len(messages[-1]) + len(i) < max_tokens * 0.5: # type: ignore
  258. messages[-1] += i # type: ignore
  259. if get_prompt_tokens(messages[-1] + i) > max_tokens * 0.7: # type: ignore
  260. messages.append(i) # type: ignore
  261. else:
  262. messages[-1] += i # type: ignore
  263. summaries = []
  264. for i in range(len(messages)):
  265. message = messages[i]
  266. summary = summarize(message)
  267. summaries.append(summary)
  268. result = "\n".join(summaries)
  269. if (
  270. cls.get_prompt_tokens(
  271. tenant_id=tenant.id,
  272. prompt_messages=[UserPromptMessage(content=result)],
  273. )
  274. > max_tokens * 0.7
  275. ):
  276. return cls.invoke_summary(
  277. user_id=user_id,
  278. tenant=tenant,
  279. payload=RequestInvokeSummary(text=result, instruction=payload.instruction),
  280. )
  281. return result