model.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532
  1. import binascii
  2. from collections.abc import Generator, Sequence
  3. from typing import IO, Optional
  4. from core.model_runtime.entities.llm_entities import LLMResultChunk
  5. from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool
  6. from core.model_runtime.entities.model_entities import AIModelEntity
  7. from core.model_runtime.entities.rerank_entities import RerankResult
  8. from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult
  9. from core.model_runtime.utils.encoders import jsonable_encoder
  10. from core.plugin.entities.plugin_daemon import (
  11. PluginBasicBooleanResponse,
  12. PluginDaemonInnerError,
  13. PluginLLMNumTokensResponse,
  14. PluginModelProviderEntity,
  15. PluginModelSchemaEntity,
  16. PluginStringResultResponse,
  17. PluginTextEmbeddingNumTokensResponse,
  18. PluginVoicesResponse,
  19. )
  20. from core.plugin.manager.base import BasePluginManager
  21. class PluginModelManager(BasePluginManager):
  22. def fetch_model_providers(self, tenant_id: str) -> Sequence[PluginModelProviderEntity]:
  23. """
  24. Fetch model providers for the given tenant.
  25. """
  26. response = self._request_with_plugin_daemon_response(
  27. "GET",
  28. f"plugin/{tenant_id}/management/models",
  29. list[PluginModelProviderEntity],
  30. params={"page": 1, "page_size": 256},
  31. )
  32. return response
  33. def get_model_schema(
  34. self,
  35. tenant_id: str,
  36. user_id: str,
  37. plugin_id: str,
  38. provider: str,
  39. model_type: str,
  40. model: str,
  41. credentials: dict,
  42. ) -> AIModelEntity | None:
  43. """
  44. Get model schema
  45. """
  46. response = self._request_with_plugin_daemon_response_stream(
  47. "POST",
  48. f"plugin/{tenant_id}/dispatch/model/schema",
  49. PluginModelSchemaEntity,
  50. data={
  51. "user_id": user_id,
  52. "data": {
  53. "provider": provider,
  54. "model_type": model_type,
  55. "model": model,
  56. "credentials": credentials,
  57. },
  58. },
  59. headers={
  60. "X-Plugin-ID": plugin_id,
  61. "Content-Type": "application/json",
  62. },
  63. )
  64. for resp in response:
  65. return resp.model_schema
  66. return None
  67. def validate_provider_credentials(
  68. self, tenant_id: str, user_id: str, plugin_id: str, provider: str, credentials: dict
  69. ) -> bool:
  70. """
  71. validate the credentials of the provider
  72. """
  73. response = self._request_with_plugin_daemon_response_stream(
  74. "POST",
  75. f"plugin/{tenant_id}/dispatch/model/validate_provider_credentials",
  76. PluginBasicBooleanResponse,
  77. data={
  78. "user_id": user_id,
  79. "data": {
  80. "provider": provider,
  81. "credentials": credentials,
  82. },
  83. },
  84. headers={
  85. "X-Plugin-ID": plugin_id,
  86. "Content-Type": "application/json",
  87. },
  88. )
  89. for resp in response:
  90. if resp.credentials and isinstance(resp.credentials, dict):
  91. credentials.update(resp.credentials)
  92. return resp.result
  93. return False
  94. def validate_model_credentials(
  95. self,
  96. tenant_id: str,
  97. user_id: str,
  98. plugin_id: str,
  99. provider: str,
  100. model_type: str,
  101. model: str,
  102. credentials: dict,
  103. ) -> bool:
  104. """
  105. validate the credentials of the provider
  106. """
  107. response = self._request_with_plugin_daemon_response_stream(
  108. "POST",
  109. f"plugin/{tenant_id}/dispatch/model/validate_model_credentials",
  110. PluginBasicBooleanResponse,
  111. data={
  112. "user_id": user_id,
  113. "data": {
  114. "provider": provider,
  115. "model_type": model_type,
  116. "model": model,
  117. "credentials": credentials,
  118. },
  119. },
  120. headers={
  121. "X-Plugin-ID": plugin_id,
  122. "Content-Type": "application/json",
  123. },
  124. )
  125. for resp in response:
  126. if resp.credentials and isinstance(resp.credentials, dict):
  127. credentials.update(resp.credentials)
  128. return resp.result
  129. return False
  130. def invoke_llm(
  131. self,
  132. tenant_id: str,
  133. user_id: str,
  134. plugin_id: str,
  135. provider: str,
  136. model: str,
  137. credentials: dict,
  138. prompt_messages: list[PromptMessage],
  139. model_parameters: Optional[dict] = None,
  140. tools: Optional[list[PromptMessageTool]] = None,
  141. stop: Optional[list[str]] = None,
  142. stream: bool = True,
  143. ) -> Generator[LLMResultChunk, None, None]:
  144. """
  145. Invoke llm
  146. """
  147. response = self._request_with_plugin_daemon_response_stream(
  148. method="POST",
  149. path=f"plugin/{tenant_id}/dispatch/llm/invoke",
  150. type=LLMResultChunk,
  151. data=jsonable_encoder(
  152. {
  153. "user_id": user_id,
  154. "data": {
  155. "provider": provider,
  156. "model_type": "llm",
  157. "model": model,
  158. "credentials": credentials,
  159. "prompt_messages": prompt_messages,
  160. "model_parameters": model_parameters,
  161. "tools": tools,
  162. "stop": stop,
  163. "stream": stream,
  164. },
  165. }
  166. ),
  167. headers={
  168. "X-Plugin-ID": plugin_id,
  169. "Content-Type": "application/json",
  170. },
  171. )
  172. try:
  173. yield from response
  174. except PluginDaemonInnerError as e:
  175. raise ValueError(e.message + str(e.code))
  176. def get_llm_num_tokens(
  177. self,
  178. tenant_id: str,
  179. user_id: str,
  180. plugin_id: str,
  181. provider: str,
  182. model_type: str,
  183. model: str,
  184. credentials: dict,
  185. prompt_messages: list[PromptMessage],
  186. tools: Optional[list[PromptMessageTool]] = None,
  187. ) -> int:
  188. """
  189. Get number of tokens for llm
  190. """
  191. response = self._request_with_plugin_daemon_response_stream(
  192. method="POST",
  193. path=f"plugin/{tenant_id}/dispatch/llm/num_tokens",
  194. type=PluginLLMNumTokensResponse,
  195. data=jsonable_encoder(
  196. {
  197. "user_id": user_id,
  198. "data": {
  199. "provider": provider,
  200. "model_type": model_type,
  201. "model": model,
  202. "credentials": credentials,
  203. "prompt_messages": prompt_messages,
  204. "tools": tools,
  205. },
  206. }
  207. ),
  208. headers={
  209. "X-Plugin-ID": plugin_id,
  210. "Content-Type": "application/json",
  211. },
  212. )
  213. for resp in response:
  214. return resp.num_tokens
  215. return 0
  216. def invoke_text_embedding(
  217. self,
  218. tenant_id: str,
  219. user_id: str,
  220. plugin_id: str,
  221. provider: str,
  222. model: str,
  223. credentials: dict,
  224. texts: list[str],
  225. input_type: str,
  226. ) -> TextEmbeddingResult:
  227. """
  228. Invoke text embedding
  229. """
  230. response = self._request_with_plugin_daemon_response_stream(
  231. method="POST",
  232. path=f"plugin/{tenant_id}/dispatch/text_embedding/invoke",
  233. type=TextEmbeddingResult,
  234. data=jsonable_encoder(
  235. {
  236. "user_id": user_id,
  237. "data": {
  238. "provider": provider,
  239. "model_type": "text-embedding",
  240. "model": model,
  241. "credentials": credentials,
  242. "texts": texts,
  243. "input_type": input_type,
  244. },
  245. }
  246. ),
  247. headers={
  248. "X-Plugin-ID": plugin_id,
  249. "Content-Type": "application/json",
  250. },
  251. )
  252. for resp in response:
  253. return resp
  254. raise ValueError("Failed to invoke text embedding")
  255. def get_text_embedding_num_tokens(
  256. self,
  257. tenant_id: str,
  258. user_id: str,
  259. plugin_id: str,
  260. provider: str,
  261. model: str,
  262. credentials: dict,
  263. texts: list[str],
  264. ) -> list[int]:
  265. """
  266. Get number of tokens for text embedding
  267. """
  268. response = self._request_with_plugin_daemon_response_stream(
  269. method="POST",
  270. path=f"plugin/{tenant_id}/dispatch/text_embedding/num_tokens",
  271. type=PluginTextEmbeddingNumTokensResponse,
  272. data=jsonable_encoder(
  273. {
  274. "user_id": user_id,
  275. "data": {
  276. "provider": provider,
  277. "model_type": "text-embedding",
  278. "model": model,
  279. "credentials": credentials,
  280. "texts": texts,
  281. },
  282. }
  283. ),
  284. headers={
  285. "X-Plugin-ID": plugin_id,
  286. "Content-Type": "application/json",
  287. },
  288. )
  289. for resp in response:
  290. return resp.num_tokens
  291. return []
  292. def invoke_rerank(
  293. self,
  294. tenant_id: str,
  295. user_id: str,
  296. plugin_id: str,
  297. provider: str,
  298. model: str,
  299. credentials: dict,
  300. query: str,
  301. docs: list[str],
  302. score_threshold: Optional[float] = None,
  303. top_n: Optional[int] = None,
  304. ) -> RerankResult:
  305. """
  306. Invoke rerank
  307. """
  308. response = self._request_with_plugin_daemon_response_stream(
  309. method="POST",
  310. path=f"plugin/{tenant_id}/dispatch/rerank/invoke",
  311. type=RerankResult,
  312. data=jsonable_encoder(
  313. {
  314. "user_id": user_id,
  315. "data": {
  316. "provider": provider,
  317. "model_type": "rerank",
  318. "model": model,
  319. "credentials": credentials,
  320. "query": query,
  321. "docs": docs,
  322. "score_threshold": score_threshold,
  323. "top_n": top_n,
  324. },
  325. }
  326. ),
  327. headers={
  328. "X-Plugin-ID": plugin_id,
  329. "Content-Type": "application/json",
  330. },
  331. )
  332. for resp in response:
  333. return resp
  334. raise ValueError("Failed to invoke rerank")
  335. def invoke_tts(
  336. self,
  337. tenant_id: str,
  338. user_id: str,
  339. plugin_id: str,
  340. provider: str,
  341. model: str,
  342. credentials: dict,
  343. content_text: str,
  344. voice: str,
  345. ) -> Generator[bytes, None, None]:
  346. """
  347. Invoke tts
  348. """
  349. response = self._request_with_plugin_daemon_response_stream(
  350. method="POST",
  351. path=f"plugin/{tenant_id}/dispatch/tts/invoke",
  352. type=PluginStringResultResponse,
  353. data=jsonable_encoder(
  354. {
  355. "user_id": user_id,
  356. "data": {
  357. "provider": provider,
  358. "model_type": "tts",
  359. "model": model,
  360. "credentials": credentials,
  361. "tenant_id": tenant_id,
  362. "content_text": content_text,
  363. "voice": voice,
  364. },
  365. }
  366. ),
  367. headers={
  368. "X-Plugin-ID": plugin_id,
  369. "Content-Type": "application/json",
  370. },
  371. )
  372. try:
  373. for result in response:
  374. hex_str = result.result
  375. yield binascii.unhexlify(hex_str)
  376. except PluginDaemonInnerError as e:
  377. raise ValueError(e.message + str(e.code))
  378. def get_tts_model_voices(
  379. self,
  380. tenant_id: str,
  381. user_id: str,
  382. plugin_id: str,
  383. provider: str,
  384. model: str,
  385. credentials: dict,
  386. language: Optional[str] = None,
  387. ) -> list[dict]:
  388. """
  389. Get tts model voices
  390. """
  391. response = self._request_with_plugin_daemon_response_stream(
  392. method="POST",
  393. path=f"plugin/{tenant_id}/dispatch/tts/model/voices",
  394. type=PluginVoicesResponse,
  395. data=jsonable_encoder(
  396. {
  397. "user_id": user_id,
  398. "data": {
  399. "provider": provider,
  400. "model_type": "tts",
  401. "model": model,
  402. "credentials": credentials,
  403. "language": language,
  404. },
  405. }
  406. ),
  407. headers={
  408. "X-Plugin-ID": plugin_id,
  409. "Content-Type": "application/json",
  410. },
  411. )
  412. for resp in response:
  413. voices = []
  414. for voice in resp.voices:
  415. voices.append({"name": voice.name, "value": voice.value})
  416. return voices
  417. return []
  418. def invoke_speech_to_text(
  419. self,
  420. tenant_id: str,
  421. user_id: str,
  422. plugin_id: str,
  423. provider: str,
  424. model: str,
  425. credentials: dict,
  426. file: IO[bytes],
  427. ) -> str:
  428. """
  429. Invoke speech to text
  430. """
  431. response = self._request_with_plugin_daemon_response_stream(
  432. method="POST",
  433. path=f"plugin/{tenant_id}/dispatch/speech2text/invoke",
  434. type=PluginStringResultResponse,
  435. data=jsonable_encoder(
  436. {
  437. "user_id": user_id,
  438. "data": {
  439. "provider": provider,
  440. "model_type": "speech2text",
  441. "model": model,
  442. "credentials": credentials,
  443. "file": binascii.hexlify(file.read()).decode(),
  444. },
  445. }
  446. ),
  447. headers={
  448. "X-Plugin-ID": plugin_id,
  449. "Content-Type": "application/json",
  450. },
  451. )
  452. for resp in response:
  453. return resp.result
  454. raise ValueError("Failed to invoke speech to text")
  455. def invoke_moderation(
  456. self,
  457. tenant_id: str,
  458. user_id: str,
  459. plugin_id: str,
  460. provider: str,
  461. model: str,
  462. credentials: dict,
  463. text: str,
  464. ) -> bool:
  465. """
  466. Invoke moderation
  467. """
  468. response = self._request_with_plugin_daemon_response_stream(
  469. method="POST",
  470. path=f"plugin/{tenant_id}/dispatch/moderation/invoke",
  471. type=PluginBasicBooleanResponse,
  472. data=jsonable_encoder(
  473. {
  474. "user_id": user_id,
  475. "data": {
  476. "provider": provider,
  477. "model_type": "moderation",
  478. "model": model,
  479. "credentials": credentials,
  480. "text": text,
  481. },
  482. }
  483. ),
  484. headers={
  485. "X-Plugin-ID": plugin_id,
  486. "Content-Type": "application/json",
  487. },
  488. )
  489. for resp in response:
  490. return resp.result
  491. raise ValueError("Failed to invoke moderation")