model.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524
  1. import binascii
  2. from collections.abc import Generator, Sequence
  3. from typing import IO, Optional
  4. from core.model_runtime.entities.llm_entities import LLMResultChunk
  5. from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool
  6. from core.model_runtime.entities.model_entities import AIModelEntity
  7. from core.model_runtime.entities.rerank_entities import RerankResult
  8. from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult
  9. from core.model_runtime.utils.encoders import jsonable_encoder
  10. from core.plugin.entities.plugin_daemon import (
  11. PluginBasicBooleanResponse,
  12. PluginDaemonInnerError,
  13. PluginModelProviderEntity,
  14. PluginModelSchemaEntity,
  15. PluginNumTokensResponse,
  16. PluginStringResultResponse,
  17. PluginVoicesResponse,
  18. )
  19. from core.plugin.manager.base import BasePluginManager
  20. class PluginModelManager(BasePluginManager):
  21. def fetch_model_providers(self, tenant_id: str) -> Sequence[PluginModelProviderEntity]:
  22. """
  23. Fetch model providers for the given tenant.
  24. """
  25. response = self._request_with_plugin_daemon_response(
  26. "GET",
  27. f"plugin/{tenant_id}/management/models",
  28. list[PluginModelProviderEntity],
  29. params={"page": 1, "page_size": 256},
  30. )
  31. return response
  32. def get_model_schema(
  33. self,
  34. tenant_id: str,
  35. user_id: str,
  36. plugin_id: str,
  37. provider: str,
  38. model_type: str,
  39. model: str,
  40. credentials: dict,
  41. ) -> AIModelEntity | None:
  42. """
  43. Get model schema
  44. """
  45. response = self._request_with_plugin_daemon_response_stream(
  46. "POST",
  47. f"plugin/{tenant_id}/dispatch/model/schema",
  48. PluginModelSchemaEntity,
  49. data={
  50. "user_id": user_id,
  51. "data": {
  52. "provider": provider,
  53. "model_type": model_type,
  54. "model": model,
  55. "credentials": credentials,
  56. },
  57. },
  58. headers={
  59. "X-Plugin-ID": plugin_id,
  60. "Content-Type": "application/json",
  61. },
  62. )
  63. for resp in response:
  64. return resp.model_schema
  65. return None
  66. def validate_provider_credentials(
  67. self, tenant_id: str, user_id: str, plugin_id: str, provider: str, credentials: dict
  68. ) -> bool:
  69. """
  70. validate the credentials of the provider
  71. """
  72. response = self._request_with_plugin_daemon_response_stream(
  73. "POST",
  74. f"plugin/{tenant_id}/dispatch/model/validate_provider_credentials",
  75. PluginBasicBooleanResponse,
  76. data={
  77. "user_id": user_id,
  78. "data": {
  79. "provider": provider,
  80. "credentials": credentials,
  81. },
  82. },
  83. headers={
  84. "X-Plugin-ID": plugin_id,
  85. "Content-Type": "application/json",
  86. },
  87. )
  88. for resp in response:
  89. return resp.result
  90. return False
  91. def validate_model_credentials(
  92. self,
  93. tenant_id: str,
  94. user_id: str,
  95. plugin_id: str,
  96. provider: str,
  97. model_type: str,
  98. model: str,
  99. credentials: dict,
  100. ) -> bool:
  101. """
  102. validate the credentials of the provider
  103. """
  104. response = self._request_with_plugin_daemon_response_stream(
  105. "POST",
  106. f"plugin/{tenant_id}/dispatch/model/validate_model_credentials",
  107. PluginBasicBooleanResponse,
  108. data={
  109. "user_id": user_id,
  110. "data": {
  111. "provider": provider,
  112. "model_type": model_type,
  113. "model": model,
  114. "credentials": credentials,
  115. },
  116. },
  117. headers={
  118. "X-Plugin-ID": plugin_id,
  119. "Content-Type": "application/json",
  120. },
  121. )
  122. for resp in response:
  123. return resp.result
  124. return False
  125. def invoke_llm(
  126. self,
  127. tenant_id: str,
  128. user_id: str,
  129. plugin_id: str,
  130. provider: str,
  131. model: str,
  132. credentials: dict,
  133. prompt_messages: list[PromptMessage],
  134. model_parameters: Optional[dict] = None,
  135. tools: Optional[list[PromptMessageTool]] = None,
  136. stop: Optional[list[str]] = None,
  137. stream: bool = True,
  138. ) -> Generator[LLMResultChunk, None, None]:
  139. """
  140. Invoke llm
  141. """
  142. response = self._request_with_plugin_daemon_response_stream(
  143. method="POST",
  144. path=f"plugin/{tenant_id}/dispatch/llm/invoke",
  145. type=LLMResultChunk,
  146. data=jsonable_encoder(
  147. {
  148. "user_id": user_id,
  149. "data": {
  150. "provider": provider,
  151. "model_type": "llm",
  152. "model": model,
  153. "credentials": credentials,
  154. "prompt_messages": prompt_messages,
  155. "model_parameters": model_parameters,
  156. "tools": tools,
  157. "stop": stop,
  158. "stream": stream,
  159. },
  160. }
  161. ),
  162. headers={
  163. "X-Plugin-ID": plugin_id,
  164. "Content-Type": "application/json",
  165. },
  166. )
  167. try:
  168. yield from response
  169. except PluginDaemonInnerError as e:
  170. raise ValueError(e.message + str(e.code))
  171. def get_llm_num_tokens(
  172. self,
  173. tenant_id: str,
  174. user_id: str,
  175. plugin_id: str,
  176. provider: str,
  177. model_type: str,
  178. model: str,
  179. credentials: dict,
  180. prompt_messages: list[PromptMessage],
  181. tools: Optional[list[PromptMessageTool]] = None,
  182. ) -> int:
  183. """
  184. Get number of tokens for llm
  185. """
  186. response = self._request_with_plugin_daemon_response_stream(
  187. method="POST",
  188. path=f"plugin/{tenant_id}/dispatch/llm/num_tokens",
  189. type=PluginNumTokensResponse,
  190. data=jsonable_encoder(
  191. {
  192. "user_id": user_id,
  193. "data": {
  194. "provider": provider,
  195. "model_type": model_type,
  196. "model": model,
  197. "credentials": credentials,
  198. "prompt_messages": prompt_messages,
  199. "tools": tools,
  200. },
  201. }
  202. ),
  203. headers={
  204. "X-Plugin-ID": plugin_id,
  205. "Content-Type": "application/json",
  206. },
  207. )
  208. for resp in response:
  209. return resp.num_tokens
  210. return 0
  211. def invoke_text_embedding(
  212. self,
  213. tenant_id: str,
  214. user_id: str,
  215. plugin_id: str,
  216. provider: str,
  217. model: str,
  218. credentials: dict,
  219. texts: list[str],
  220. input_type: str,
  221. ) -> TextEmbeddingResult:
  222. """
  223. Invoke text embedding
  224. """
  225. response = self._request_with_plugin_daemon_response_stream(
  226. method="POST",
  227. path=f"plugin/{tenant_id}/dispatch/text_embedding/invoke",
  228. type=TextEmbeddingResult,
  229. data=jsonable_encoder(
  230. {
  231. "user_id": user_id,
  232. "data": {
  233. "provider": provider,
  234. "model_type": "text-embedding",
  235. "model": model,
  236. "credentials": credentials,
  237. "texts": texts,
  238. "input_type": input_type,
  239. },
  240. }
  241. ),
  242. headers={
  243. "X-Plugin-ID": plugin_id,
  244. "Content-Type": "application/json",
  245. },
  246. )
  247. for resp in response:
  248. return resp
  249. raise ValueError("Failed to invoke text embedding")
  250. def get_text_embedding_num_tokens(
  251. self,
  252. tenant_id: str,
  253. user_id: str,
  254. plugin_id: str,
  255. provider: str,
  256. model: str,
  257. credentials: dict,
  258. texts: list[str],
  259. ) -> int:
  260. """
  261. Get number of tokens for text embedding
  262. """
  263. response = self._request_with_plugin_daemon_response_stream(
  264. method="POST",
  265. path=f"plugin/{tenant_id}/dispatch/text_embedding/num_tokens",
  266. type=PluginNumTokensResponse,
  267. data=jsonable_encoder(
  268. {
  269. "user_id": user_id,
  270. "data": {
  271. "provider": provider,
  272. "model_type": "text-embedding",
  273. "model": model,
  274. "credentials": credentials,
  275. "texts": texts,
  276. },
  277. }
  278. ),
  279. headers={
  280. "X-Plugin-ID": plugin_id,
  281. "Content-Type": "application/json",
  282. },
  283. )
  284. for resp in response:
  285. return resp.num_tokens
  286. return 0
  287. def invoke_rerank(
  288. self,
  289. tenant_id: str,
  290. user_id: str,
  291. plugin_id: str,
  292. provider: str,
  293. model: str,
  294. credentials: dict,
  295. query: str,
  296. docs: list[str],
  297. score_threshold: Optional[float] = None,
  298. top_n: Optional[int] = None,
  299. ) -> RerankResult:
  300. """
  301. Invoke rerank
  302. """
  303. response = self._request_with_plugin_daemon_response_stream(
  304. method="POST",
  305. path=f"plugin/{tenant_id}/dispatch/rerank/invoke",
  306. type=RerankResult,
  307. data=jsonable_encoder(
  308. {
  309. "user_id": user_id,
  310. "data": {
  311. "provider": provider,
  312. "model_type": "rerank",
  313. "model": model,
  314. "credentials": credentials,
  315. "query": query,
  316. "docs": docs,
  317. "score_threshold": score_threshold,
  318. "top_n": top_n,
  319. },
  320. }
  321. ),
  322. headers={
  323. "X-Plugin-ID": plugin_id,
  324. "Content-Type": "application/json",
  325. },
  326. )
  327. for resp in response:
  328. return resp
  329. raise ValueError("Failed to invoke rerank")
  330. def invoke_tts(
  331. self,
  332. tenant_id: str,
  333. user_id: str,
  334. plugin_id: str,
  335. provider: str,
  336. model: str,
  337. credentials: dict,
  338. content_text: str,
  339. voice: str,
  340. ) -> Generator[bytes, None, None]:
  341. """
  342. Invoke tts
  343. """
  344. response = self._request_with_plugin_daemon_response_stream(
  345. method="POST",
  346. path=f"plugin/{tenant_id}/dispatch/tts/invoke",
  347. type=PluginStringResultResponse,
  348. data=jsonable_encoder(
  349. {
  350. "user_id": user_id,
  351. "data": {
  352. "provider": provider,
  353. "model_type": "tts",
  354. "model": model,
  355. "credentials": credentials,
  356. "content_text": content_text,
  357. "voice": voice,
  358. },
  359. }
  360. ),
  361. headers={
  362. "X-Plugin-ID": plugin_id,
  363. "Content-Type": "application/json",
  364. },
  365. )
  366. try:
  367. for result in response:
  368. hex_str = result.result
  369. yield binascii.unhexlify(hex_str)
  370. except PluginDaemonInnerError as e:
  371. raise ValueError(e.message + str(e.code))
  372. def get_tts_model_voices(
  373. self,
  374. tenant_id: str,
  375. user_id: str,
  376. plugin_id: str,
  377. provider: str,
  378. model: str,
  379. credentials: dict,
  380. language: Optional[str] = None,
  381. ) -> list[dict]:
  382. """
  383. Get tts model voices
  384. """
  385. response = self._request_with_plugin_daemon_response_stream(
  386. method="POST",
  387. path=f"plugin/{tenant_id}/dispatch/tts/model/voices",
  388. type=PluginVoicesResponse,
  389. data=jsonable_encoder(
  390. {
  391. "user_id": user_id,
  392. "data": {
  393. "provider": provider,
  394. "model_type": "tts",
  395. "model": model,
  396. "credentials": credentials,
  397. "language": language,
  398. },
  399. }
  400. ),
  401. headers={
  402. "X-Plugin-ID": plugin_id,
  403. "Content-Type": "application/json",
  404. },
  405. )
  406. for resp in response:
  407. voices = []
  408. for voice in resp.voices:
  409. voices.append({"name": voice.name, "value": voice.value})
  410. return voices
  411. return []
  412. def invoke_speech_to_text(
  413. self,
  414. tenant_id: str,
  415. user_id: str,
  416. plugin_id: str,
  417. provider: str,
  418. model: str,
  419. credentials: dict,
  420. file: IO[bytes],
  421. ) -> str:
  422. """
  423. Invoke speech to text
  424. """
  425. response = self._request_with_plugin_daemon_response_stream(
  426. method="POST",
  427. path=f"plugin/{tenant_id}/dispatch/speech2text/invoke",
  428. type=PluginStringResultResponse,
  429. data=jsonable_encoder(
  430. {
  431. "user_id": user_id,
  432. "data": {
  433. "provider": provider,
  434. "model_type": "speech2text",
  435. "model": model,
  436. "credentials": credentials,
  437. "file": binascii.hexlify(file.read()).decode(),
  438. },
  439. }
  440. ),
  441. headers={
  442. "X-Plugin-ID": plugin_id,
  443. "Content-Type": "application/json",
  444. },
  445. )
  446. for resp in response:
  447. return resp.result
  448. raise ValueError("Failed to invoke speech to text")
  449. def invoke_moderation(
  450. self,
  451. tenant_id: str,
  452. user_id: str,
  453. plugin_id: str,
  454. provider: str,
  455. model: str,
  456. credentials: dict,
  457. text: str,
  458. ) -> bool:
  459. """
  460. Invoke moderation
  461. """
  462. response = self._request_with_plugin_daemon_response_stream(
  463. method="POST",
  464. path=f"plugin/{tenant_id}/dispatch/moderation/invoke",
  465. type=PluginBasicBooleanResponse,
  466. data=jsonable_encoder(
  467. {
  468. "user_id": user_id,
  469. "data": {
  470. "provider": provider,
  471. "model_type": "moderation",
  472. "model": model,
  473. "credentials": credentials,
  474. "text": text,
  475. },
  476. }
  477. ),
  478. headers={
  479. "X-Plugin-ID": plugin_id,
  480. "Content-Type": "application/json",
  481. },
  482. )
  483. for resp in response:
  484. return resp.result
  485. raise ValueError("Failed to invoke moderation")