model.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525
  1. import binascii
  2. from collections.abc import Generator, Sequence
  3. from typing import IO, Optional
  4. from core.model_runtime.entities.llm_entities import LLMResultChunk
  5. from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool
  6. from core.model_runtime.entities.model_entities import AIModelEntity
  7. from core.model_runtime.entities.rerank_entities import RerankResult
  8. from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult
  9. from core.model_runtime.utils.encoders import jsonable_encoder
  10. from core.plugin.entities.plugin_daemon import (
  11. PluginBasicBooleanResponse,
  12. PluginDaemonInnerError,
  13. PluginModelProviderEntity,
  14. PluginModelSchemaEntity,
  15. PluginNumTokensResponse,
  16. PluginStringResultResponse,
  17. PluginVoicesResponse,
  18. )
  19. from core.plugin.manager.base import BasePluginManager
  20. class PluginModelManager(BasePluginManager):
  21. def fetch_model_providers(self, tenant_id: str) -> Sequence[PluginModelProviderEntity]:
  22. """
  23. Fetch model providers for the given tenant.
  24. """
  25. response = self._request_with_plugin_daemon_response(
  26. "GET",
  27. f"plugin/{tenant_id}/management/models",
  28. list[PluginModelProviderEntity],
  29. params={"page": 1, "page_size": 256},
  30. )
  31. return response
  32. def get_model_schema(
  33. self,
  34. tenant_id: str,
  35. user_id: str,
  36. plugin_id: str,
  37. provider: str,
  38. model_type: str,
  39. model: str,
  40. credentials: dict,
  41. ) -> AIModelEntity | None:
  42. """
  43. Get model schema
  44. """
  45. response = self._request_with_plugin_daemon_response_stream(
  46. "POST",
  47. f"plugin/{tenant_id}/dispatch/model/schema",
  48. PluginModelSchemaEntity,
  49. data={
  50. "user_id": user_id,
  51. "data": {
  52. "provider": provider,
  53. "model_type": model_type,
  54. "model": model,
  55. "credentials": credentials,
  56. },
  57. },
  58. headers={
  59. "X-Plugin-ID": plugin_id,
  60. "Content-Type": "application/json",
  61. },
  62. )
  63. for resp in response:
  64. return resp.model_schema
  65. return None
  66. def validate_provider_credentials(
  67. self, tenant_id: str, user_id: str, plugin_id: str, provider: str, credentials: dict
  68. ) -> bool:
  69. """
  70. validate the credentials of the provider
  71. """
  72. response = self._request_with_plugin_daemon_response_stream(
  73. "POST",
  74. f"plugin/{tenant_id}/dispatch/model/validate_provider_credentials",
  75. PluginBasicBooleanResponse,
  76. data={
  77. "user_id": user_id,
  78. "data": {
  79. "provider": provider,
  80. "credentials": credentials,
  81. },
  82. },
  83. headers={
  84. "X-Plugin-ID": plugin_id,
  85. "Content-Type": "application/json",
  86. },
  87. )
  88. for resp in response:
  89. return resp.result
  90. return False
  91. def validate_model_credentials(
  92. self,
  93. tenant_id: str,
  94. user_id: str,
  95. plugin_id: str,
  96. provider: str,
  97. model_type: str,
  98. model: str,
  99. credentials: dict,
  100. ) -> bool:
  101. """
  102. validate the credentials of the provider
  103. """
  104. response = self._request_with_plugin_daemon_response_stream(
  105. "POST",
  106. f"plugin/{tenant_id}/dispatch/model/validate_model_credentials",
  107. PluginBasicBooleanResponse,
  108. data={
  109. "user_id": user_id,
  110. "data": {
  111. "provider": provider,
  112. "model_type": model_type,
  113. "model": model,
  114. "credentials": credentials,
  115. },
  116. },
  117. headers={
  118. "X-Plugin-ID": plugin_id,
  119. "Content-Type": "application/json",
  120. },
  121. )
  122. for resp in response:
  123. return resp.result
  124. return False
  125. def invoke_llm(
  126. self,
  127. tenant_id: str,
  128. user_id: str,
  129. plugin_id: str,
  130. provider: str,
  131. model: str,
  132. credentials: dict,
  133. prompt_messages: list[PromptMessage],
  134. model_parameters: Optional[dict] = None,
  135. tools: Optional[list[PromptMessageTool]] = None,
  136. stop: Optional[list[str]] = None,
  137. stream: bool = True,
  138. ) -> Generator[LLMResultChunk, None, None]:
  139. """
  140. Invoke llm
  141. """
  142. response = self._request_with_plugin_daemon_response_stream(
  143. method="POST",
  144. path=f"plugin/{tenant_id}/dispatch/llm/invoke",
  145. type=LLMResultChunk,
  146. data=jsonable_encoder(
  147. {
  148. "user_id": user_id,
  149. "data": {
  150. "provider": provider,
  151. "model_type": "llm",
  152. "model": model,
  153. "credentials": credentials,
  154. "prompt_messages": prompt_messages,
  155. "model_parameters": model_parameters,
  156. "tools": tools,
  157. "stop": stop,
  158. "stream": stream,
  159. },
  160. }
  161. ),
  162. headers={
  163. "X-Plugin-ID": plugin_id,
  164. "Content-Type": "application/json",
  165. },
  166. )
  167. try:
  168. yield from response
  169. except PluginDaemonInnerError as e:
  170. raise ValueError(e.message + str(e.code))
  171. def get_llm_num_tokens(
  172. self,
  173. tenant_id: str,
  174. user_id: str,
  175. plugin_id: str,
  176. provider: str,
  177. model_type: str,
  178. model: str,
  179. credentials: dict,
  180. prompt_messages: list[PromptMessage],
  181. tools: Optional[list[PromptMessageTool]] = None,
  182. ) -> int:
  183. """
  184. Get number of tokens for llm
  185. """
  186. response = self._request_with_plugin_daemon_response_stream(
  187. method="POST",
  188. path=f"plugin/{tenant_id}/dispatch/llm/num_tokens",
  189. type=PluginNumTokensResponse,
  190. data=jsonable_encoder(
  191. {
  192. "user_id": user_id,
  193. "data": {
  194. "provider": provider,
  195. "model_type": model_type,
  196. "model": model,
  197. "credentials": credentials,
  198. "prompt_messages": prompt_messages,
  199. "tools": tools,
  200. },
  201. }
  202. ),
  203. headers={
  204. "X-Plugin-ID": plugin_id,
  205. "Content-Type": "application/json",
  206. },
  207. )
  208. for resp in response:
  209. return resp.num_tokens
  210. return 0
  211. def invoke_text_embedding(
  212. self,
  213. tenant_id: str,
  214. user_id: str,
  215. plugin_id: str,
  216. provider: str,
  217. model: str,
  218. credentials: dict,
  219. texts: list[str],
  220. ) -> TextEmbeddingResult:
  221. """
  222. Invoke text embedding
  223. """
  224. response = self._request_with_plugin_daemon_response_stream(
  225. method="POST",
  226. path=f"plugin/{tenant_id}/dispatch/text_embedding/invoke",
  227. type=TextEmbeddingResult,
  228. data=jsonable_encoder(
  229. {
  230. "user_id": user_id,
  231. "data": {
  232. "provider": provider,
  233. "model_type": "text-embedding",
  234. "model": model,
  235. "credentials": credentials,
  236. "texts": texts,
  237. },
  238. }
  239. ),
  240. headers={
  241. "X-Plugin-ID": plugin_id,
  242. "Content-Type": "application/json",
  243. },
  244. )
  245. for resp in response:
  246. return resp
  247. raise ValueError("Failed to invoke text embedding")
  248. def get_text_embedding_num_tokens(
  249. self,
  250. tenant_id: str,
  251. user_id: str,
  252. plugin_id: str,
  253. provider: str,
  254. model_type: str,
  255. model: str,
  256. credentials: dict,
  257. texts: list[str],
  258. ) -> int:
  259. """
  260. Get number of tokens for text embedding
  261. """
  262. response = self._request_with_plugin_daemon_response_stream(
  263. method="POST",
  264. path=f"plugin/{tenant_id}/dispatch/text_embedding/num_tokens",
  265. type=PluginNumTokensResponse,
  266. data=jsonable_encoder(
  267. {
  268. "user_id": user_id,
  269. "data": {
  270. "provider": provider,
  271. "model_type": model_type,
  272. "model": model,
  273. "credentials": credentials,
  274. "texts": texts,
  275. },
  276. }
  277. ),
  278. headers={
  279. "X-Plugin-ID": plugin_id,
  280. "Content-Type": "application/json",
  281. },
  282. )
  283. for resp in response:
  284. return resp.num_tokens
  285. return 0
  286. def invoke_rerank(
  287. self,
  288. tenant_id: str,
  289. user_id: str,
  290. plugin_id: str,
  291. provider: str,
  292. model_type: str,
  293. model: str,
  294. credentials: dict,
  295. query: str,
  296. docs: list[str],
  297. score_threshold: Optional[float] = None,
  298. top_n: Optional[int] = None,
  299. ) -> RerankResult:
  300. """
  301. Invoke rerank
  302. """
  303. response = self._request_with_plugin_daemon_response_stream(
  304. method="POST",
  305. path=f"plugin/{tenant_id}/dispatch/rerank/invoke",
  306. type=RerankResult,
  307. data=jsonable_encoder(
  308. {
  309. "user_id": user_id,
  310. "data": {
  311. "provider": provider,
  312. "model_type": model_type,
  313. "model": model,
  314. "credentials": credentials,
  315. "query": query,
  316. "docs": docs,
  317. "score_threshold": score_threshold,
  318. "top_n": top_n,
  319. },
  320. }
  321. ),
  322. headers={
  323. "X-Plugin-ID": plugin_id,
  324. "Content-Type": "application/json",
  325. },
  326. )
  327. for resp in response:
  328. return resp
  329. raise ValueError("Failed to invoke rerank")
  330. def invoke_tts(
  331. self,
  332. tenant_id: str,
  333. user_id: str,
  334. plugin_id: str,
  335. provider: str,
  336. model_type: str,
  337. model: str,
  338. credentials: dict,
  339. content_text: str,
  340. voice: str,
  341. ) -> Generator[bytes, None, None]:
  342. """
  343. Invoke tts
  344. """
  345. response = self._request_with_plugin_daemon_response_stream(
  346. method="POST",
  347. path=f"plugin/{tenant_id}/dispatch/tts/invoke",
  348. type=PluginStringResultResponse,
  349. data=jsonable_encoder(
  350. {
  351. "user_id": user_id,
  352. "data": {
  353. "provider": provider,
  354. "model_type": model_type,
  355. "model": model,
  356. "credentials": credentials,
  357. "content_text": content_text,
  358. "voice": voice,
  359. },
  360. }
  361. ),
  362. headers={
  363. "X-Plugin-ID": plugin_id,
  364. "Content-Type": "application/json",
  365. },
  366. )
  367. try:
  368. for result in response:
  369. hex_str = result.result
  370. yield binascii.unhexlify(hex_str)
  371. except PluginDaemonInnerError as e:
  372. raise ValueError(e.message + str(e.code))
  373. def get_tts_model_voices(
  374. self,
  375. tenant_id: str,
  376. user_id: str,
  377. plugin_id: str,
  378. provider: str,
  379. model_type: str,
  380. model: str,
  381. credentials: dict,
  382. language: Optional[str] = None,
  383. ) -> list[dict]:
  384. """
  385. Get tts model voices
  386. """
  387. response = self._request_with_plugin_daemon_response_stream(
  388. method="POST",
  389. path=f"plugin/{tenant_id}/dispatch/model/voices",
  390. type=PluginVoicesResponse,
  391. data=jsonable_encoder(
  392. {
  393. "user_id": user_id,
  394. "data": {
  395. "provider": provider,
  396. "model_type": model_type,
  397. "model": model,
  398. "credentials": credentials,
  399. "language": language,
  400. },
  401. }
  402. ),
  403. headers={
  404. "X-Plugin-ID": plugin_id,
  405. "Content-Type": "application/json",
  406. },
  407. )
  408. for resp in response:
  409. for voice in resp.voices:
  410. return [{"name": voice.name, "value": voice.value}]
  411. return []
  412. def invoke_speech_to_text(
  413. self,
  414. tenant_id: str,
  415. user_id: str,
  416. plugin_id: str,
  417. provider: str,
  418. model_type: str,
  419. model: str,
  420. credentials: dict,
  421. file: IO[bytes],
  422. ) -> str:
  423. """
  424. Invoke speech to text
  425. """
  426. response = self._request_with_plugin_daemon_response_stream(
  427. method="POST",
  428. path=f"plugin/{tenant_id}/dispatch/speech2text/invoke",
  429. type=PluginStringResultResponse,
  430. data=jsonable_encoder(
  431. {
  432. "user_id": user_id,
  433. "data": {
  434. "provider": provider,
  435. "model_type": model_type,
  436. "model": model,
  437. "credentials": credentials,
  438. "file": binascii.hexlify(file.read()).decode(),
  439. },
  440. }
  441. ),
  442. headers={
  443. "X-Plugin-ID": plugin_id,
  444. "Content-Type": "application/json",
  445. },
  446. )
  447. for resp in response:
  448. return resp.result
  449. raise ValueError("Failed to invoke speech to text")
  450. def invoke_moderation(
  451. self,
  452. tenant_id: str,
  453. user_id: str,
  454. plugin_id: str,
  455. provider: str,
  456. model_type: str,
  457. model: str,
  458. credentials: dict,
  459. text: str,
  460. ) -> bool:
  461. """
  462. Invoke moderation
  463. """
  464. response = self._request_with_plugin_daemon_response_stream(
  465. method="POST",
  466. path=f"plugin/{tenant_id}/dispatch/moderation/invoke",
  467. type=PluginBasicBooleanResponse,
  468. data=jsonable_encoder(
  469. {
  470. "user_id": user_id,
  471. "data": {
  472. "provider": provider,
  473. "model_type": model_type,
  474. "model": model,
  475. "credentials": credentials,
  476. "text": text,
  477. },
  478. }
  479. ),
  480. headers={
  481. "X-Plugin-ID": plugin_id,
  482. "Content-Type": "application/json",
  483. },
  484. )
  485. for resp in response:
  486. return resp.result
  487. raise ValueError("Failed to invoke moderation")