provider_configuration.py 43 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104
  1. import datetime
  2. import json
  3. import logging
  4. from collections import defaultdict
  5. from collections.abc import Iterator, Sequence
  6. from json import JSONDecodeError
  7. from typing import Optional
  8. from pydantic import BaseModel, ConfigDict, Field
  9. from sqlalchemy import or_
  10. from constants import HIDDEN_VALUE
  11. from core.entities.model_entities import ModelStatus, ModelWithProviderEntity, SimpleModelProviderEntity
  12. from core.entities.provider_entities import (
  13. CustomConfiguration,
  14. ModelSettings,
  15. SystemConfiguration,
  16. SystemConfigurationStatus,
  17. )
  18. from core.helper import encrypter
  19. from core.helper.model_provider_cache import ProviderCredentialsCache, ProviderCredentialsCacheType
  20. from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType
  21. from core.model_runtime.entities.provider_entities import (
  22. ConfigurateMethod,
  23. CredentialFormSchema,
  24. FormType,
  25. ProviderEntity,
  26. )
  27. from core.model_runtime.model_providers.__base.ai_model import AIModel
  28. from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
  29. from core.plugin.entities.plugin import ModelProviderID
  30. from extensions.ext_database import db
  31. from models.provider import (
  32. LoadBalancingModelConfig,
  33. Provider,
  34. ProviderModel,
  35. ProviderModelSetting,
  36. ProviderType,
  37. TenantPreferredModelProvider,
  38. )
  39. logger = logging.getLogger(__name__)
  40. original_provider_configurate_methods: dict[str, list[ConfigurateMethod]] = {}
  41. class ProviderConfiguration(BaseModel):
  42. """
  43. Model class for provider configuration.
  44. """
  45. tenant_id: str
  46. provider: ProviderEntity
  47. preferred_provider_type: ProviderType
  48. using_provider_type: ProviderType
  49. system_configuration: SystemConfiguration
  50. custom_configuration: CustomConfiguration
  51. model_settings: list[ModelSettings]
  52. # pydantic configs
  53. model_config = ConfigDict(protected_namespaces=())
  54. def __init__(self, **data):
  55. super().__init__(**data)
  56. if self.provider.provider not in original_provider_configurate_methods:
  57. original_provider_configurate_methods[self.provider.provider] = []
  58. for configurate_method in self.provider.configurate_methods:
  59. original_provider_configurate_methods[self.provider.provider].append(configurate_method)
  60. if original_provider_configurate_methods[self.provider.provider] == [ConfigurateMethod.CUSTOMIZABLE_MODEL]:
  61. if (
  62. any(
  63. len(quota_configuration.restrict_models) > 0
  64. for quota_configuration in self.system_configuration.quota_configurations
  65. )
  66. and ConfigurateMethod.PREDEFINED_MODEL not in self.provider.configurate_methods
  67. ):
  68. self.provider.configurate_methods.append(ConfigurateMethod.PREDEFINED_MODEL)
  69. def get_current_credentials(self, model_type: ModelType, model: str) -> Optional[dict]:
  70. """
  71. Get current credentials.
  72. :param model_type: model type
  73. :param model: model name
  74. :return:
  75. """
  76. if self.model_settings:
  77. # check if model is disabled by admin
  78. for model_setting in self.model_settings:
  79. if model_setting.model_type == model_type and model_setting.model == model:
  80. if not model_setting.enabled:
  81. raise ValueError(f"Model {model} is disabled.")
  82. if self.using_provider_type == ProviderType.SYSTEM:
  83. restrict_models = []
  84. for quota_configuration in self.system_configuration.quota_configurations:
  85. if self.system_configuration.current_quota_type != quota_configuration.quota_type:
  86. continue
  87. restrict_models = quota_configuration.restrict_models
  88. copy_credentials = (
  89. self.system_configuration.credentials.copy() if self.system_configuration.credentials else {}
  90. )
  91. if restrict_models:
  92. for restrict_model in restrict_models:
  93. if (
  94. restrict_model.model_type == model_type
  95. and restrict_model.model == model
  96. and restrict_model.base_model_name
  97. ):
  98. copy_credentials["base_model_name"] = restrict_model.base_model_name
  99. return copy_credentials
  100. else:
  101. credentials = None
  102. if self.custom_configuration.models:
  103. for model_configuration in self.custom_configuration.models:
  104. if model_configuration.model_type == model_type and model_configuration.model == model:
  105. credentials = model_configuration.credentials
  106. break
  107. if not credentials and self.custom_configuration.provider:
  108. credentials = self.custom_configuration.provider.credentials
  109. return credentials
  110. def get_system_configuration_status(self) -> Optional[SystemConfigurationStatus]:
  111. """
  112. Get system configuration status.
  113. :return:
  114. """
  115. if self.system_configuration.enabled is False:
  116. return SystemConfigurationStatus.UNSUPPORTED
  117. current_quota_type = self.system_configuration.current_quota_type
  118. current_quota_configuration = next(
  119. (q for q in self.system_configuration.quota_configurations if q.quota_type == current_quota_type), None
  120. )
  121. if current_quota_configuration is None:
  122. return None
  123. if not current_quota_configuration:
  124. return SystemConfigurationStatus.UNSUPPORTED
  125. return (
  126. SystemConfigurationStatus.ACTIVE
  127. if current_quota_configuration.is_valid
  128. else SystemConfigurationStatus.QUOTA_EXCEEDED
  129. )
  130. def is_custom_configuration_available(self) -> bool:
  131. """
  132. Check custom configuration available.
  133. :return:
  134. """
  135. return self.custom_configuration.provider is not None or len(self.custom_configuration.models) > 0
  136. def get_custom_credentials(self, obfuscated: bool = False) -> dict | None:
  137. """
  138. Get custom credentials.
  139. :param obfuscated: obfuscated secret data in credentials
  140. :return:
  141. """
  142. if self.custom_configuration.provider is None:
  143. return None
  144. credentials = self.custom_configuration.provider.credentials
  145. if not obfuscated:
  146. return credentials
  147. # Obfuscate credentials
  148. return self.obfuscated_credentials(
  149. credentials=credentials,
  150. credential_form_schemas=self.provider.provider_credential_schema.credential_form_schemas
  151. if self.provider.provider_credential_schema
  152. else [],
  153. )
  154. def custom_credentials_validate(self, credentials: dict) -> tuple[Provider | None, dict]:
  155. """
  156. Validate custom credentials.
  157. :param credentials: provider credentials
  158. :return:
  159. """
  160. # get provider
  161. model_provider_id = ModelProviderID(self.provider.provider)
  162. if model_provider_id.is_langgenius():
  163. provider_record = (
  164. db.session.query(Provider)
  165. .filter(
  166. Provider.tenant_id == self.tenant_id,
  167. Provider.provider_type == ProviderType.CUSTOM.value,
  168. or_(
  169. Provider.provider_name == model_provider_id.provider_name,
  170. Provider.provider_name == self.provider.provider,
  171. ),
  172. )
  173. .first()
  174. )
  175. else:
  176. provider_record = (
  177. db.session.query(Provider)
  178. .filter(
  179. Provider.tenant_id == self.tenant_id,
  180. Provider.provider_type == ProviderType.CUSTOM.value,
  181. Provider.provider_name == self.provider.provider,
  182. )
  183. .first()
  184. )
  185. # Get provider credential secret variables
  186. provider_credential_secret_variables = self.extract_secret_variables(
  187. self.provider.provider_credential_schema.credential_form_schemas
  188. if self.provider.provider_credential_schema
  189. else []
  190. )
  191. if provider_record:
  192. try:
  193. # fix origin data
  194. if provider_record.encrypted_config:
  195. if not provider_record.encrypted_config.startswith("{"):
  196. original_credentials = {"openai_api_key": provider_record.encrypted_config}
  197. else:
  198. original_credentials = json.loads(provider_record.encrypted_config)
  199. else:
  200. original_credentials = {}
  201. except JSONDecodeError:
  202. original_credentials = {}
  203. # encrypt credentials
  204. for key, value in credentials.items():
  205. if key in provider_credential_secret_variables:
  206. # if send [__HIDDEN__] in secret input, it will be same as original value
  207. if value == HIDDEN_VALUE and key in original_credentials:
  208. credentials[key] = encrypter.decrypt_token(self.tenant_id, original_credentials[key])
  209. model_provider_factory = ModelProviderFactory(self.tenant_id)
  210. credentials = model_provider_factory.provider_credentials_validate(
  211. provider=self.provider.provider, credentials=credentials
  212. )
  213. for key, value in credentials.items():
  214. if key in provider_credential_secret_variables:
  215. credentials[key] = encrypter.encrypt_token(self.tenant_id, value)
  216. return provider_record, credentials
  217. def add_or_update_custom_credentials(self, credentials: dict) -> None:
  218. """
  219. Add or update custom provider credentials.
  220. :param credentials:
  221. :return:
  222. """
  223. # validate custom provider config
  224. provider_record, credentials = self.custom_credentials_validate(credentials)
  225. # save provider
  226. # Note: Do not switch the preferred provider, which allows users to use quotas first
  227. if provider_record:
  228. provider_record.encrypted_config = json.dumps(credentials)
  229. provider_record.is_valid = True
  230. provider_record.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
  231. db.session.commit()
  232. else:
  233. provider_record = Provider()
  234. provider_record.tenant_id = self.tenant_id
  235. provider_record.provider_name = self.provider.provider
  236. provider_record.provider_type = ProviderType.CUSTOM.value
  237. provider_record.encrypted_config = json.dumps(credentials)
  238. provider_record.is_valid = True
  239. db.session.add(provider_record)
  240. db.session.commit()
  241. provider_model_credentials_cache = ProviderCredentialsCache(
  242. tenant_id=self.tenant_id, identity_id=provider_record.id, cache_type=ProviderCredentialsCacheType.PROVIDER
  243. )
  244. provider_model_credentials_cache.delete()
  245. self.switch_preferred_provider_type(ProviderType.CUSTOM)
  246. def delete_custom_credentials(self) -> None:
  247. """
  248. Delete custom provider credentials.
  249. :return:
  250. """
  251. # get provider
  252. provider_record = (
  253. db.session.query(Provider)
  254. .filter(
  255. Provider.tenant_id == self.tenant_id,
  256. or_(
  257. Provider.provider_name == ModelProviderID(self.provider.provider).plugin_name,
  258. Provider.provider_name == self.provider.provider,
  259. ),
  260. Provider.provider_type == ProviderType.CUSTOM.value,
  261. )
  262. .first()
  263. )
  264. # delete provider
  265. if provider_record:
  266. self.switch_preferred_provider_type(ProviderType.SYSTEM)
  267. db.session.delete(provider_record)
  268. db.session.commit()
  269. provider_model_credentials_cache = ProviderCredentialsCache(
  270. tenant_id=self.tenant_id,
  271. identity_id=provider_record.id,
  272. cache_type=ProviderCredentialsCacheType.PROVIDER,
  273. )
  274. provider_model_credentials_cache.delete()
  275. def get_custom_model_credentials(
  276. self, model_type: ModelType, model: str, obfuscated: bool = False
  277. ) -> Optional[dict]:
  278. """
  279. Get custom model credentials.
  280. :param model_type: model type
  281. :param model: model name
  282. :param obfuscated: obfuscated secret data in credentials
  283. :return:
  284. """
  285. if not self.custom_configuration.models:
  286. return None
  287. for model_configuration in self.custom_configuration.models:
  288. if model_configuration.model_type == model_type and model_configuration.model == model:
  289. credentials = model_configuration.credentials
  290. if not obfuscated:
  291. return credentials
  292. # Obfuscate credentials
  293. return self.obfuscated_credentials(
  294. credentials=credentials,
  295. credential_form_schemas=self.provider.model_credential_schema.credential_form_schemas
  296. if self.provider.model_credential_schema
  297. else [],
  298. )
  299. return None
  300. def custom_model_credentials_validate(
  301. self, model_type: ModelType, model: str, credentials: dict
  302. ) -> tuple[ProviderModel | None, dict]:
  303. """
  304. Validate custom model credentials.
  305. :param model_type: model type
  306. :param model: model name
  307. :param credentials: model credentials
  308. :return:
  309. """
  310. # get provider model
  311. provider_model_record = (
  312. db.session.query(ProviderModel)
  313. .filter(
  314. ProviderModel.tenant_id == self.tenant_id,
  315. ProviderModel.provider_name == self.provider.provider,
  316. ProviderModel.model_name == model,
  317. ProviderModel.model_type == model_type.to_origin_model_type(),
  318. )
  319. .first()
  320. )
  321. # Get provider credential secret variables
  322. provider_credential_secret_variables = self.extract_secret_variables(
  323. self.provider.model_credential_schema.credential_form_schemas
  324. if self.provider.model_credential_schema
  325. else []
  326. )
  327. if provider_model_record:
  328. try:
  329. original_credentials = (
  330. json.loads(provider_model_record.encrypted_config) if provider_model_record.encrypted_config else {}
  331. )
  332. except JSONDecodeError:
  333. original_credentials = {}
  334. # decrypt credentials
  335. for key, value in credentials.items():
  336. if key in provider_credential_secret_variables:
  337. # if send [__HIDDEN__] in secret input, it will be same as original value
  338. if value == HIDDEN_VALUE and key in original_credentials:
  339. credentials[key] = encrypter.decrypt_token(self.tenant_id, original_credentials[key])
  340. model_provider_factory = ModelProviderFactory(self.tenant_id)
  341. credentials = model_provider_factory.model_credentials_validate(
  342. provider=self.provider.provider, model_type=model_type, model=model, credentials=credentials
  343. )
  344. for key, value in credentials.items():
  345. if key in provider_credential_secret_variables:
  346. credentials[key] = encrypter.encrypt_token(self.tenant_id, value)
  347. return provider_model_record, credentials
  348. def add_or_update_custom_model_credentials(self, model_type: ModelType, model: str, credentials: dict) -> None:
  349. """
  350. Add or update custom model credentials.
  351. :param model_type: model type
  352. :param model: model name
  353. :param credentials: model credentials
  354. :return:
  355. """
  356. # validate custom model config
  357. provider_model_record, credentials = self.custom_model_credentials_validate(model_type, model, credentials)
  358. # save provider model
  359. # Note: Do not switch the preferred provider, which allows users to use quotas first
  360. if provider_model_record:
  361. provider_model_record.encrypted_config = json.dumps(credentials)
  362. provider_model_record.is_valid = True
  363. provider_model_record.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
  364. db.session.commit()
  365. else:
  366. provider_model_record = ProviderModel()
  367. provider_model_record.tenant_id = self.tenant_id
  368. provider_model_record.provider_name = self.provider.provider
  369. provider_model_record.model_name = model
  370. provider_model_record.model_type = model_type.to_origin_model_type()
  371. provider_model_record.encrypted_config = json.dumps(credentials)
  372. provider_model_record.is_valid = True
  373. db.session.add(provider_model_record)
  374. db.session.commit()
  375. provider_model_credentials_cache = ProviderCredentialsCache(
  376. tenant_id=self.tenant_id,
  377. identity_id=provider_model_record.id,
  378. cache_type=ProviderCredentialsCacheType.MODEL,
  379. )
  380. provider_model_credentials_cache.delete()
  381. def delete_custom_model_credentials(self, model_type: ModelType, model: str) -> None:
  382. """
  383. Delete custom model credentials.
  384. :param model_type: model type
  385. :param model: model name
  386. :return:
  387. """
  388. # get provider model
  389. provider_model_record = (
  390. db.session.query(ProviderModel)
  391. .filter(
  392. ProviderModel.tenant_id == self.tenant_id,
  393. ProviderModel.provider_name == self.provider.provider,
  394. ProviderModel.model_name == model,
  395. ProviderModel.model_type == model_type.to_origin_model_type(),
  396. )
  397. .first()
  398. )
  399. # delete provider model
  400. if provider_model_record:
  401. db.session.delete(provider_model_record)
  402. db.session.commit()
  403. provider_model_credentials_cache = ProviderCredentialsCache(
  404. tenant_id=self.tenant_id,
  405. identity_id=provider_model_record.id,
  406. cache_type=ProviderCredentialsCacheType.MODEL,
  407. )
  408. provider_model_credentials_cache.delete()
  409. def enable_model(self, model_type: ModelType, model: str) -> ProviderModelSetting:
  410. """
  411. Enable model.
  412. :param model_type: model type
  413. :param model: model name
  414. :return:
  415. """
  416. model_setting = (
  417. db.session.query(ProviderModelSetting)
  418. .filter(
  419. ProviderModelSetting.tenant_id == self.tenant_id,
  420. ProviderModelSetting.provider_name == self.provider.provider,
  421. ProviderModelSetting.model_type == model_type.to_origin_model_type(),
  422. ProviderModelSetting.model_name == model,
  423. )
  424. .first()
  425. )
  426. if model_setting:
  427. model_setting.enabled = True
  428. model_setting.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
  429. db.session.commit()
  430. else:
  431. model_setting = ProviderModelSetting()
  432. model_setting.tenant_id = self.tenant_id
  433. model_setting.provider_name = self.provider.provider
  434. model_setting.model_type = model_type.to_origin_model_type()
  435. model_setting.model_name = model
  436. model_setting.enabled = True
  437. db.session.add(model_setting)
  438. db.session.commit()
  439. return model_setting
  440. def disable_model(self, model_type: ModelType, model: str) -> ProviderModelSetting:
  441. """
  442. Disable model.
  443. :param model_type: model type
  444. :param model: model name
  445. :return:
  446. """
  447. model_setting = (
  448. db.session.query(ProviderModelSetting)
  449. .filter(
  450. ProviderModelSetting.tenant_id == self.tenant_id,
  451. ProviderModelSetting.provider_name == self.provider.provider,
  452. ProviderModelSetting.model_type == model_type.to_origin_model_type(),
  453. ProviderModelSetting.model_name == model,
  454. )
  455. .first()
  456. )
  457. if model_setting:
  458. model_setting.enabled = False
  459. model_setting.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
  460. db.session.commit()
  461. else:
  462. model_setting = ProviderModelSetting()
  463. model_setting.tenant_id = self.tenant_id
  464. model_setting.provider_name = self.provider.provider
  465. model_setting.model_type = model_type.to_origin_model_type()
  466. model_setting.model_name = model
  467. model_setting.enabled = False
  468. db.session.add(model_setting)
  469. db.session.commit()
  470. return model_setting
  471. def get_provider_model_setting(self, model_type: ModelType, model: str) -> Optional[ProviderModelSetting]:
  472. """
  473. Get provider model setting.
  474. :param model_type: model type
  475. :param model: model name
  476. :return:
  477. """
  478. return (
  479. db.session.query(ProviderModelSetting)
  480. .filter(
  481. ProviderModelSetting.tenant_id == self.tenant_id,
  482. ProviderModelSetting.provider_name == self.provider.provider,
  483. ProviderModelSetting.model_type == model_type.to_origin_model_type(),
  484. ProviderModelSetting.model_name == model,
  485. )
  486. .first()
  487. )
  488. def enable_model_load_balancing(self, model_type: ModelType, model: str) -> ProviderModelSetting:
  489. """
  490. Enable model load balancing.
  491. :param model_type: model type
  492. :param model: model name
  493. :return:
  494. """
  495. load_balancing_config_count = (
  496. db.session.query(LoadBalancingModelConfig)
  497. .filter(
  498. LoadBalancingModelConfig.tenant_id == self.tenant_id,
  499. LoadBalancingModelConfig.provider_name == self.provider.provider,
  500. LoadBalancingModelConfig.model_type == model_type.to_origin_model_type(),
  501. LoadBalancingModelConfig.model_name == model,
  502. )
  503. .count()
  504. )
  505. if load_balancing_config_count <= 1:
  506. raise ValueError("Model load balancing configuration must be more than 1.")
  507. model_setting = (
  508. db.session.query(ProviderModelSetting)
  509. .filter(
  510. ProviderModelSetting.tenant_id == self.tenant_id,
  511. ProviderModelSetting.provider_name == self.provider.provider,
  512. ProviderModelSetting.model_type == model_type.to_origin_model_type(),
  513. ProviderModelSetting.model_name == model,
  514. )
  515. .first()
  516. )
  517. if model_setting:
  518. model_setting.load_balancing_enabled = True
  519. model_setting.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
  520. db.session.commit()
  521. else:
  522. model_setting = ProviderModelSetting()
  523. model_setting.tenant_id = self.tenant_id
  524. model_setting.provider_name = self.provider.provider
  525. model_setting.model_type = model_type.to_origin_model_type()
  526. model_setting.model_name = model
  527. model_setting.load_balancing_enabled = True
  528. db.session.add(model_setting)
  529. db.session.commit()
  530. return model_setting
  531. def disable_model_load_balancing(self, model_type: ModelType, model: str) -> ProviderModelSetting:
  532. """
  533. Disable model load balancing.
  534. :param model_type: model type
  535. :param model: model name
  536. :return:
  537. """
  538. model_setting = (
  539. db.session.query(ProviderModelSetting)
  540. .filter(
  541. ProviderModelSetting.tenant_id == self.tenant_id,
  542. ProviderModelSetting.provider_name == self.provider.provider,
  543. ProviderModelSetting.model_type == model_type.to_origin_model_type(),
  544. ProviderModelSetting.model_name == model,
  545. )
  546. .first()
  547. )
  548. if model_setting:
  549. model_setting.load_balancing_enabled = False
  550. model_setting.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
  551. db.session.commit()
  552. else:
  553. model_setting = ProviderModelSetting()
  554. model_setting.tenant_id = self.tenant_id
  555. model_setting.provider_name = self.provider.provider
  556. model_setting.model_type = model_type.to_origin_model_type()
  557. model_setting.model_name = model
  558. model_setting.load_balancing_enabled = False
  559. db.session.add(model_setting)
  560. db.session.commit()
  561. return model_setting
  562. def get_model_type_instance(self, model_type: ModelType) -> AIModel:
  563. """
  564. Get current model type instance.
  565. :param model_type: model type
  566. :return:
  567. """
  568. model_provider_factory = ModelProviderFactory(self.tenant_id)
  569. # Get model instance of LLM
  570. return model_provider_factory.get_model_type_instance(provider=self.provider.provider, model_type=model_type)
  571. def get_model_schema(self, model_type: ModelType, model: str, credentials: dict) -> AIModelEntity | None:
  572. """
  573. Get model schema
  574. """
  575. model_provider_factory = ModelProviderFactory(self.tenant_id)
  576. return model_provider_factory.get_model_schema(
  577. provider=self.provider.provider, model_type=model_type, model=model, credentials=credentials
  578. )
  579. def switch_preferred_provider_type(self, provider_type: ProviderType) -> None:
  580. """
  581. Switch preferred provider type.
  582. :param provider_type:
  583. :return:
  584. """
  585. if provider_type == self.preferred_provider_type:
  586. return
  587. if provider_type == ProviderType.SYSTEM and not self.system_configuration.enabled:
  588. return
  589. # get preferred provider
  590. preferred_model_provider = (
  591. db.session.query(TenantPreferredModelProvider)
  592. .filter(
  593. TenantPreferredModelProvider.tenant_id == self.tenant_id,
  594. TenantPreferredModelProvider.provider_name == self.provider.provider,
  595. )
  596. .first()
  597. )
  598. if preferred_model_provider:
  599. preferred_model_provider.preferred_provider_type = provider_type.value
  600. else:
  601. preferred_model_provider = TenantPreferredModelProvider()
  602. preferred_model_provider.tenant_id = self.tenant_id
  603. preferred_model_provider.provider_name = self.provider.provider
  604. preferred_model_provider.preferred_provider_type = provider_type.value
  605. db.session.add(preferred_model_provider)
  606. db.session.commit()
  607. def extract_secret_variables(self, credential_form_schemas: list[CredentialFormSchema]) -> list[str]:
  608. """
  609. Extract secret input form variables.
  610. :param credential_form_schemas:
  611. :return:
  612. """
  613. secret_input_form_variables = []
  614. for credential_form_schema in credential_form_schemas:
  615. if credential_form_schema.type == FormType.SECRET_INPUT:
  616. secret_input_form_variables.append(credential_form_schema.variable)
  617. return secret_input_form_variables
  618. def obfuscated_credentials(self, credentials: dict, credential_form_schemas: list[CredentialFormSchema]) -> dict:
  619. """
  620. Obfuscated credentials.
  621. :param credentials: credentials
  622. :param credential_form_schemas: credential form schemas
  623. :return:
  624. """
  625. # Get provider credential secret variables
  626. credential_secret_variables = self.extract_secret_variables(credential_form_schemas)
  627. # Obfuscate provider credentials
  628. copy_credentials = credentials.copy()
  629. for key, value in copy_credentials.items():
  630. if key in credential_secret_variables:
  631. copy_credentials[key] = encrypter.obfuscated_token(value)
  632. return copy_credentials
  633. def get_provider_model(
  634. self, model_type: ModelType, model: str, only_active: bool = False
  635. ) -> Optional[ModelWithProviderEntity]:
  636. """
  637. Get provider model.
  638. :param model_type: model type
  639. :param model: model name
  640. :param only_active: return active model only
  641. :return:
  642. """
  643. provider_models = self.get_provider_models(model_type, only_active)
  644. for provider_model in provider_models:
  645. if provider_model.model == model:
  646. return provider_model
  647. return None
  648. def get_provider_models(
  649. self, model_type: Optional[ModelType] = None, only_active: bool = False
  650. ) -> list[ModelWithProviderEntity]:
  651. """
  652. Get provider models.
  653. :param model_type: model type
  654. :param only_active: only active models
  655. :return:
  656. """
  657. model_provider_factory = ModelProviderFactory(self.tenant_id)
  658. provider_schema = model_provider_factory.get_provider_schema(self.provider.provider)
  659. model_types: list[ModelType] = []
  660. if model_type:
  661. model_types.append(model_type)
  662. else:
  663. model_types = list(provider_schema.supported_model_types)
  664. # Group model settings by model type and model
  665. model_setting_map: defaultdict[ModelType, dict[str, ModelSettings]] = defaultdict(dict)
  666. for model_setting in self.model_settings:
  667. model_setting_map[model_setting.model_type][model_setting.model] = model_setting
  668. if self.using_provider_type == ProviderType.SYSTEM:
  669. provider_models = self._get_system_provider_models(
  670. model_types=model_types, provider_schema=provider_schema, model_setting_map=model_setting_map
  671. )
  672. else:
  673. provider_models = self._get_custom_provider_models(
  674. model_types=model_types, provider_schema=provider_schema, model_setting_map=model_setting_map
  675. )
  676. if only_active:
  677. provider_models = [m for m in provider_models if m.status == ModelStatus.ACTIVE]
  678. # resort provider_models
  679. return sorted(provider_models, key=lambda x: x.model_type.value)
  680. def _get_system_provider_models(
  681. self,
  682. model_types: Sequence[ModelType],
  683. provider_schema: ProviderEntity,
  684. model_setting_map: dict[ModelType, dict[str, ModelSettings]],
  685. ) -> list[ModelWithProviderEntity]:
  686. """
  687. Get system provider models.
  688. :param model_types: model types
  689. :param provider_schema: provider schema
  690. :param model_setting_map: model setting map
  691. :return:
  692. """
  693. provider_models = []
  694. for model_type in model_types:
  695. for m in provider_schema.models:
  696. if m.model_type != model_type:
  697. continue
  698. status = ModelStatus.ACTIVE
  699. if m.model in model_setting_map:
  700. model_setting = model_setting_map[m.model_type][m.model]
  701. if model_setting.enabled is False:
  702. status = ModelStatus.DISABLED
  703. provider_models.append(
  704. ModelWithProviderEntity(
  705. model=m.model,
  706. label=m.label,
  707. model_type=m.model_type,
  708. features=m.features,
  709. fetch_from=m.fetch_from,
  710. model_properties=m.model_properties,
  711. deprecated=m.deprecated,
  712. provider=SimpleModelProviderEntity(self.provider),
  713. status=status,
  714. )
  715. )
  716. if self.provider.provider not in original_provider_configurate_methods:
  717. original_provider_configurate_methods[self.provider.provider] = []
  718. for configurate_method in provider_schema.configurate_methods:
  719. original_provider_configurate_methods[self.provider.provider].append(configurate_method)
  720. should_use_custom_model = False
  721. if original_provider_configurate_methods[self.provider.provider] == [ConfigurateMethod.CUSTOMIZABLE_MODEL]:
  722. should_use_custom_model = True
  723. for quota_configuration in self.system_configuration.quota_configurations:
  724. if self.system_configuration.current_quota_type != quota_configuration.quota_type:
  725. continue
  726. restrict_models = quota_configuration.restrict_models
  727. if len(restrict_models) == 0:
  728. break
  729. if should_use_custom_model:
  730. if original_provider_configurate_methods[self.provider.provider] == [
  731. ConfigurateMethod.CUSTOMIZABLE_MODEL
  732. ]:
  733. # only customizable model
  734. for restrict_model in restrict_models:
  735. copy_credentials = (
  736. self.system_configuration.credentials.copy()
  737. if self.system_configuration.credentials
  738. else {}
  739. )
  740. if restrict_model.base_model_name:
  741. copy_credentials["base_model_name"] = restrict_model.base_model_name
  742. try:
  743. custom_model_schema = self.get_model_schema(
  744. model_type=restrict_model.model_type,
  745. model=restrict_model.model,
  746. credentials=copy_credentials,
  747. )
  748. except Exception as ex:
  749. logger.warning(f"get custom model schema failed, {ex}")
  750. if not custom_model_schema:
  751. continue
  752. if custom_model_schema.model_type not in model_types:
  753. continue
  754. status = ModelStatus.ACTIVE
  755. if (
  756. custom_model_schema.model_type in model_setting_map
  757. and custom_model_schema.model in model_setting_map[custom_model_schema.model_type]
  758. ):
  759. model_setting = model_setting_map[custom_model_schema.model_type][
  760. custom_model_schema.model
  761. ]
  762. if model_setting.enabled is False:
  763. status = ModelStatus.DISABLED
  764. provider_models.append(
  765. ModelWithProviderEntity(
  766. model=custom_model_schema.model,
  767. label=custom_model_schema.label,
  768. model_type=custom_model_schema.model_type,
  769. features=custom_model_schema.features,
  770. fetch_from=FetchFrom.PREDEFINED_MODEL,
  771. model_properties=custom_model_schema.model_properties,
  772. deprecated=custom_model_schema.deprecated,
  773. provider=SimpleModelProviderEntity(self.provider),
  774. status=status,
  775. )
  776. )
  777. # if llm name not in restricted llm list, remove it
  778. restrict_model_names = [rm.model for rm in restrict_models]
  779. for model in provider_models:
  780. if model.model_type == ModelType.LLM and model.model not in restrict_model_names:
  781. model.status = ModelStatus.NO_PERMISSION
  782. elif not quota_configuration.is_valid:
  783. model.status = ModelStatus.QUOTA_EXCEEDED
  784. return provider_models
  785. def _get_custom_provider_models(
  786. self,
  787. model_types: Sequence[ModelType],
  788. provider_schema: ProviderEntity,
  789. model_setting_map: dict[ModelType, dict[str, ModelSettings]],
  790. ) -> list[ModelWithProviderEntity]:
  791. """
  792. Get custom provider models.
  793. :param model_types: model types
  794. :param provider_schema: provider schema
  795. :param model_setting_map: model setting map
  796. :return:
  797. """
  798. provider_models = []
  799. credentials = None
  800. if self.custom_configuration.provider:
  801. credentials = self.custom_configuration.provider.credentials
  802. for model_type in model_types:
  803. if model_type not in self.provider.supported_model_types:
  804. continue
  805. for m in provider_schema.models:
  806. if m.model_type != model_type:
  807. continue
  808. status = ModelStatus.ACTIVE if credentials else ModelStatus.NO_CONFIGURE
  809. load_balancing_enabled = False
  810. if m.model_type in model_setting_map and m.model in model_setting_map[m.model_type]:
  811. model_setting = model_setting_map[m.model_type][m.model]
  812. if model_setting.enabled is False:
  813. status = ModelStatus.DISABLED
  814. if len(model_setting.load_balancing_configs) > 1:
  815. load_balancing_enabled = True
  816. provider_models.append(
  817. ModelWithProviderEntity(
  818. model=m.model,
  819. label=m.label,
  820. model_type=m.model_type,
  821. features=m.features,
  822. fetch_from=m.fetch_from,
  823. model_properties=m.model_properties,
  824. deprecated=m.deprecated,
  825. provider=SimpleModelProviderEntity(self.provider),
  826. status=status,
  827. load_balancing_enabled=load_balancing_enabled,
  828. )
  829. )
  830. # custom models
  831. for model_configuration in self.custom_configuration.models:
  832. if model_configuration.model_type not in model_types:
  833. continue
  834. try:
  835. custom_model_schema = self.get_model_schema(
  836. model_type=model_configuration.model_type,
  837. model=model_configuration.model,
  838. credentials=model_configuration.credentials,
  839. )
  840. except Exception as ex:
  841. logger.warning(f"get custom model schema failed, {ex}")
  842. continue
  843. if not custom_model_schema:
  844. continue
  845. status = ModelStatus.ACTIVE
  846. load_balancing_enabled = False
  847. if (
  848. custom_model_schema.model_type in model_setting_map
  849. and custom_model_schema.model in model_setting_map[custom_model_schema.model_type]
  850. ):
  851. model_setting = model_setting_map[custom_model_schema.model_type][custom_model_schema.model]
  852. if model_setting.enabled is False:
  853. status = ModelStatus.DISABLED
  854. if len(model_setting.load_balancing_configs) > 1:
  855. load_balancing_enabled = True
  856. provider_models.append(
  857. ModelWithProviderEntity(
  858. model=custom_model_schema.model,
  859. label=custom_model_schema.label,
  860. model_type=custom_model_schema.model_type,
  861. features=custom_model_schema.features,
  862. fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
  863. model_properties=custom_model_schema.model_properties,
  864. deprecated=custom_model_schema.deprecated,
  865. provider=SimpleModelProviderEntity(self.provider),
  866. status=status,
  867. load_balancing_enabled=load_balancing_enabled,
  868. )
  869. )
  870. return provider_models
  871. class ProviderConfigurations(BaseModel):
  872. """
  873. Model class for provider configuration dict.
  874. """
  875. tenant_id: str
  876. configurations: dict[str, ProviderConfiguration] = Field(default_factory=dict)
  877. def __init__(self, tenant_id: str):
  878. super().__init__(tenant_id=tenant_id)
  879. def get_models(
  880. self, provider: Optional[str] = None, model_type: Optional[ModelType] = None, only_active: bool = False
  881. ) -> list[ModelWithProviderEntity]:
  882. """
  883. Get available models.
  884. If preferred provider type is `system`:
  885. Get the current **system mode** if provider supported,
  886. if all system modes are not available (no quota), it is considered to be the **custom credential mode**.
  887. If there is no model configured in custom mode, it is treated as no_configure.
  888. system > custom > no_configure
  889. If preferred provider type is `custom`:
  890. If custom credentials are configured, it is treated as custom mode.
  891. Otherwise, get the current **system mode** if supported,
  892. If all system modes are not available (no quota), it is treated as no_configure.
  893. custom > system > no_configure
  894. If real mode is `system`, use system credentials to get models,
  895. paid quotas > provider free quotas > system free quotas
  896. include pre-defined models (exclude GPT-4, status marked as `no_permission`).
  897. If real mode is `custom`, use workspace custom credentials to get models,
  898. include pre-defined models, custom models(manual append).
  899. If real mode is `no_configure`, only return pre-defined models from `model runtime`.
  900. (model status marked as `no_configure` if preferred provider type is `custom` otherwise `quota_exceeded`)
  901. model status marked as `active` is available.
  902. :param provider: provider name
  903. :param model_type: model type
  904. :param only_active: only active models
  905. :return:
  906. """
  907. all_models = []
  908. for provider_configuration in self.values():
  909. if provider and provider_configuration.provider.provider != provider:
  910. continue
  911. all_models.extend(provider_configuration.get_provider_models(model_type, only_active))
  912. return all_models
  913. def to_list(self) -> list[ProviderConfiguration]:
  914. """
  915. Convert to list.
  916. :return:
  917. """
  918. return list(self.values())
  919. def __getitem__(self, key):
  920. if "/" not in key:
  921. key = str(ModelProviderID(key))
  922. return self.configurations[key]
  923. def __setitem__(self, key, value):
  924. self.configurations[key] = value
  925. def __iter__(self):
  926. return iter(self.configurations)
  927. def values(self) -> Iterator[ProviderConfiguration]:
  928. return iter(self.configurations.values())
  929. def get(self, key, default=None) -> ProviderConfiguration | None:
  930. if "/" not in key:
  931. key = str(ModelProviderID(key))
  932. return self.configurations.get(key, default) # type: ignore
  933. class ProviderModelBundle(BaseModel):
  934. """
  935. Provider model bundle.
  936. """
  937. configuration: ProviderConfiguration
  938. model_type_instance: AIModel
  939. # pydantic configs
  940. model_config = ConfigDict(arbitrary_types_allowed=True, protected_namespaces=())