provider_configuration.py 42 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075
  1. import datetime
  2. import json
  3. import logging
  4. from collections import defaultdict
  5. from collections.abc import Iterator, Sequence
  6. from json import JSONDecodeError
  7. from typing import Optional
  8. from pydantic import BaseModel, ConfigDict
  9. from constants import HIDDEN_VALUE
  10. from core.entities.model_entities import ModelStatus, ModelWithProviderEntity, SimpleModelProviderEntity
  11. from core.entities.provider_entities import (
  12. CustomConfiguration,
  13. ModelSettings,
  14. SystemConfiguration,
  15. SystemConfigurationStatus,
  16. )
  17. from core.helper import encrypter
  18. from core.helper.model_provider_cache import ProviderCredentialsCache, ProviderCredentialsCacheType
  19. from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType
  20. from core.model_runtime.entities.provider_entities import (
  21. ConfigurateMethod,
  22. CredentialFormSchema,
  23. FormType,
  24. ProviderEntity,
  25. )
  26. from core.model_runtime.model_providers.__base.ai_model import AIModel
  27. from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
  28. from extensions.ext_database import db
  29. from models.provider import (
  30. LoadBalancingModelConfig,
  31. Provider,
  32. ProviderModel,
  33. ProviderModelSetting,
  34. ProviderType,
  35. TenantPreferredModelProvider,
  36. )
  37. logger = logging.getLogger(__name__)
  38. original_provider_configurate_methods = {}
  39. class ProviderConfiguration(BaseModel):
  40. """
  41. Model class for provider configuration.
  42. """
  43. tenant_id: str
  44. provider: ProviderEntity
  45. preferred_provider_type: ProviderType
  46. using_provider_type: ProviderType
  47. system_configuration: SystemConfiguration
  48. custom_configuration: CustomConfiguration
  49. model_settings: list[ModelSettings]
  50. # pydantic configs
  51. model_config = ConfigDict(protected_namespaces=())
  52. def __init__(self, **data):
  53. super().__init__(**data)
  54. if self.provider.provider not in original_provider_configurate_methods:
  55. original_provider_configurate_methods[self.provider.provider] = []
  56. for configurate_method in self.provider.configurate_methods:
  57. original_provider_configurate_methods[self.provider.provider].append(configurate_method)
  58. if original_provider_configurate_methods[self.provider.provider] == [ConfigurateMethod.CUSTOMIZABLE_MODEL]:
  59. if (
  60. any(
  61. len(quota_configuration.restrict_models) > 0
  62. for quota_configuration in self.system_configuration.quota_configurations
  63. )
  64. and ConfigurateMethod.PREDEFINED_MODEL not in self.provider.configurate_methods
  65. ):
  66. self.provider.configurate_methods.append(ConfigurateMethod.PREDEFINED_MODEL)
  67. def get_current_credentials(self, model_type: ModelType, model: str) -> Optional[dict]:
  68. """
  69. Get current credentials.
  70. :param model_type: model type
  71. :param model: model name
  72. :return:
  73. """
  74. if self.model_settings:
  75. # check if model is disabled by admin
  76. for model_setting in self.model_settings:
  77. if model_setting.model_type == model_type and model_setting.model == model:
  78. if not model_setting.enabled:
  79. raise ValueError(f"Model {model} is disabled.")
  80. if self.using_provider_type == ProviderType.SYSTEM:
  81. restrict_models = []
  82. for quota_configuration in self.system_configuration.quota_configurations:
  83. if self.system_configuration.current_quota_type != quota_configuration.quota_type:
  84. continue
  85. restrict_models = quota_configuration.restrict_models
  86. copy_credentials = (
  87. self.system_configuration.credentials.copy() if self.system_configuration.credentials else {}
  88. )
  89. if restrict_models:
  90. for restrict_model in restrict_models:
  91. if (
  92. restrict_model.model_type == model_type
  93. and restrict_model.model == model
  94. and restrict_model.base_model_name
  95. ):
  96. copy_credentials["base_model_name"] = restrict_model.base_model_name
  97. return copy_credentials
  98. else:
  99. credentials = None
  100. if self.custom_configuration.models:
  101. for model_configuration in self.custom_configuration.models:
  102. if model_configuration.model_type == model_type and model_configuration.model == model:
  103. credentials = model_configuration.credentials
  104. break
  105. if not credentials and self.custom_configuration.provider:
  106. credentials = self.custom_configuration.provider.credentials
  107. return credentials
  108. def get_system_configuration_status(self) -> SystemConfigurationStatus:
  109. """
  110. Get system configuration status.
  111. :return:
  112. """
  113. if self.system_configuration.enabled is False:
  114. return SystemConfigurationStatus.UNSUPPORTED
  115. current_quota_type = self.system_configuration.current_quota_type
  116. current_quota_configuration = next(
  117. (q for q in self.system_configuration.quota_configurations if q.quota_type == current_quota_type), None
  118. )
  119. if not current_quota_configuration:
  120. return SystemConfigurationStatus.UNSUPPORTED
  121. return (
  122. SystemConfigurationStatus.ACTIVE
  123. if current_quota_configuration.is_valid
  124. else SystemConfigurationStatus.QUOTA_EXCEEDED
  125. )
  126. def is_custom_configuration_available(self) -> bool:
  127. """
  128. Check custom configuration available.
  129. :return:
  130. """
  131. return self.custom_configuration.provider is not None or len(self.custom_configuration.models) > 0
  132. def get_custom_credentials(self, obfuscated: bool = False) -> Optional[dict]:
  133. """
  134. Get custom credentials.
  135. :param obfuscated: obfuscated secret data in credentials
  136. :return:
  137. """
  138. if self.custom_configuration.provider is None:
  139. return None
  140. credentials = self.custom_configuration.provider.credentials
  141. if not obfuscated:
  142. return credentials
  143. # Obfuscate credentials
  144. return self.obfuscated_credentials(
  145. credentials=credentials,
  146. credential_form_schemas=self.provider.provider_credential_schema.credential_form_schemas
  147. if self.provider.provider_credential_schema
  148. else [],
  149. )
  150. def custom_credentials_validate(self, credentials: dict) -> tuple[Provider | None, dict]:
  151. """
  152. Validate custom credentials.
  153. :param credentials: provider credentials
  154. :return:
  155. """
  156. # get provider
  157. provider_record = (
  158. db.session.query(Provider)
  159. .filter(
  160. Provider.tenant_id == self.tenant_id,
  161. Provider.provider_name == self.provider.provider,
  162. Provider.provider_type == ProviderType.CUSTOM.value,
  163. )
  164. .first()
  165. )
  166. # Get provider credential secret variables
  167. provider_credential_secret_variables = self.extract_secret_variables(
  168. self.provider.provider_credential_schema.credential_form_schemas
  169. if self.provider.provider_credential_schema
  170. else []
  171. )
  172. if provider_record:
  173. try:
  174. # fix origin data
  175. if provider_record.encrypted_config:
  176. if not provider_record.encrypted_config.startswith("{"):
  177. original_credentials = {"openai_api_key": provider_record.encrypted_config}
  178. else:
  179. original_credentials = json.loads(provider_record.encrypted_config)
  180. else:
  181. original_credentials = {}
  182. except JSONDecodeError:
  183. original_credentials = {}
  184. # encrypt credentials
  185. for key, value in credentials.items():
  186. if key in provider_credential_secret_variables:
  187. # if send [__HIDDEN__] in secret input, it will be same as original value
  188. if value == HIDDEN_VALUE and key in original_credentials:
  189. credentials[key] = encrypter.decrypt_token(self.tenant_id, original_credentials[key])
  190. model_provider_factory = ModelProviderFactory(self.tenant_id)
  191. credentials = model_provider_factory.provider_credentials_validate(
  192. provider=self.provider.provider, credentials=credentials
  193. )
  194. for key, value in credentials.items():
  195. if key in provider_credential_secret_variables:
  196. credentials[key] = encrypter.encrypt_token(self.tenant_id, value)
  197. return provider_record, credentials
  198. def add_or_update_custom_credentials(self, credentials: dict) -> None:
  199. """
  200. Add or update custom provider credentials.
  201. :param credentials:
  202. :return:
  203. """
  204. # validate custom provider config
  205. provider_record, credentials = self.custom_credentials_validate(credentials)
  206. # save provider
  207. # Note: Do not switch the preferred provider, which allows users to use quotas first
  208. if provider_record:
  209. provider_record.encrypted_config = json.dumps(credentials)
  210. provider_record.is_valid = True
  211. provider_record.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
  212. db.session.commit()
  213. else:
  214. provider_record = Provider()
  215. provider_record.tenant_id = self.tenant_id
  216. provider_record.provider_name = self.provider.provider
  217. provider_record.provider_type = ProviderType.CUSTOM.value
  218. provider_record.encrypted_config = json.dumps(credentials)
  219. provider_record.is_valid = True
  220. db.session.add(provider_record)
  221. db.session.commit()
  222. provider_model_credentials_cache = ProviderCredentialsCache(
  223. tenant_id=self.tenant_id, identity_id=provider_record.id, cache_type=ProviderCredentialsCacheType.PROVIDER
  224. )
  225. provider_model_credentials_cache.delete()
  226. self.switch_preferred_provider_type(ProviderType.CUSTOM)
  227. def delete_custom_credentials(self) -> None:
  228. """
  229. Delete custom provider credentials.
  230. :return:
  231. """
  232. # get provider
  233. provider_record = (
  234. db.session.query(Provider)
  235. .filter(
  236. Provider.tenant_id == self.tenant_id,
  237. Provider.provider_name == self.provider.provider,
  238. Provider.provider_type == ProviderType.CUSTOM.value,
  239. )
  240. .first()
  241. )
  242. # delete provider
  243. if provider_record:
  244. self.switch_preferred_provider_type(ProviderType.SYSTEM)
  245. db.session.delete(provider_record)
  246. db.session.commit()
  247. provider_model_credentials_cache = ProviderCredentialsCache(
  248. tenant_id=self.tenant_id,
  249. identity_id=provider_record.id,
  250. cache_type=ProviderCredentialsCacheType.PROVIDER,
  251. )
  252. provider_model_credentials_cache.delete()
  253. def get_custom_model_credentials(
  254. self, model_type: ModelType, model: str, obfuscated: bool = False
  255. ) -> Optional[dict]:
  256. """
  257. Get custom model credentials.
  258. :param model_type: model type
  259. :param model: model name
  260. :param obfuscated: obfuscated secret data in credentials
  261. :return:
  262. """
  263. if not self.custom_configuration.models:
  264. return None
  265. for model_configuration in self.custom_configuration.models:
  266. if model_configuration.model_type == model_type and model_configuration.model == model:
  267. credentials = model_configuration.credentials
  268. if not obfuscated:
  269. return credentials
  270. # Obfuscate credentials
  271. return self.obfuscated_credentials(
  272. credentials=credentials,
  273. credential_form_schemas=self.provider.model_credential_schema.credential_form_schemas
  274. if self.provider.model_credential_schema
  275. else [],
  276. )
  277. return None
  278. def custom_model_credentials_validate(
  279. self, model_type: ModelType, model: str, credentials: dict
  280. ) -> tuple[ProviderModel | None, dict]:
  281. """
  282. Validate custom model credentials.
  283. :param model_type: model type
  284. :param model: model name
  285. :param credentials: model credentials
  286. :return:
  287. """
  288. # get provider model
  289. provider_model_record = (
  290. db.session.query(ProviderModel)
  291. .filter(
  292. ProviderModel.tenant_id == self.tenant_id,
  293. ProviderModel.provider_name == self.provider.provider,
  294. ProviderModel.model_name == model,
  295. ProviderModel.model_type == model_type.to_origin_model_type(),
  296. )
  297. .first()
  298. )
  299. # Get provider credential secret variables
  300. provider_credential_secret_variables = self.extract_secret_variables(
  301. self.provider.model_credential_schema.credential_form_schemas
  302. if self.provider.model_credential_schema
  303. else []
  304. )
  305. if provider_model_record:
  306. try:
  307. original_credentials = (
  308. json.loads(provider_model_record.encrypted_config) if provider_model_record.encrypted_config else {}
  309. )
  310. except JSONDecodeError:
  311. original_credentials = {}
  312. # decrypt credentials
  313. for key, value in credentials.items():
  314. if key in provider_credential_secret_variables:
  315. # if send [__HIDDEN__] in secret input, it will be same as original value
  316. if value == HIDDEN_VALUE and key in original_credentials:
  317. credentials[key] = encrypter.decrypt_token(self.tenant_id, original_credentials[key])
  318. model_provider_factory = ModelProviderFactory(self.tenant_id)
  319. credentials = model_provider_factory.model_credentials_validate(
  320. provider=self.provider.provider, model_type=model_type, model=model, credentials=credentials
  321. )
  322. for key, value in credentials.items():
  323. if key in provider_credential_secret_variables:
  324. credentials[key] = encrypter.encrypt_token(self.tenant_id, value)
  325. return provider_model_record, credentials
  326. def add_or_update_custom_model_credentials(self, model_type: ModelType, model: str, credentials: dict) -> None:
  327. """
  328. Add or update custom model credentials.
  329. :param model_type: model type
  330. :param model: model name
  331. :param credentials: model credentials
  332. :return:
  333. """
  334. # validate custom model config
  335. provider_model_record, credentials = self.custom_model_credentials_validate(model_type, model, credentials)
  336. # save provider model
  337. # Note: Do not switch the preferred provider, which allows users to use quotas first
  338. if provider_model_record:
  339. provider_model_record.encrypted_config = json.dumps(credentials)
  340. provider_model_record.is_valid = True
  341. provider_model_record.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
  342. db.session.commit()
  343. else:
  344. provider_model_record = ProviderModel()
  345. provider_model_record.tenant_id = self.tenant_id
  346. provider_model_record.provider_name = self.provider.provider
  347. provider_model_record.model_name = model
  348. provider_model_record.model_type = model_type.to_origin_model_type()
  349. provider_model_record.encrypted_config = json.dumps(credentials)
  350. provider_model_record.is_valid = True
  351. db.session.add(provider_model_record)
  352. db.session.commit()
  353. provider_model_credentials_cache = ProviderCredentialsCache(
  354. tenant_id=self.tenant_id,
  355. identity_id=provider_model_record.id,
  356. cache_type=ProviderCredentialsCacheType.MODEL,
  357. )
  358. provider_model_credentials_cache.delete()
  359. def delete_custom_model_credentials(self, model_type: ModelType, model: str) -> None:
  360. """
  361. Delete custom model credentials.
  362. :param model_type: model type
  363. :param model: model name
  364. :return:
  365. """
  366. # get provider model
  367. provider_model_record = (
  368. db.session.query(ProviderModel)
  369. .filter(
  370. ProviderModel.tenant_id == self.tenant_id,
  371. ProviderModel.provider_name == self.provider.provider,
  372. ProviderModel.model_name == model,
  373. ProviderModel.model_type == model_type.to_origin_model_type(),
  374. )
  375. .first()
  376. )
  377. # delete provider model
  378. if provider_model_record:
  379. db.session.delete(provider_model_record)
  380. db.session.commit()
  381. provider_model_credentials_cache = ProviderCredentialsCache(
  382. tenant_id=self.tenant_id,
  383. identity_id=provider_model_record.id,
  384. cache_type=ProviderCredentialsCacheType.MODEL,
  385. )
  386. provider_model_credentials_cache.delete()
  387. def enable_model(self, model_type: ModelType, model: str) -> ProviderModelSetting:
  388. """
  389. Enable model.
  390. :param model_type: model type
  391. :param model: model name
  392. :return:
  393. """
  394. model_setting = (
  395. db.session.query(ProviderModelSetting)
  396. .filter(
  397. ProviderModelSetting.tenant_id == self.tenant_id,
  398. ProviderModelSetting.provider_name == self.provider.provider,
  399. ProviderModelSetting.model_type == model_type.to_origin_model_type(),
  400. ProviderModelSetting.model_name == model,
  401. )
  402. .first()
  403. )
  404. if model_setting:
  405. model_setting.enabled = True
  406. model_setting.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
  407. db.session.commit()
  408. else:
  409. model_setting = ProviderModelSetting()
  410. model_setting.tenant_id = self.tenant_id
  411. model_setting.provider_name = self.provider.provider
  412. model_setting.model_type = model_type.to_origin_model_type()
  413. model_setting.model_name = model
  414. model_setting.enabled = True
  415. db.session.add(model_setting)
  416. db.session.commit()
  417. return model_setting
  418. def disable_model(self, model_type: ModelType, model: str) -> ProviderModelSetting:
  419. """
  420. Disable model.
  421. :param model_type: model type
  422. :param model: model name
  423. :return:
  424. """
  425. model_setting = (
  426. db.session.query(ProviderModelSetting)
  427. .filter(
  428. ProviderModelSetting.tenant_id == self.tenant_id,
  429. ProviderModelSetting.provider_name == self.provider.provider,
  430. ProviderModelSetting.model_type == model_type.to_origin_model_type(),
  431. ProviderModelSetting.model_name == model,
  432. )
  433. .first()
  434. )
  435. if model_setting:
  436. model_setting.enabled = False
  437. model_setting.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
  438. db.session.commit()
  439. else:
  440. model_setting = ProviderModelSetting()
  441. model_setting.tenant_id = self.tenant_id
  442. model_setting.provider_name = self.provider.provider
  443. model_setting.model_type = model_type.to_origin_model_type()
  444. model_setting.model_name = model
  445. model_setting.enabled = False
  446. db.session.add(model_setting)
  447. db.session.commit()
  448. return model_setting
  449. def get_provider_model_setting(self, model_type: ModelType, model: str) -> Optional[ProviderModelSetting]:
  450. """
  451. Get provider model setting.
  452. :param model_type: model type
  453. :param model: model name
  454. :return:
  455. """
  456. return (
  457. db.session.query(ProviderModelSetting)
  458. .filter(
  459. ProviderModelSetting.tenant_id == self.tenant_id,
  460. ProviderModelSetting.provider_name == self.provider.provider,
  461. ProviderModelSetting.model_type == model_type.to_origin_model_type(),
  462. ProviderModelSetting.model_name == model,
  463. )
  464. .first()
  465. )
  466. def enable_model_load_balancing(self, model_type: ModelType, model: str) -> ProviderModelSetting:
  467. """
  468. Enable model load balancing.
  469. :param model_type: model type
  470. :param model: model name
  471. :return:
  472. """
  473. load_balancing_config_count = (
  474. db.session.query(LoadBalancingModelConfig)
  475. .filter(
  476. LoadBalancingModelConfig.tenant_id == self.tenant_id,
  477. LoadBalancingModelConfig.provider_name == self.provider.provider,
  478. LoadBalancingModelConfig.model_type == model_type.to_origin_model_type(),
  479. LoadBalancingModelConfig.model_name == model,
  480. )
  481. .count()
  482. )
  483. if load_balancing_config_count <= 1:
  484. raise ValueError("Model load balancing configuration must be more than 1.")
  485. model_setting = (
  486. db.session.query(ProviderModelSetting)
  487. .filter(
  488. ProviderModelSetting.tenant_id == self.tenant_id,
  489. ProviderModelSetting.provider_name == self.provider.provider,
  490. ProviderModelSetting.model_type == model_type.to_origin_model_type(),
  491. ProviderModelSetting.model_name == model,
  492. )
  493. .first()
  494. )
  495. if model_setting:
  496. model_setting.load_balancing_enabled = True
  497. model_setting.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
  498. db.session.commit()
  499. else:
  500. model_setting = ProviderModelSetting()
  501. model_setting.tenant_id = self.tenant_id
  502. model_setting.provider_name = self.provider.provider
  503. model_setting.model_type = model_type.to_origin_model_type()
  504. model_setting.model_name = model
  505. model_setting.load_balancing_enabled = True
  506. db.session.add(model_setting)
  507. db.session.commit()
  508. return model_setting
  509. def disable_model_load_balancing(self, model_type: ModelType, model: str) -> ProviderModelSetting:
  510. """
  511. Disable model load balancing.
  512. :param model_type: model type
  513. :param model: model name
  514. :return:
  515. """
  516. model_setting = (
  517. db.session.query(ProviderModelSetting)
  518. .filter(
  519. ProviderModelSetting.tenant_id == self.tenant_id,
  520. ProviderModelSetting.provider_name == self.provider.provider,
  521. ProviderModelSetting.model_type == model_type.to_origin_model_type(),
  522. ProviderModelSetting.model_name == model,
  523. )
  524. .first()
  525. )
  526. if model_setting:
  527. model_setting.load_balancing_enabled = False
  528. model_setting.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
  529. db.session.commit()
  530. else:
  531. model_setting = ProviderModelSetting()
  532. model_setting.tenant_id = self.tenant_id
  533. model_setting.provider_name = self.provider.provider
  534. model_setting.model_type = model_type.to_origin_model_type()
  535. model_setting.model_name = model
  536. model_setting.load_balancing_enabled = False
  537. db.session.add(model_setting)
  538. db.session.commit()
  539. return model_setting
  540. def get_model_type_instance(self, model_type: ModelType) -> AIModel:
  541. """
  542. Get current model type instance.
  543. :param model_type: model type
  544. :return:
  545. """
  546. model_provider_factory = ModelProviderFactory(self.tenant_id)
  547. # Get model instance of LLM
  548. return model_provider_factory.get_model_type_instance(provider=self.provider.provider, model_type=model_type)
  549. def get_model_schema(self, model_type: ModelType, model: str, credentials: dict) -> AIModelEntity | None:
  550. """
  551. Get model schema
  552. """
  553. model_provider_factory = ModelProviderFactory(self.tenant_id)
  554. return model_provider_factory.get_model_schema(
  555. provider=self.provider.provider, model_type=model_type, model=model, credentials=credentials
  556. )
  557. def switch_preferred_provider_type(self, provider_type: ProviderType) -> None:
  558. """
  559. Switch preferred provider type.
  560. :param provider_type:
  561. :return:
  562. """
  563. if provider_type == self.preferred_provider_type:
  564. return
  565. if provider_type == ProviderType.SYSTEM and not self.system_configuration.enabled:
  566. return
  567. # get preferred provider
  568. preferred_model_provider = (
  569. db.session.query(TenantPreferredModelProvider)
  570. .filter(
  571. TenantPreferredModelProvider.tenant_id == self.tenant_id,
  572. TenantPreferredModelProvider.provider_name == self.provider.provider,
  573. )
  574. .first()
  575. )
  576. if preferred_model_provider:
  577. preferred_model_provider.preferred_provider_type = provider_type.value
  578. else:
  579. preferred_model_provider = TenantPreferredModelProvider()
  580. preferred_model_provider.tenant_id = self.tenant_id
  581. preferred_model_provider.provider_name = self.provider.provider
  582. preferred_model_provider.preferred_provider_type = provider_type.value
  583. db.session.add(preferred_model_provider)
  584. db.session.commit()
  585. def extract_secret_variables(self, credential_form_schemas: list[CredentialFormSchema]) -> list[str]:
  586. """
  587. Extract secret input form variables.
  588. :param credential_form_schemas:
  589. :return:
  590. """
  591. secret_input_form_variables = []
  592. for credential_form_schema in credential_form_schemas:
  593. if credential_form_schema.type == FormType.SECRET_INPUT:
  594. secret_input_form_variables.append(credential_form_schema.variable)
  595. return secret_input_form_variables
  596. def obfuscated_credentials(self, credentials: dict, credential_form_schemas: list[CredentialFormSchema]) -> dict:
  597. """
  598. Obfuscated credentials.
  599. :param credentials: credentials
  600. :param credential_form_schemas: credential form schemas
  601. :return:
  602. """
  603. # Get provider credential secret variables
  604. credential_secret_variables = self.extract_secret_variables(credential_form_schemas)
  605. # Obfuscate provider credentials
  606. copy_credentials = credentials.copy()
  607. for key, value in copy_credentials.items():
  608. if key in credential_secret_variables:
  609. copy_credentials[key] = encrypter.obfuscated_token(value)
  610. return copy_credentials
  611. def get_provider_model(
  612. self, model_type: ModelType, model: str, only_active: bool = False
  613. ) -> Optional[ModelWithProviderEntity]:
  614. """
  615. Get provider model.
  616. :param model_type: model type
  617. :param model: model name
  618. :param only_active: return active model only
  619. :return:
  620. """
  621. provider_models = self.get_provider_models(model_type, only_active)
  622. for provider_model in provider_models:
  623. if provider_model.model == model:
  624. return provider_model
  625. return None
  626. def get_provider_models(
  627. self, model_type: Optional[ModelType] = None, only_active: bool = False
  628. ) -> list[ModelWithProviderEntity]:
  629. """
  630. Get provider models.
  631. :param model_type: model type
  632. :param only_active: only active models
  633. :return:
  634. """
  635. model_provider_factory = ModelProviderFactory(self.tenant_id)
  636. provider_schema = model_provider_factory.get_provider_schema(self.provider.provider)
  637. model_types = []
  638. if model_type:
  639. model_types.append(model_type)
  640. else:
  641. model_types = provider_schema.supported_model_types
  642. # Group model settings by model type and model
  643. model_setting_map = defaultdict(dict)
  644. for model_setting in self.model_settings:
  645. model_setting_map[model_setting.model_type][model_setting.model] = model_setting
  646. if self.using_provider_type == ProviderType.SYSTEM:
  647. provider_models = self._get_system_provider_models(
  648. model_types=model_types, provider_schema=provider_schema, model_setting_map=model_setting_map
  649. )
  650. else:
  651. provider_models = self._get_custom_provider_models(
  652. model_types=model_types, provider_schema=provider_schema, model_setting_map=model_setting_map
  653. )
  654. if only_active:
  655. provider_models = [m for m in provider_models if m.status == ModelStatus.ACTIVE]
  656. # resort provider_models
  657. return sorted(provider_models, key=lambda x: x.model_type.value)
  658. def _get_system_provider_models(
  659. self,
  660. model_types: Sequence[ModelType],
  661. provider_schema: ProviderEntity,
  662. model_setting_map: dict[ModelType, dict[str, ModelSettings]],
  663. ) -> list[ModelWithProviderEntity]:
  664. """
  665. Get system provider models.
  666. :param model_types: model types
  667. :param provider_schema: provider schema
  668. :param model_setting_map: model setting map
  669. :return:
  670. """
  671. provider_models = []
  672. for model_type in model_types:
  673. for m in provider_schema.models:
  674. if m.model_type != model_type:
  675. continue
  676. status = ModelStatus.ACTIVE
  677. if m.model in model_setting_map:
  678. model_setting = model_setting_map[m.model_type][m.model]
  679. if model_setting.enabled is False:
  680. status = ModelStatus.DISABLED
  681. provider_models.append(
  682. ModelWithProviderEntity(
  683. model=m.model,
  684. label=m.label,
  685. model_type=m.model_type,
  686. features=m.features,
  687. fetch_from=m.fetch_from,
  688. model_properties=m.model_properties,
  689. deprecated=m.deprecated,
  690. provider=SimpleModelProviderEntity(self.provider),
  691. status=status,
  692. )
  693. )
  694. if self.provider.provider not in original_provider_configurate_methods:
  695. original_provider_configurate_methods[self.provider.provider] = []
  696. for configurate_method in provider_schema.configurate_methods:
  697. original_provider_configurate_methods[self.provider.provider].append(configurate_method)
  698. should_use_custom_model = False
  699. if original_provider_configurate_methods[self.provider.provider] == [ConfigurateMethod.CUSTOMIZABLE_MODEL]:
  700. should_use_custom_model = True
  701. for quota_configuration in self.system_configuration.quota_configurations:
  702. if self.system_configuration.current_quota_type != quota_configuration.quota_type:
  703. continue
  704. restrict_models = quota_configuration.restrict_models
  705. if len(restrict_models) == 0:
  706. break
  707. if should_use_custom_model:
  708. if original_provider_configurate_methods[self.provider.provider] == [
  709. ConfigurateMethod.CUSTOMIZABLE_MODEL
  710. ]:
  711. # only customizable model
  712. for restrict_model in restrict_models:
  713. copy_credentials = (
  714. self.system_configuration.credentials.copy()
  715. if self.system_configuration.credentials
  716. else {}
  717. )
  718. if restrict_model.base_model_name:
  719. copy_credentials["base_model_name"] = restrict_model.base_model_name
  720. try:
  721. custom_model_schema = self.get_model_schema(
  722. model_type=restrict_model.model_type,
  723. model=restrict_model.model,
  724. credentials=copy_credentials,
  725. )
  726. except Exception as ex:
  727. logger.warning(f"get custom model schema failed, {ex}")
  728. continue
  729. if not custom_model_schema:
  730. continue
  731. if custom_model_schema.model_type not in model_types:
  732. continue
  733. status = ModelStatus.ACTIVE
  734. if (
  735. custom_model_schema.model_type in model_setting_map
  736. and custom_model_schema.model in model_setting_map[custom_model_schema.model_type]
  737. ):
  738. model_setting = model_setting_map[custom_model_schema.model_type][custom_model_schema.model]
  739. if model_setting.enabled is False:
  740. status = ModelStatus.DISABLED
  741. provider_models.append(
  742. ModelWithProviderEntity(
  743. model=custom_model_schema.model,
  744. label=custom_model_schema.label,
  745. model_type=custom_model_schema.model_type,
  746. features=custom_model_schema.features,
  747. fetch_from=FetchFrom.PREDEFINED_MODEL,
  748. model_properties=custom_model_schema.model_properties,
  749. deprecated=custom_model_schema.deprecated,
  750. provider=SimpleModelProviderEntity(self.provider),
  751. status=status,
  752. )
  753. )
  754. # if llm name not in restricted llm list, remove it
  755. restrict_model_names = [rm.model for rm in restrict_models]
  756. for m in provider_models:
  757. if m.model_type == ModelType.LLM and m.model not in restrict_model_names:
  758. m.status = ModelStatus.NO_PERMISSION
  759. elif not quota_configuration.is_valid:
  760. m.status = ModelStatus.QUOTA_EXCEEDED
  761. return provider_models
  762. def _get_custom_provider_models(
  763. self,
  764. model_types: Sequence[ModelType],
  765. provider_schema: ProviderEntity,
  766. model_setting_map: dict[ModelType, dict[str, ModelSettings]],
  767. ) -> list[ModelWithProviderEntity]:
  768. """
  769. Get custom provider models.
  770. :param model_types: model types
  771. :param provider_schema: provider schema
  772. :param model_setting_map: model setting map
  773. :return:
  774. """
  775. provider_models = []
  776. credentials = None
  777. if self.custom_configuration.provider:
  778. credentials = self.custom_configuration.provider.credentials
  779. for model_type in model_types:
  780. if model_type not in self.provider.supported_model_types:
  781. continue
  782. for m in provider_schema.models:
  783. if m.model_type != model_type:
  784. continue
  785. status = ModelStatus.ACTIVE if credentials else ModelStatus.NO_CONFIGURE
  786. load_balancing_enabled = False
  787. if m.model_type in model_setting_map and m.model in model_setting_map[m.model_type]:
  788. model_setting = model_setting_map[m.model_type][m.model]
  789. if model_setting.enabled is False:
  790. status = ModelStatus.DISABLED
  791. if len(model_setting.load_balancing_configs) > 1:
  792. load_balancing_enabled = True
  793. provider_models.append(
  794. ModelWithProviderEntity(
  795. model=m.model,
  796. label=m.label,
  797. model_type=m.model_type,
  798. features=m.features,
  799. fetch_from=m.fetch_from,
  800. model_properties=m.model_properties,
  801. deprecated=m.deprecated,
  802. provider=SimpleModelProviderEntity(self.provider),
  803. status=status,
  804. load_balancing_enabled=load_balancing_enabled,
  805. )
  806. )
  807. # custom models
  808. for model_configuration in self.custom_configuration.models:
  809. if model_configuration.model_type not in model_types:
  810. continue
  811. try:
  812. custom_model_schema = self.get_model_schema(
  813. model_type=model_configuration.model_type,
  814. model=model_configuration.model,
  815. credentials=model_configuration.credentials,
  816. )
  817. except Exception as ex:
  818. logger.warning(f"get custom model schema failed, {ex}")
  819. continue
  820. if not custom_model_schema:
  821. continue
  822. status = ModelStatus.ACTIVE
  823. load_balancing_enabled = False
  824. if (
  825. custom_model_schema.model_type in model_setting_map
  826. and custom_model_schema.model in model_setting_map[custom_model_schema.model_type]
  827. ):
  828. model_setting = model_setting_map[custom_model_schema.model_type][custom_model_schema.model]
  829. if model_setting.enabled is False:
  830. status = ModelStatus.DISABLED
  831. if len(model_setting.load_balancing_configs) > 1:
  832. load_balancing_enabled = True
  833. provider_models.append(
  834. ModelWithProviderEntity(
  835. model=custom_model_schema.model,
  836. label=custom_model_schema.label,
  837. model_type=custom_model_schema.model_type,
  838. features=custom_model_schema.features,
  839. fetch_from=custom_model_schema.fetch_from,
  840. model_properties=custom_model_schema.model_properties,
  841. deprecated=custom_model_schema.deprecated,
  842. provider=SimpleModelProviderEntity(self.provider),
  843. status=status,
  844. load_balancing_enabled=load_balancing_enabled,
  845. )
  846. )
  847. return provider_models
  848. class ProviderConfigurations(BaseModel):
  849. """
  850. Model class for provider configuration dict.
  851. """
  852. tenant_id: str
  853. configurations: dict[str, ProviderConfiguration] = {}
  854. def __init__(self, tenant_id: str):
  855. super().__init__(tenant_id=tenant_id)
  856. def get_models(
  857. self, provider: Optional[str] = None, model_type: Optional[ModelType] = None, only_active: bool = False
  858. ) -> list[ModelWithProviderEntity]:
  859. """
  860. Get available models.
  861. If preferred provider type is `system`:
  862. Get the current **system mode** if provider supported,
  863. if all system modes are not available (no quota), it is considered to be the **custom credential mode**.
  864. If there is no model configured in custom mode, it is treated as no_configure.
  865. system > custom > no_configure
  866. If preferred provider type is `custom`:
  867. If custom credentials are configured, it is treated as custom mode.
  868. Otherwise, get the current **system mode** if supported,
  869. If all system modes are not available (no quota), it is treated as no_configure.
  870. custom > system > no_configure
  871. If real mode is `system`, use system credentials to get models,
  872. paid quotas > provider free quotas > system free quotas
  873. include pre-defined models (exclude GPT-4, status marked as `no_permission`).
  874. If real mode is `custom`, use workspace custom credentials to get models,
  875. include pre-defined models, custom models(manual append).
  876. If real mode is `no_configure`, only return pre-defined models from `model runtime`.
  877. (model status marked as `no_configure` if preferred provider type is `custom` otherwise `quota_exceeded`)
  878. model status marked as `active` is available.
  879. :param provider: provider name
  880. :param model_type: model type
  881. :param only_active: only active models
  882. :return:
  883. """
  884. all_models = []
  885. for provider_configuration in self.values():
  886. if provider and provider_configuration.provider.provider != provider:
  887. continue
  888. all_models.extend(provider_configuration.get_provider_models(model_type, only_active))
  889. return all_models
  890. def to_list(self) -> list[ProviderConfiguration]:
  891. """
  892. Convert to list.
  893. :return:
  894. """
  895. return list(self.values())
  896. def __getitem__(self, key):
  897. return self.configurations[key]
  898. def __setitem__(self, key, value):
  899. self.configurations[key] = value
  900. def __iter__(self):
  901. return iter(self.configurations)
  902. def values(self) -> Iterator[ProviderConfiguration]:
  903. return iter(self.configurations.values())
  904. def get(self, key, default=None):
  905. return self.configurations.get(key, default)
  906. class ProviderModelBundle(BaseModel):
  907. """
  908. Provider model bundle.
  909. """
  910. configuration: ProviderConfiguration
  911. model_type_instance: AIModel
  912. # pydantic configs
  913. model_config = ConfigDict(arbitrary_types_allowed=True, protected_namespaces=())