model.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525
  1. import binascii
  2. from collections.abc import Generator, Sequence
  3. from typing import IO, Optional
  4. from core.model_runtime.entities.llm_entities import LLMResultChunk
  5. from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool
  6. from core.model_runtime.entities.model_entities import AIModelEntity
  7. from core.model_runtime.entities.rerank_entities import RerankResult
  8. from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult
  9. from core.model_runtime.utils.encoders import jsonable_encoder
  10. from core.plugin.entities.plugin_daemon import (
  11. PluginBasicBooleanResponse,
  12. PluginDaemonInnerError,
  13. PluginLLMNumTokensResponse,
  14. PluginModelProviderEntity,
  15. PluginModelSchemaEntity,
  16. PluginStringResultResponse,
  17. PluginTextEmbeddingNumTokensResponse,
  18. PluginVoicesResponse,
  19. )
  20. from core.plugin.manager.base import BasePluginManager
  21. class PluginModelManager(BasePluginManager):
  22. def fetch_model_providers(self, tenant_id: str) -> Sequence[PluginModelProviderEntity]:
  23. """
  24. Fetch model providers for the given tenant.
  25. """
  26. response = self._request_with_plugin_daemon_response(
  27. "GET",
  28. f"plugin/{tenant_id}/management/models",
  29. list[PluginModelProviderEntity],
  30. params={"page": 1, "page_size": 256},
  31. )
  32. return response
  33. def get_model_schema(
  34. self,
  35. tenant_id: str,
  36. user_id: str,
  37. plugin_id: str,
  38. provider: str,
  39. model_type: str,
  40. model: str,
  41. credentials: dict,
  42. ) -> AIModelEntity | None:
  43. """
  44. Get model schema
  45. """
  46. response = self._request_with_plugin_daemon_response_stream(
  47. "POST",
  48. f"plugin/{tenant_id}/dispatch/model/schema",
  49. PluginModelSchemaEntity,
  50. data={
  51. "user_id": user_id,
  52. "data": {
  53. "provider": provider,
  54. "model_type": model_type,
  55. "model": model,
  56. "credentials": credentials,
  57. },
  58. },
  59. headers={
  60. "X-Plugin-ID": plugin_id,
  61. "Content-Type": "application/json",
  62. },
  63. )
  64. for resp in response:
  65. return resp.model_schema
  66. return None
  67. def validate_provider_credentials(
  68. self, tenant_id: str, user_id: str, plugin_id: str, provider: str, credentials: dict
  69. ) -> bool:
  70. """
  71. validate the credentials of the provider
  72. """
  73. response = self._request_with_plugin_daemon_response_stream(
  74. "POST",
  75. f"plugin/{tenant_id}/dispatch/model/validate_provider_credentials",
  76. PluginBasicBooleanResponse,
  77. data={
  78. "user_id": user_id,
  79. "data": {
  80. "provider": provider,
  81. "credentials": credentials,
  82. },
  83. },
  84. headers={
  85. "X-Plugin-ID": plugin_id,
  86. "Content-Type": "application/json",
  87. },
  88. )
  89. for resp in response:
  90. return resp.result
  91. return False
  92. def validate_model_credentials(
  93. self,
  94. tenant_id: str,
  95. user_id: str,
  96. plugin_id: str,
  97. provider: str,
  98. model_type: str,
  99. model: str,
  100. credentials: dict,
  101. ) -> bool:
  102. """
  103. validate the credentials of the provider
  104. """
  105. response = self._request_with_plugin_daemon_response_stream(
  106. "POST",
  107. f"plugin/{tenant_id}/dispatch/model/validate_model_credentials",
  108. PluginBasicBooleanResponse,
  109. data={
  110. "user_id": user_id,
  111. "data": {
  112. "provider": provider,
  113. "model_type": model_type,
  114. "model": model,
  115. "credentials": credentials,
  116. },
  117. },
  118. headers={
  119. "X-Plugin-ID": plugin_id,
  120. "Content-Type": "application/json",
  121. },
  122. )
  123. for resp in response:
  124. return resp.result
  125. return False
  126. def invoke_llm(
  127. self,
  128. tenant_id: str,
  129. user_id: str,
  130. plugin_id: str,
  131. provider: str,
  132. model: str,
  133. credentials: dict,
  134. prompt_messages: list[PromptMessage],
  135. model_parameters: Optional[dict] = None,
  136. tools: Optional[list[PromptMessageTool]] = None,
  137. stop: Optional[list[str]] = None,
  138. stream: bool = True,
  139. ) -> Generator[LLMResultChunk, None, None]:
  140. """
  141. Invoke llm
  142. """
  143. response = self._request_with_plugin_daemon_response_stream(
  144. method="POST",
  145. path=f"plugin/{tenant_id}/dispatch/llm/invoke",
  146. type=LLMResultChunk,
  147. data=jsonable_encoder(
  148. {
  149. "user_id": user_id,
  150. "data": {
  151. "provider": provider,
  152. "model_type": "llm",
  153. "model": model,
  154. "credentials": credentials,
  155. "prompt_messages": prompt_messages,
  156. "model_parameters": model_parameters,
  157. "tools": tools,
  158. "stop": stop,
  159. "stream": stream,
  160. },
  161. }
  162. ),
  163. headers={
  164. "X-Plugin-ID": plugin_id,
  165. "Content-Type": "application/json",
  166. },
  167. )
  168. try:
  169. yield from response
  170. except PluginDaemonInnerError as e:
  171. raise ValueError(e.message + str(e.code))
  172. def get_llm_num_tokens(
  173. self,
  174. tenant_id: str,
  175. user_id: str,
  176. plugin_id: str,
  177. provider: str,
  178. model_type: str,
  179. model: str,
  180. credentials: dict,
  181. prompt_messages: list[PromptMessage],
  182. tools: Optional[list[PromptMessageTool]] = None,
  183. ) -> int:
  184. """
  185. Get number of tokens for llm
  186. """
  187. response = self._request_with_plugin_daemon_response_stream(
  188. method="POST",
  189. path=f"plugin/{tenant_id}/dispatch/llm/num_tokens",
  190. type=PluginLLMNumTokensResponse,
  191. data=jsonable_encoder(
  192. {
  193. "user_id": user_id,
  194. "data": {
  195. "provider": provider,
  196. "model_type": model_type,
  197. "model": model,
  198. "credentials": credentials,
  199. "prompt_messages": prompt_messages,
  200. "tools": tools,
  201. },
  202. }
  203. ),
  204. headers={
  205. "X-Plugin-ID": plugin_id,
  206. "Content-Type": "application/json",
  207. },
  208. )
  209. for resp in response:
  210. return resp.num_tokens
  211. return 0
  212. def invoke_text_embedding(
  213. self,
  214. tenant_id: str,
  215. user_id: str,
  216. plugin_id: str,
  217. provider: str,
  218. model: str,
  219. credentials: dict,
  220. texts: list[str],
  221. input_type: str,
  222. ) -> TextEmbeddingResult:
  223. """
  224. Invoke text embedding
  225. """
  226. response = self._request_with_plugin_daemon_response_stream(
  227. method="POST",
  228. path=f"plugin/{tenant_id}/dispatch/text_embedding/invoke",
  229. type=TextEmbeddingResult,
  230. data=jsonable_encoder(
  231. {
  232. "user_id": user_id,
  233. "data": {
  234. "provider": provider,
  235. "model_type": "text-embedding",
  236. "model": model,
  237. "credentials": credentials,
  238. "texts": texts,
  239. "input_type": input_type,
  240. },
  241. }
  242. ),
  243. headers={
  244. "X-Plugin-ID": plugin_id,
  245. "Content-Type": "application/json",
  246. },
  247. )
  248. for resp in response:
  249. return resp
  250. raise ValueError("Failed to invoke text embedding")
  251. def get_text_embedding_num_tokens(
  252. self,
  253. tenant_id: str,
  254. user_id: str,
  255. plugin_id: str,
  256. provider: str,
  257. model: str,
  258. credentials: dict,
  259. texts: list[str],
  260. ) -> list[int]:
  261. """
  262. Get number of tokens for text embedding
  263. """
  264. response = self._request_with_plugin_daemon_response_stream(
  265. method="POST",
  266. path=f"plugin/{tenant_id}/dispatch/text_embedding/num_tokens",
  267. type=PluginTextEmbeddingNumTokensResponse,
  268. data=jsonable_encoder(
  269. {
  270. "user_id": user_id,
  271. "data": {
  272. "provider": provider,
  273. "model_type": "text-embedding",
  274. "model": model,
  275. "credentials": credentials,
  276. "texts": texts,
  277. },
  278. }
  279. ),
  280. headers={
  281. "X-Plugin-ID": plugin_id,
  282. "Content-Type": "application/json",
  283. },
  284. )
  285. for resp in response:
  286. return resp.num_tokens
  287. return []
  288. def invoke_rerank(
  289. self,
  290. tenant_id: str,
  291. user_id: str,
  292. plugin_id: str,
  293. provider: str,
  294. model: str,
  295. credentials: dict,
  296. query: str,
  297. docs: list[str],
  298. score_threshold: Optional[float] = None,
  299. top_n: Optional[int] = None,
  300. ) -> RerankResult:
  301. """
  302. Invoke rerank
  303. """
  304. response = self._request_with_plugin_daemon_response_stream(
  305. method="POST",
  306. path=f"plugin/{tenant_id}/dispatch/rerank/invoke",
  307. type=RerankResult,
  308. data=jsonable_encoder(
  309. {
  310. "user_id": user_id,
  311. "data": {
  312. "provider": provider,
  313. "model_type": "rerank",
  314. "model": model,
  315. "credentials": credentials,
  316. "query": query,
  317. "docs": docs,
  318. "score_threshold": score_threshold,
  319. "top_n": top_n,
  320. },
  321. }
  322. ),
  323. headers={
  324. "X-Plugin-ID": plugin_id,
  325. "Content-Type": "application/json",
  326. },
  327. )
  328. for resp in response:
  329. return resp
  330. raise ValueError("Failed to invoke rerank")
  331. def invoke_tts(
  332. self,
  333. tenant_id: str,
  334. user_id: str,
  335. plugin_id: str,
  336. provider: str,
  337. model: str,
  338. credentials: dict,
  339. content_text: str,
  340. voice: str,
  341. ) -> Generator[bytes, None, None]:
  342. """
  343. Invoke tts
  344. """
  345. response = self._request_with_plugin_daemon_response_stream(
  346. method="POST",
  347. path=f"plugin/{tenant_id}/dispatch/tts/invoke",
  348. type=PluginStringResultResponse,
  349. data=jsonable_encoder(
  350. {
  351. "user_id": user_id,
  352. "data": {
  353. "provider": provider,
  354. "model_type": "tts",
  355. "model": model,
  356. "credentials": credentials,
  357. "content_text": content_text,
  358. "voice": voice,
  359. },
  360. }
  361. ),
  362. headers={
  363. "X-Plugin-ID": plugin_id,
  364. "Content-Type": "application/json",
  365. },
  366. )
  367. try:
  368. for result in response:
  369. hex_str = result.result
  370. yield binascii.unhexlify(hex_str)
  371. except PluginDaemonInnerError as e:
  372. raise ValueError(e.message + str(e.code))
  373. def get_tts_model_voices(
  374. self,
  375. tenant_id: str,
  376. user_id: str,
  377. plugin_id: str,
  378. provider: str,
  379. model: str,
  380. credentials: dict,
  381. language: Optional[str] = None,
  382. ) -> list[dict]:
  383. """
  384. Get tts model voices
  385. """
  386. response = self._request_with_plugin_daemon_response_stream(
  387. method="POST",
  388. path=f"plugin/{tenant_id}/dispatch/tts/model/voices",
  389. type=PluginVoicesResponse,
  390. data=jsonable_encoder(
  391. {
  392. "user_id": user_id,
  393. "data": {
  394. "provider": provider,
  395. "model_type": "tts",
  396. "model": model,
  397. "credentials": credentials,
  398. "language": language,
  399. },
  400. }
  401. ),
  402. headers={
  403. "X-Plugin-ID": plugin_id,
  404. "Content-Type": "application/json",
  405. },
  406. )
  407. for resp in response:
  408. voices = []
  409. for voice in resp.voices:
  410. voices.append({"name": voice.name, "value": voice.value})
  411. return voices
  412. return []
  413. def invoke_speech_to_text(
  414. self,
  415. tenant_id: str,
  416. user_id: str,
  417. plugin_id: str,
  418. provider: str,
  419. model: str,
  420. credentials: dict,
  421. file: IO[bytes],
  422. ) -> str:
  423. """
  424. Invoke speech to text
  425. """
  426. response = self._request_with_plugin_daemon_response_stream(
  427. method="POST",
  428. path=f"plugin/{tenant_id}/dispatch/speech2text/invoke",
  429. type=PluginStringResultResponse,
  430. data=jsonable_encoder(
  431. {
  432. "user_id": user_id,
  433. "data": {
  434. "provider": provider,
  435. "model_type": "speech2text",
  436. "model": model,
  437. "credentials": credentials,
  438. "file": binascii.hexlify(file.read()).decode(),
  439. },
  440. }
  441. ),
  442. headers={
  443. "X-Plugin-ID": plugin_id,
  444. "Content-Type": "application/json",
  445. },
  446. )
  447. for resp in response:
  448. return resp.result
  449. raise ValueError("Failed to invoke speech to text")
  450. def invoke_moderation(
  451. self,
  452. tenant_id: str,
  453. user_id: str,
  454. plugin_id: str,
  455. provider: str,
  456. model: str,
  457. credentials: dict,
  458. text: str,
  459. ) -> bool:
  460. """
  461. Invoke moderation
  462. """
  463. response = self._request_with_plugin_daemon_response_stream(
  464. method="POST",
  465. path=f"plugin/{tenant_id}/dispatch/moderation/invoke",
  466. type=PluginBasicBooleanResponse,
  467. data=jsonable_encoder(
  468. {
  469. "user_id": user_id,
  470. "data": {
  471. "provider": provider,
  472. "model_type": "moderation",
  473. "model": model,
  474. "credentials": credentials,
  475. "text": text,
  476. },
  477. }
  478. ),
  479. headers={
  480. "X-Plugin-ID": plugin_id,
  481. "Content-Type": "application/json",
  482. },
  483. )
  484. for resp in response:
  485. return resp.result
  486. raise ValueError("Failed to invoke moderation")