provider_configuration.py 42 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082
  1. import datetime
  2. import json
  3. import logging
  4. from collections import defaultdict
  5. from collections.abc import Iterator, Sequence
  6. from json import JSONDecodeError
  7. from typing import Optional
  8. from pydantic import BaseModel, ConfigDict
  9. from constants import HIDDEN_VALUE
  10. from core.entities import DEFAULT_PLUGIN_ID
  11. from core.entities.model_entities import ModelStatus, ModelWithProviderEntity, SimpleModelProviderEntity
  12. from core.entities.provider_entities import (
  13. CustomConfiguration,
  14. ModelSettings,
  15. SystemConfiguration,
  16. SystemConfigurationStatus,
  17. )
  18. from core.helper import encrypter
  19. from core.helper.model_provider_cache import ProviderCredentialsCache, ProviderCredentialsCacheType
  20. from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType
  21. from core.model_runtime.entities.provider_entities import (
  22. ConfigurateMethod,
  23. CredentialFormSchema,
  24. FormType,
  25. ProviderEntity,
  26. )
  27. from core.model_runtime.model_providers.__base.ai_model import AIModel
  28. from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
  29. from extensions.ext_database import db
  30. from models.provider import (
  31. LoadBalancingModelConfig,
  32. Provider,
  33. ProviderModel,
  34. ProviderModelSetting,
  35. ProviderType,
  36. TenantPreferredModelProvider,
  37. )
  38. logger = logging.getLogger(__name__)
  39. original_provider_configurate_methods = {}
  40. class ProviderConfiguration(BaseModel):
  41. """
  42. Model class for provider configuration.
  43. """
  44. tenant_id: str
  45. provider: ProviderEntity
  46. preferred_provider_type: ProviderType
  47. using_provider_type: ProviderType
  48. system_configuration: SystemConfiguration
  49. custom_configuration: CustomConfiguration
  50. model_settings: list[ModelSettings]
  51. # pydantic configs
  52. model_config = ConfigDict(protected_namespaces=())
  53. def __init__(self, **data):
  54. super().__init__(**data)
  55. if self.provider.provider not in original_provider_configurate_methods:
  56. original_provider_configurate_methods[self.provider.provider] = []
  57. for configurate_method in self.provider.configurate_methods:
  58. original_provider_configurate_methods[self.provider.provider].append(configurate_method)
  59. if original_provider_configurate_methods[self.provider.provider] == [ConfigurateMethod.CUSTOMIZABLE_MODEL]:
  60. if (
  61. any(
  62. len(quota_configuration.restrict_models) > 0
  63. for quota_configuration in self.system_configuration.quota_configurations
  64. )
  65. and ConfigurateMethod.PREDEFINED_MODEL not in self.provider.configurate_methods
  66. ):
  67. self.provider.configurate_methods.append(ConfigurateMethod.PREDEFINED_MODEL)
  68. def get_current_credentials(self, model_type: ModelType, model: str) -> Optional[dict]:
  69. """
  70. Get current credentials.
  71. :param model_type: model type
  72. :param model: model name
  73. :return:
  74. """
  75. if self.model_settings:
  76. # check if model is disabled by admin
  77. for model_setting in self.model_settings:
  78. if model_setting.model_type == model_type and model_setting.model == model:
  79. if not model_setting.enabled:
  80. raise ValueError(f"Model {model} is disabled.")
  81. if self.using_provider_type == ProviderType.SYSTEM:
  82. restrict_models = []
  83. for quota_configuration in self.system_configuration.quota_configurations:
  84. if self.system_configuration.current_quota_type != quota_configuration.quota_type:
  85. continue
  86. restrict_models = quota_configuration.restrict_models
  87. copy_credentials = (
  88. self.system_configuration.credentials.copy() if self.system_configuration.credentials else {}
  89. )
  90. if restrict_models:
  91. for restrict_model in restrict_models:
  92. if (
  93. restrict_model.model_type == model_type
  94. and restrict_model.model == model
  95. and restrict_model.base_model_name
  96. ):
  97. copy_credentials["base_model_name"] = restrict_model.base_model_name
  98. return copy_credentials
  99. else:
  100. credentials = None
  101. if self.custom_configuration.models:
  102. for model_configuration in self.custom_configuration.models:
  103. if model_configuration.model_type == model_type and model_configuration.model == model:
  104. credentials = model_configuration.credentials
  105. break
  106. if not credentials and self.custom_configuration.provider:
  107. credentials = self.custom_configuration.provider.credentials
  108. return credentials
  109. def get_system_configuration_status(self) -> SystemConfigurationStatus:
  110. """
  111. Get system configuration status.
  112. :return:
  113. """
  114. if self.system_configuration.enabled is False:
  115. return SystemConfigurationStatus.UNSUPPORTED
  116. current_quota_type = self.system_configuration.current_quota_type
  117. current_quota_configuration = next(
  118. (q for q in self.system_configuration.quota_configurations if q.quota_type == current_quota_type), None
  119. )
  120. if not current_quota_configuration:
  121. return SystemConfigurationStatus.UNSUPPORTED
  122. return (
  123. SystemConfigurationStatus.ACTIVE
  124. if current_quota_configuration.is_valid
  125. else SystemConfigurationStatus.QUOTA_EXCEEDED
  126. )
  127. def is_custom_configuration_available(self) -> bool:
  128. """
  129. Check custom configuration available.
  130. :return:
  131. """
  132. return self.custom_configuration.provider is not None or len(self.custom_configuration.models) > 0
  133. def get_custom_credentials(self, obfuscated: bool = False) -> Optional[dict]:
  134. """
  135. Get custom credentials.
  136. :param obfuscated: obfuscated secret data in credentials
  137. :return:
  138. """
  139. if self.custom_configuration.provider is None:
  140. return None
  141. credentials = self.custom_configuration.provider.credentials
  142. if not obfuscated:
  143. return credentials
  144. # Obfuscate credentials
  145. return self.obfuscated_credentials(
  146. credentials=credentials,
  147. credential_form_schemas=self.provider.provider_credential_schema.credential_form_schemas
  148. if self.provider.provider_credential_schema
  149. else [],
  150. )
  151. def custom_credentials_validate(self, credentials: dict) -> tuple[Provider | None, dict]:
  152. """
  153. Validate custom credentials.
  154. :param credentials: provider credentials
  155. :return:
  156. """
  157. # get provider
  158. provider_record = (
  159. db.session.query(Provider)
  160. .filter(
  161. Provider.tenant_id == self.tenant_id,
  162. Provider.provider_name == self.provider.provider,
  163. Provider.provider_type == ProviderType.CUSTOM.value,
  164. )
  165. .first()
  166. )
  167. # Get provider credential secret variables
  168. provider_credential_secret_variables = self.extract_secret_variables(
  169. self.provider.provider_credential_schema.credential_form_schemas
  170. if self.provider.provider_credential_schema
  171. else []
  172. )
  173. if provider_record:
  174. try:
  175. # fix origin data
  176. if provider_record.encrypted_config:
  177. if not provider_record.encrypted_config.startswith("{"):
  178. original_credentials = {"openai_api_key": provider_record.encrypted_config}
  179. else:
  180. original_credentials = json.loads(provider_record.encrypted_config)
  181. else:
  182. original_credentials = {}
  183. except JSONDecodeError:
  184. original_credentials = {}
  185. # encrypt credentials
  186. for key, value in credentials.items():
  187. if key in provider_credential_secret_variables:
  188. # if send [__HIDDEN__] in secret input, it will be same as original value
  189. if value == HIDDEN_VALUE and key in original_credentials:
  190. credentials[key] = encrypter.decrypt_token(self.tenant_id, original_credentials[key])
  191. model_provider_factory = ModelProviderFactory(self.tenant_id)
  192. credentials = model_provider_factory.provider_credentials_validate(
  193. provider=self.provider.provider, credentials=credentials
  194. )
  195. for key, value in credentials.items():
  196. if key in provider_credential_secret_variables:
  197. credentials[key] = encrypter.encrypt_token(self.tenant_id, value)
  198. return provider_record, credentials
  199. def add_or_update_custom_credentials(self, credentials: dict) -> None:
  200. """
  201. Add or update custom provider credentials.
  202. :param credentials:
  203. :return:
  204. """
  205. # validate custom provider config
  206. provider_record, credentials = self.custom_credentials_validate(credentials)
  207. # save provider
  208. # Note: Do not switch the preferred provider, which allows users to use quotas first
  209. if provider_record:
  210. provider_record.encrypted_config = json.dumps(credentials)
  211. provider_record.is_valid = True
  212. provider_record.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
  213. db.session.commit()
  214. else:
  215. provider_record = Provider()
  216. provider_record.tenant_id = self.tenant_id
  217. provider_record.provider_name = self.provider.provider
  218. provider_record.provider_type = ProviderType.CUSTOM.value
  219. provider_record.encrypted_config = json.dumps(credentials)
  220. provider_record.is_valid = True
  221. db.session.add(provider_record)
  222. db.session.commit()
  223. provider_model_credentials_cache = ProviderCredentialsCache(
  224. tenant_id=self.tenant_id, identity_id=provider_record.id, cache_type=ProviderCredentialsCacheType.PROVIDER
  225. )
  226. provider_model_credentials_cache.delete()
  227. self.switch_preferred_provider_type(ProviderType.CUSTOM)
  228. def delete_custom_credentials(self) -> None:
  229. """
  230. Delete custom provider credentials.
  231. :return:
  232. """
  233. # get provider
  234. provider_record = (
  235. db.session.query(Provider)
  236. .filter(
  237. Provider.tenant_id == self.tenant_id,
  238. Provider.provider_name == self.provider.provider,
  239. Provider.provider_type == ProviderType.CUSTOM.value,
  240. )
  241. .first()
  242. )
  243. # delete provider
  244. if provider_record:
  245. self.switch_preferred_provider_type(ProviderType.SYSTEM)
  246. db.session.delete(provider_record)
  247. db.session.commit()
  248. provider_model_credentials_cache = ProviderCredentialsCache(
  249. tenant_id=self.tenant_id,
  250. identity_id=provider_record.id,
  251. cache_type=ProviderCredentialsCacheType.PROVIDER,
  252. )
  253. provider_model_credentials_cache.delete()
  254. def get_custom_model_credentials(
  255. self, model_type: ModelType, model: str, obfuscated: bool = False
  256. ) -> Optional[dict]:
  257. """
  258. Get custom model credentials.
  259. :param model_type: model type
  260. :param model: model name
  261. :param obfuscated: obfuscated secret data in credentials
  262. :return:
  263. """
  264. if not self.custom_configuration.models:
  265. return None
  266. for model_configuration in self.custom_configuration.models:
  267. if model_configuration.model_type == model_type and model_configuration.model == model:
  268. credentials = model_configuration.credentials
  269. if not obfuscated:
  270. return credentials
  271. # Obfuscate credentials
  272. return self.obfuscated_credentials(
  273. credentials=credentials,
  274. credential_form_schemas=self.provider.model_credential_schema.credential_form_schemas
  275. if self.provider.model_credential_schema
  276. else [],
  277. )
  278. return None
  279. def custom_model_credentials_validate(
  280. self, model_type: ModelType, model: str, credentials: dict
  281. ) -> tuple[ProviderModel | None, dict]:
  282. """
  283. Validate custom model credentials.
  284. :param model_type: model type
  285. :param model: model name
  286. :param credentials: model credentials
  287. :return:
  288. """
  289. # get provider model
  290. provider_model_record = (
  291. db.session.query(ProviderModel)
  292. .filter(
  293. ProviderModel.tenant_id == self.tenant_id,
  294. ProviderModel.provider_name == self.provider.provider,
  295. ProviderModel.model_name == model,
  296. ProviderModel.model_type == model_type.to_origin_model_type(),
  297. )
  298. .first()
  299. )
  300. # Get provider credential secret variables
  301. provider_credential_secret_variables = self.extract_secret_variables(
  302. self.provider.model_credential_schema.credential_form_schemas
  303. if self.provider.model_credential_schema
  304. else []
  305. )
  306. if provider_model_record:
  307. try:
  308. original_credentials = (
  309. json.loads(provider_model_record.encrypted_config) if provider_model_record.encrypted_config else {}
  310. )
  311. except JSONDecodeError:
  312. original_credentials = {}
  313. # decrypt credentials
  314. for key, value in credentials.items():
  315. if key in provider_credential_secret_variables:
  316. # if send [__HIDDEN__] in secret input, it will be same as original value
  317. if value == HIDDEN_VALUE and key in original_credentials:
  318. credentials[key] = encrypter.decrypt_token(self.tenant_id, original_credentials[key])
  319. model_provider_factory = ModelProviderFactory(self.tenant_id)
  320. credentials = model_provider_factory.model_credentials_validate(
  321. provider=self.provider.provider, model_type=model_type, model=model, credentials=credentials
  322. )
  323. for key, value in credentials.items():
  324. if key in provider_credential_secret_variables:
  325. credentials[key] = encrypter.encrypt_token(self.tenant_id, value)
  326. return provider_model_record, credentials
  327. def add_or_update_custom_model_credentials(self, model_type: ModelType, model: str, credentials: dict) -> None:
  328. """
  329. Add or update custom model credentials.
  330. :param model_type: model type
  331. :param model: model name
  332. :param credentials: model credentials
  333. :return:
  334. """
  335. # validate custom model config
  336. provider_model_record, credentials = self.custom_model_credentials_validate(model_type, model, credentials)
  337. # save provider model
  338. # Note: Do not switch the preferred provider, which allows users to use quotas first
  339. if provider_model_record:
  340. provider_model_record.encrypted_config = json.dumps(credentials)
  341. provider_model_record.is_valid = True
  342. provider_model_record.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
  343. db.session.commit()
  344. else:
  345. provider_model_record = ProviderModel()
  346. provider_model_record.tenant_id = self.tenant_id
  347. provider_model_record.provider_name = self.provider.provider
  348. provider_model_record.model_name = model
  349. provider_model_record.model_type = model_type.to_origin_model_type()
  350. provider_model_record.encrypted_config = json.dumps(credentials)
  351. provider_model_record.is_valid = True
  352. db.session.add(provider_model_record)
  353. db.session.commit()
  354. provider_model_credentials_cache = ProviderCredentialsCache(
  355. tenant_id=self.tenant_id,
  356. identity_id=provider_model_record.id,
  357. cache_type=ProviderCredentialsCacheType.MODEL,
  358. )
  359. provider_model_credentials_cache.delete()
  360. def delete_custom_model_credentials(self, model_type: ModelType, model: str) -> None:
  361. """
  362. Delete custom model credentials.
  363. :param model_type: model type
  364. :param model: model name
  365. :return:
  366. """
  367. # get provider model
  368. provider_model_record = (
  369. db.session.query(ProviderModel)
  370. .filter(
  371. ProviderModel.tenant_id == self.tenant_id,
  372. ProviderModel.provider_name == self.provider.provider,
  373. ProviderModel.model_name == model,
  374. ProviderModel.model_type == model_type.to_origin_model_type(),
  375. )
  376. .first()
  377. )
  378. # delete provider model
  379. if provider_model_record:
  380. db.session.delete(provider_model_record)
  381. db.session.commit()
  382. provider_model_credentials_cache = ProviderCredentialsCache(
  383. tenant_id=self.tenant_id,
  384. identity_id=provider_model_record.id,
  385. cache_type=ProviderCredentialsCacheType.MODEL,
  386. )
  387. provider_model_credentials_cache.delete()
  388. def enable_model(self, model_type: ModelType, model: str) -> ProviderModelSetting:
  389. """
  390. Enable model.
  391. :param model_type: model type
  392. :param model: model name
  393. :return:
  394. """
  395. model_setting = (
  396. db.session.query(ProviderModelSetting)
  397. .filter(
  398. ProviderModelSetting.tenant_id == self.tenant_id,
  399. ProviderModelSetting.provider_name == self.provider.provider,
  400. ProviderModelSetting.model_type == model_type.to_origin_model_type(),
  401. ProviderModelSetting.model_name == model,
  402. )
  403. .first()
  404. )
  405. if model_setting:
  406. model_setting.enabled = True
  407. model_setting.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
  408. db.session.commit()
  409. else:
  410. model_setting = ProviderModelSetting()
  411. model_setting.tenant_id = self.tenant_id
  412. model_setting.provider_name = self.provider.provider
  413. model_setting.model_type = model_type.to_origin_model_type()
  414. model_setting.model_name = model
  415. model_setting.enabled = True
  416. db.session.add(model_setting)
  417. db.session.commit()
  418. return model_setting
  419. def disable_model(self, model_type: ModelType, model: str) -> ProviderModelSetting:
  420. """
  421. Disable model.
  422. :param model_type: model type
  423. :param model: model name
  424. :return:
  425. """
  426. model_setting = (
  427. db.session.query(ProviderModelSetting)
  428. .filter(
  429. ProviderModelSetting.tenant_id == self.tenant_id,
  430. ProviderModelSetting.provider_name == self.provider.provider,
  431. ProviderModelSetting.model_type == model_type.to_origin_model_type(),
  432. ProviderModelSetting.model_name == model,
  433. )
  434. .first()
  435. )
  436. if model_setting:
  437. model_setting.enabled = False
  438. model_setting.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
  439. db.session.commit()
  440. else:
  441. model_setting = ProviderModelSetting()
  442. model_setting.tenant_id = self.tenant_id
  443. model_setting.provider_name = self.provider.provider
  444. model_setting.model_type = model_type.to_origin_model_type()
  445. model_setting.model_name = model
  446. model_setting.enabled = False
  447. db.session.add(model_setting)
  448. db.session.commit()
  449. return model_setting
  450. def get_provider_model_setting(self, model_type: ModelType, model: str) -> Optional[ProviderModelSetting]:
  451. """
  452. Get provider model setting.
  453. :param model_type: model type
  454. :param model: model name
  455. :return:
  456. """
  457. return (
  458. db.session.query(ProviderModelSetting)
  459. .filter(
  460. ProviderModelSetting.tenant_id == self.tenant_id,
  461. ProviderModelSetting.provider_name == self.provider.provider,
  462. ProviderModelSetting.model_type == model_type.to_origin_model_type(),
  463. ProviderModelSetting.model_name == model,
  464. )
  465. .first()
  466. )
  467. def enable_model_load_balancing(self, model_type: ModelType, model: str) -> ProviderModelSetting:
  468. """
  469. Enable model load balancing.
  470. :param model_type: model type
  471. :param model: model name
  472. :return:
  473. """
  474. load_balancing_config_count = (
  475. db.session.query(LoadBalancingModelConfig)
  476. .filter(
  477. LoadBalancingModelConfig.tenant_id == self.tenant_id,
  478. LoadBalancingModelConfig.provider_name == self.provider.provider,
  479. LoadBalancingModelConfig.model_type == model_type.to_origin_model_type(),
  480. LoadBalancingModelConfig.model_name == model,
  481. )
  482. .count()
  483. )
  484. if load_balancing_config_count <= 1:
  485. raise ValueError("Model load balancing configuration must be more than 1.")
  486. model_setting = (
  487. db.session.query(ProviderModelSetting)
  488. .filter(
  489. ProviderModelSetting.tenant_id == self.tenant_id,
  490. ProviderModelSetting.provider_name == self.provider.provider,
  491. ProviderModelSetting.model_type == model_type.to_origin_model_type(),
  492. ProviderModelSetting.model_name == model,
  493. )
  494. .first()
  495. )
  496. if model_setting:
  497. model_setting.load_balancing_enabled = True
  498. model_setting.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
  499. db.session.commit()
  500. else:
  501. model_setting = ProviderModelSetting()
  502. model_setting.tenant_id = self.tenant_id
  503. model_setting.provider_name = self.provider.provider
  504. model_setting.model_type = model_type.to_origin_model_type()
  505. model_setting.model_name = model
  506. model_setting.load_balancing_enabled = True
  507. db.session.add(model_setting)
  508. db.session.commit()
  509. return model_setting
  510. def disable_model_load_balancing(self, model_type: ModelType, model: str) -> ProviderModelSetting:
  511. """
  512. Disable model load balancing.
  513. :param model_type: model type
  514. :param model: model name
  515. :return:
  516. """
  517. model_setting = (
  518. db.session.query(ProviderModelSetting)
  519. .filter(
  520. ProviderModelSetting.tenant_id == self.tenant_id,
  521. ProviderModelSetting.provider_name == self.provider.provider,
  522. ProviderModelSetting.model_type == model_type.to_origin_model_type(),
  523. ProviderModelSetting.model_name == model,
  524. )
  525. .first()
  526. )
  527. if model_setting:
  528. model_setting.load_balancing_enabled = False
  529. model_setting.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
  530. db.session.commit()
  531. else:
  532. model_setting = ProviderModelSetting()
  533. model_setting.tenant_id = self.tenant_id
  534. model_setting.provider_name = self.provider.provider
  535. model_setting.model_type = model_type.to_origin_model_type()
  536. model_setting.model_name = model
  537. model_setting.load_balancing_enabled = False
  538. db.session.add(model_setting)
  539. db.session.commit()
  540. return model_setting
  541. def get_model_type_instance(self, model_type: ModelType) -> AIModel:
  542. """
  543. Get current model type instance.
  544. :param model_type: model type
  545. :return:
  546. """
  547. model_provider_factory = ModelProviderFactory(self.tenant_id)
  548. # Get model instance of LLM
  549. return model_provider_factory.get_model_type_instance(provider=self.provider.provider, model_type=model_type)
  550. def get_model_schema(self, model_type: ModelType, model: str, credentials: dict) -> AIModelEntity | None:
  551. """
  552. Get model schema
  553. """
  554. model_provider_factory = ModelProviderFactory(self.tenant_id)
  555. return model_provider_factory.get_model_schema(
  556. provider=self.provider.provider, model_type=model_type, model=model, credentials=credentials
  557. )
  558. def switch_preferred_provider_type(self, provider_type: ProviderType) -> None:
  559. """
  560. Switch preferred provider type.
  561. :param provider_type:
  562. :return:
  563. """
  564. if provider_type == self.preferred_provider_type:
  565. return
  566. if provider_type == ProviderType.SYSTEM and not self.system_configuration.enabled:
  567. return
  568. # get preferred provider
  569. preferred_model_provider = (
  570. db.session.query(TenantPreferredModelProvider)
  571. .filter(
  572. TenantPreferredModelProvider.tenant_id == self.tenant_id,
  573. TenantPreferredModelProvider.provider_name == self.provider.provider,
  574. )
  575. .first()
  576. )
  577. if preferred_model_provider:
  578. preferred_model_provider.preferred_provider_type = provider_type.value
  579. else:
  580. preferred_model_provider = TenantPreferredModelProvider()
  581. preferred_model_provider.tenant_id = self.tenant_id
  582. preferred_model_provider.provider_name = self.provider.provider
  583. preferred_model_provider.preferred_provider_type = provider_type.value
  584. db.session.add(preferred_model_provider)
  585. db.session.commit()
  586. def extract_secret_variables(self, credential_form_schemas: list[CredentialFormSchema]) -> list[str]:
  587. """
  588. Extract secret input form variables.
  589. :param credential_form_schemas:
  590. :return:
  591. """
  592. secret_input_form_variables = []
  593. for credential_form_schema in credential_form_schemas:
  594. if credential_form_schema.type == FormType.SECRET_INPUT:
  595. secret_input_form_variables.append(credential_form_schema.variable)
  596. return secret_input_form_variables
  597. def obfuscated_credentials(self, credentials: dict, credential_form_schemas: list[CredentialFormSchema]) -> dict:
  598. """
  599. Obfuscated credentials.
  600. :param credentials: credentials
  601. :param credential_form_schemas: credential form schemas
  602. :return:
  603. """
  604. # Get provider credential secret variables
  605. credential_secret_variables = self.extract_secret_variables(credential_form_schemas)
  606. # Obfuscate provider credentials
  607. copy_credentials = credentials.copy()
  608. for key, value in copy_credentials.items():
  609. if key in credential_secret_variables:
  610. copy_credentials[key] = encrypter.obfuscated_token(value)
  611. return copy_credentials
  612. def get_provider_model(
  613. self, model_type: ModelType, model: str, only_active: bool = False
  614. ) -> Optional[ModelWithProviderEntity]:
  615. """
  616. Get provider model.
  617. :param model_type: model type
  618. :param model: model name
  619. :param only_active: return active model only
  620. :return:
  621. """
  622. provider_models = self.get_provider_models(model_type, only_active)
  623. for provider_model in provider_models:
  624. if provider_model.model == model:
  625. return provider_model
  626. return None
  627. def get_provider_models(
  628. self, model_type: Optional[ModelType] = None, only_active: bool = False
  629. ) -> list[ModelWithProviderEntity]:
  630. """
  631. Get provider models.
  632. :param model_type: model type
  633. :param only_active: only active models
  634. :return:
  635. """
  636. model_provider_factory = ModelProviderFactory(self.tenant_id)
  637. provider_schema = model_provider_factory.get_provider_schema(self.provider.provider)
  638. model_types = []
  639. if model_type:
  640. model_types.append(model_type)
  641. else:
  642. model_types = provider_schema.supported_model_types
  643. # Group model settings by model type and model
  644. model_setting_map = defaultdict(dict)
  645. for model_setting in self.model_settings:
  646. model_setting_map[model_setting.model_type][model_setting.model] = model_setting
  647. if self.using_provider_type == ProviderType.SYSTEM:
  648. provider_models = self._get_system_provider_models(
  649. model_types=model_types, provider_schema=provider_schema, model_setting_map=model_setting_map
  650. )
  651. else:
  652. provider_models = self._get_custom_provider_models(
  653. model_types=model_types, provider_schema=provider_schema, model_setting_map=model_setting_map
  654. )
  655. if only_active:
  656. provider_models = [m for m in provider_models if m.status == ModelStatus.ACTIVE]
  657. # resort provider_models
  658. return sorted(provider_models, key=lambda x: x.model_type.value)
  659. def _get_system_provider_models(
  660. self,
  661. model_types: Sequence[ModelType],
  662. provider_schema: ProviderEntity,
  663. model_setting_map: dict[ModelType, dict[str, ModelSettings]],
  664. ) -> list[ModelWithProviderEntity]:
  665. """
  666. Get system provider models.
  667. :param model_types: model types
  668. :param provider_schema: provider schema
  669. :param model_setting_map: model setting map
  670. :return:
  671. """
  672. provider_models = []
  673. for model_type in model_types:
  674. for m in provider_schema.models:
  675. if m.model_type != model_type:
  676. continue
  677. status = ModelStatus.ACTIVE
  678. if m.model in model_setting_map:
  679. model_setting = model_setting_map[m.model_type][m.model]
  680. if model_setting.enabled is False:
  681. status = ModelStatus.DISABLED
  682. provider_models.append(
  683. ModelWithProviderEntity(
  684. model=m.model,
  685. label=m.label,
  686. model_type=m.model_type,
  687. features=m.features,
  688. fetch_from=m.fetch_from,
  689. model_properties=m.model_properties,
  690. deprecated=m.deprecated,
  691. provider=SimpleModelProviderEntity(self.provider),
  692. status=status,
  693. )
  694. )
  695. if self.provider.provider not in original_provider_configurate_methods:
  696. original_provider_configurate_methods[self.provider.provider] = []
  697. for configurate_method in provider_schema.configurate_methods:
  698. original_provider_configurate_methods[self.provider.provider].append(configurate_method)
  699. should_use_custom_model = False
  700. if original_provider_configurate_methods[self.provider.provider] == [ConfigurateMethod.CUSTOMIZABLE_MODEL]:
  701. should_use_custom_model = True
  702. for quota_configuration in self.system_configuration.quota_configurations:
  703. if self.system_configuration.current_quota_type != quota_configuration.quota_type:
  704. continue
  705. restrict_models = quota_configuration.restrict_models
  706. if len(restrict_models) == 0:
  707. break
  708. if should_use_custom_model:
  709. if original_provider_configurate_methods[self.provider.provider] == [
  710. ConfigurateMethod.CUSTOMIZABLE_MODEL
  711. ]:
  712. # only customizable model
  713. for restrict_model in restrict_models:
  714. copy_credentials = (
  715. self.system_configuration.credentials.copy()
  716. if self.system_configuration.credentials
  717. else {}
  718. )
  719. if restrict_model.base_model_name:
  720. copy_credentials["base_model_name"] = restrict_model.base_model_name
  721. try:
  722. custom_model_schema = self.get_model_schema(
  723. model_type=restrict_model.model_type,
  724. model=restrict_model.model,
  725. credentials=copy_credentials,
  726. )
  727. except Exception as ex:
  728. logger.warning(f"get custom model schema failed, {ex}")
  729. continue
  730. if not custom_model_schema:
  731. continue
  732. if custom_model_schema.model_type not in model_types:
  733. continue
  734. status = ModelStatus.ACTIVE
  735. if (
  736. custom_model_schema.model_type in model_setting_map
  737. and custom_model_schema.model in model_setting_map[custom_model_schema.model_type]
  738. ):
  739. model_setting = model_setting_map[custom_model_schema.model_type][custom_model_schema.model]
  740. if model_setting.enabled is False:
  741. status = ModelStatus.DISABLED
  742. provider_models.append(
  743. ModelWithProviderEntity(
  744. model=custom_model_schema.model,
  745. label=custom_model_schema.label,
  746. model_type=custom_model_schema.model_type,
  747. features=custom_model_schema.features,
  748. fetch_from=FetchFrom.PREDEFINED_MODEL,
  749. model_properties=custom_model_schema.model_properties,
  750. deprecated=custom_model_schema.deprecated,
  751. provider=SimpleModelProviderEntity(self.provider),
  752. status=status,
  753. )
  754. )
  755. # if llm name not in restricted llm list, remove it
  756. restrict_model_names = [rm.model for rm in restrict_models]
  757. for m in provider_models:
  758. if m.model_type == ModelType.LLM and m.model not in restrict_model_names:
  759. m.status = ModelStatus.NO_PERMISSION
  760. elif not quota_configuration.is_valid:
  761. m.status = ModelStatus.QUOTA_EXCEEDED
  762. return provider_models
  763. def _get_custom_provider_models(
  764. self,
  765. model_types: Sequence[ModelType],
  766. provider_schema: ProviderEntity,
  767. model_setting_map: dict[ModelType, dict[str, ModelSettings]],
  768. ) -> list[ModelWithProviderEntity]:
  769. """
  770. Get custom provider models.
  771. :param model_types: model types
  772. :param provider_schema: provider schema
  773. :param model_setting_map: model setting map
  774. :return:
  775. """
  776. provider_models = []
  777. credentials = None
  778. if self.custom_configuration.provider:
  779. credentials = self.custom_configuration.provider.credentials
  780. for model_type in model_types:
  781. if model_type not in self.provider.supported_model_types:
  782. continue
  783. for m in provider_schema.models:
  784. if m.model_type != model_type:
  785. continue
  786. status = ModelStatus.ACTIVE if credentials else ModelStatus.NO_CONFIGURE
  787. load_balancing_enabled = False
  788. if m.model_type in model_setting_map and m.model in model_setting_map[m.model_type]:
  789. model_setting = model_setting_map[m.model_type][m.model]
  790. if model_setting.enabled is False:
  791. status = ModelStatus.DISABLED
  792. if len(model_setting.load_balancing_configs) > 1:
  793. load_balancing_enabled = True
  794. provider_models.append(
  795. ModelWithProviderEntity(
  796. model=m.model,
  797. label=m.label,
  798. model_type=m.model_type,
  799. features=m.features,
  800. fetch_from=m.fetch_from,
  801. model_properties=m.model_properties,
  802. deprecated=m.deprecated,
  803. provider=SimpleModelProviderEntity(self.provider),
  804. status=status,
  805. load_balancing_enabled=load_balancing_enabled,
  806. )
  807. )
  808. # custom models
  809. for model_configuration in self.custom_configuration.models:
  810. if model_configuration.model_type not in model_types:
  811. continue
  812. try:
  813. custom_model_schema = self.get_model_schema(
  814. model_type=model_configuration.model_type,
  815. model=model_configuration.model,
  816. credentials=model_configuration.credentials,
  817. )
  818. except Exception as ex:
  819. logger.warning(f"get custom model schema failed, {ex}")
  820. continue
  821. if not custom_model_schema:
  822. continue
  823. status = ModelStatus.ACTIVE
  824. load_balancing_enabled = False
  825. if (
  826. custom_model_schema.model_type in model_setting_map
  827. and custom_model_schema.model in model_setting_map[custom_model_schema.model_type]
  828. ):
  829. model_setting = model_setting_map[custom_model_schema.model_type][custom_model_schema.model]
  830. if model_setting.enabled is False:
  831. status = ModelStatus.DISABLED
  832. if len(model_setting.load_balancing_configs) > 1:
  833. load_balancing_enabled = True
  834. provider_models.append(
  835. ModelWithProviderEntity(
  836. model=custom_model_schema.model,
  837. label=custom_model_schema.label,
  838. model_type=custom_model_schema.model_type,
  839. features=custom_model_schema.features,
  840. fetch_from=custom_model_schema.fetch_from,
  841. model_properties=custom_model_schema.model_properties,
  842. deprecated=custom_model_schema.deprecated,
  843. provider=SimpleModelProviderEntity(self.provider),
  844. status=status,
  845. load_balancing_enabled=load_balancing_enabled,
  846. )
  847. )
  848. return provider_models
  849. class ProviderConfigurations(BaseModel):
  850. """
  851. Model class for provider configuration dict.
  852. """
  853. tenant_id: str
  854. configurations: dict[str, ProviderConfiguration] = {}
  855. def __init__(self, tenant_id: str):
  856. super().__init__(tenant_id=tenant_id)
  857. def get_models(
  858. self, provider: Optional[str] = None, model_type: Optional[ModelType] = None, only_active: bool = False
  859. ) -> list[ModelWithProviderEntity]:
  860. """
  861. Get available models.
  862. If preferred provider type is `system`:
  863. Get the current **system mode** if provider supported,
  864. if all system modes are not available (no quota), it is considered to be the **custom credential mode**.
  865. If there is no model configured in custom mode, it is treated as no_configure.
  866. system > custom > no_configure
  867. If preferred provider type is `custom`:
  868. If custom credentials are configured, it is treated as custom mode.
  869. Otherwise, get the current **system mode** if supported,
  870. If all system modes are not available (no quota), it is treated as no_configure.
  871. custom > system > no_configure
  872. If real mode is `system`, use system credentials to get models,
  873. paid quotas > provider free quotas > system free quotas
  874. include pre-defined models (exclude GPT-4, status marked as `no_permission`).
  875. If real mode is `custom`, use workspace custom credentials to get models,
  876. include pre-defined models, custom models(manual append).
  877. If real mode is `no_configure`, only return pre-defined models from `model runtime`.
  878. (model status marked as `no_configure` if preferred provider type is `custom` otherwise `quota_exceeded`)
  879. model status marked as `active` is available.
  880. :param provider: provider name
  881. :param model_type: model type
  882. :param only_active: only active models
  883. :return:
  884. """
  885. all_models = []
  886. for provider_configuration in self.values():
  887. if provider and provider_configuration.provider.provider != provider:
  888. continue
  889. all_models.extend(provider_configuration.get_provider_models(model_type, only_active))
  890. return all_models
  891. def to_list(self) -> list[ProviderConfiguration]:
  892. """
  893. Convert to list.
  894. :return:
  895. """
  896. return list(self.values())
  897. def __getitem__(self, key):
  898. if "/" not in key:
  899. key = f"{DEFAULT_PLUGIN_ID}/{key}/{key}"
  900. return self.configurations[key]
  901. def __setitem__(self, key, value):
  902. self.configurations[key] = value
  903. def __iter__(self):
  904. return iter(self.configurations)
  905. def values(self) -> Iterator[ProviderConfiguration]:
  906. return iter(self.configurations.values())
  907. def get(self, key, default=None):
  908. if "/" not in key:
  909. key = f"{DEFAULT_PLUGIN_ID}/{key}/{key}"
  910. return self.configurations.get(key, default)
  911. class ProviderModelBundle(BaseModel):
  912. """
  913. Provider model bundle.
  914. """
  915. configuration: ProviderConfiguration
  916. model_type_instance: AIModel
  917. # pydantic configs
  918. model_config = ConfigDict(arbitrary_types_allowed=True, protected_namespaces=())