provider_service.py 24 KB


  1. import datetime
  2. import json
  3. import logging
  4. import os
  5. from collections import defaultdict
  6. from typing import Optional
  7. import requests
  8. from core.model_providers.model_factory import ModelFactory
  9. from extensions.ext_database import db
  10. from core.model_providers.model_provider_factory import ModelProviderFactory
  11. from core.model_providers.models.entity.model_params import ModelType, ModelKwargsRules
  12. from models.provider import Provider, ProviderModel, TenantPreferredModelProvider, ProviderType, ProviderQuotaType, \
  13. TenantDefaultModel
  14. class ProviderService:
  15. def get_provider_list(self, tenant_id: str):
  16. """
  17. get provider list of tenant.
  18. :param tenant_id:
  19. :return:
  20. """
  21. # get rules for all providers
  22. model_provider_rules = ModelProviderFactory.get_provider_rules()
  23. model_provider_names = [model_provider_name for model_provider_name, _ in model_provider_rules.items()]
  24. for model_provider_name, model_provider_rule in model_provider_rules.items():
  25. if ProviderType.SYSTEM.value in model_provider_rule['support_provider_types'] \
  26. and 'system_config' in model_provider_rule and model_provider_rule['system_config'] \
  27. and 'supported_quota_types' in model_provider_rule['system_config'] \
  28. and 'trial' in model_provider_rule['system_config']['supported_quota_types']:
  29. ModelProviderFactory.get_preferred_model_provider(tenant_id, model_provider_name)
  30. configurable_model_provider_names = [
  31. model_provider_name
  32. for model_provider_name, model_provider_rules in model_provider_rules.items()
  33. if 'custom' in model_provider_rules['support_provider_types']
  34. and model_provider_rules['model_flexibility'] == 'configurable'
  35. ]
  36. # get all providers for the tenant
  37. providers = db.session.query(Provider) \
  38. .filter(
  39. Provider.tenant_id == tenant_id,
  40. Provider.provider_name.in_(model_provider_names),
  41. Provider.is_valid == True
  42. ).order_by(Provider.created_at.desc()).all()
  43. provider_name_to_provider_dict = defaultdict(list)
  44. for provider in providers:
  45. provider_name_to_provider_dict[provider.provider_name].append(provider)
  46. # get all configurable provider models for the tenant
  47. provider_models = db.session.query(ProviderModel) \
  48. .filter(
  49. ProviderModel.tenant_id == tenant_id,
  50. ProviderModel.provider_name.in_(configurable_model_provider_names),
  51. ProviderModel.is_valid == True
  52. ).order_by(ProviderModel.created_at.desc()).all()
  53. provider_name_to_provider_model_dict = defaultdict(list)
  54. for provider_model in provider_models:
  55. provider_name_to_provider_model_dict[provider_model.provider_name].append(provider_model)
  56. # get all preferred provider type for the tenant
  57. preferred_provider_types = db.session.query(TenantPreferredModelProvider) \
  58. .filter(
  59. TenantPreferredModelProvider.tenant_id == tenant_id,
  60. TenantPreferredModelProvider.provider_name.in_(model_provider_names)
  61. ).all()
  62. provider_name_to_preferred_provider_type_dict = {preferred_provider_type.provider_name: preferred_provider_type
  63. for preferred_provider_type in preferred_provider_types}
  64. providers_list = {}
  65. for model_provider_name, model_provider_rule in model_provider_rules.items():
  66. # get preferred provider type
  67. preferred_model_provider = provider_name_to_preferred_provider_type_dict.get(model_provider_name)
  68. preferred_provider_type = ModelProviderFactory.get_preferred_type_by_preferred_model_provider(
  69. tenant_id,
  70. model_provider_name,
  71. preferred_model_provider
  72. )
  73. provider_config_dict = {
  74. "preferred_provider_type": preferred_provider_type,
  75. "model_flexibility": model_provider_rule['model_flexibility'],
  76. }
  77. provider_parameter_dict = {}
  78. if ProviderType.SYSTEM.value in model_provider_rule['support_provider_types']:
  79. for quota_type_enum in ProviderQuotaType:
  80. quota_type = quota_type_enum.value
  81. if quota_type in model_provider_rule['system_config']['supported_quota_types']:
  82. key = ProviderType.SYSTEM.value + ':' + quota_type
  83. provider_parameter_dict[key] = {
  84. "provider_name": model_provider_name,
  85. "provider_type": ProviderType.SYSTEM.value,
  86. "config": None,
  87. "is_valid": False, # need update
  88. "quota_type": quota_type,
  89. "quota_unit": model_provider_rule['system_config']['quota_unit'], # need update
  90. "quota_limit": 0 if quota_type != ProviderQuotaType.TRIAL.value else
  91. model_provider_rule['system_config']['quota_limit'], # need update
  92. "quota_used": 0, # need update
  93. "last_used": None # need update
  94. }
  95. if ProviderType.CUSTOM.value in model_provider_rule['support_provider_types']:
  96. provider_parameter_dict[ProviderType.CUSTOM.value] = {
  97. "provider_name": model_provider_name,
  98. "provider_type": ProviderType.CUSTOM.value,
  99. "config": None, # need update
  100. "models": [], # need update
  101. "is_valid": False,
  102. "last_used": None # need update
  103. }
  104. model_provider_class = ModelProviderFactory.get_model_provider_class(model_provider_name)
  105. current_providers = provider_name_to_provider_dict[model_provider_name]
  106. for provider in current_providers:
  107. if provider.provider_type == ProviderType.SYSTEM.value:
  108. quota_type = provider.quota_type
  109. key = f'{ProviderType.SYSTEM.value}:{quota_type}'
  110. if key in provider_parameter_dict:
  111. provider_parameter_dict[key]['is_valid'] = provider.is_valid
  112. provider_parameter_dict[key]['quota_used'] = provider.quota_used
  113. provider_parameter_dict[key]['quota_limit'] = provider.quota_limit
  114. provider_parameter_dict[key]['last_used'] = int(provider.last_used.timestamp()) \
  115. if provider.last_used else None
  116. elif provider.provider_type == ProviderType.CUSTOM.value \
  117. and ProviderType.CUSTOM.value in provider_parameter_dict:
  118. # if custom
  119. key = ProviderType.CUSTOM.value
  120. provider_parameter_dict[key]['last_used'] = int(provider.last_used.timestamp()) \
  121. if provider.last_used else None
  122. provider_parameter_dict[key]['is_valid'] = provider.is_valid
  123. if model_provider_rule['model_flexibility'] == 'fixed':
  124. provider_parameter_dict[key]['config'] = model_provider_class(provider=provider) \
  125. .get_provider_credentials(obfuscated=True)
  126. else:
  127. models = []
  128. provider_models = provider_name_to_provider_model_dict[model_provider_name]
  129. for provider_model in provider_models:
  130. models.append({
  131. "model_name": provider_model.model_name,
  132. "model_type": provider_model.model_type,
  133. "config": model_provider_class(provider=provider) \
  134. .get_model_credentials(provider_model.model_name,
  135. ModelType.value_of(provider_model.model_type),
  136. obfuscated=True),
  137. "is_valid": provider_model.is_valid
  138. })
  139. provider_parameter_dict[key]['models'] = models
  140. provider_config_dict['providers'] = list(provider_parameter_dict.values())
  141. providers_list[model_provider_name] = provider_config_dict
  142. return providers_list
  143. def custom_provider_config_validate(self, provider_name: str, config: dict) -> None:
  144. """
  145. validate custom provider config.
  146. :param provider_name:
  147. :param config:
  148. :return:
  149. :raises CredentialsValidateFailedError: When the config credential verification fails.
  150. """
  151. # get model provider rules
  152. model_provider_rules = ModelProviderFactory.get_provider_rule(provider_name)
  153. if model_provider_rules['model_flexibility'] != 'fixed':
  154. raise ValueError('Only support fixed model provider')
  155. # only support provider type CUSTOM
  156. if ProviderType.CUSTOM.value not in model_provider_rules['support_provider_types']:
  157. raise ValueError('Only support provider type CUSTOM')
  158. # validate provider config
  159. model_provider_class = ModelProviderFactory.get_model_provider_class(provider_name)
  160. model_provider_class.is_provider_credentials_valid_or_raise(config)
  161. def save_custom_provider_config(self, tenant_id: str, provider_name: str, config: dict) -> None:
  162. """
  163. save custom provider config.
  164. :param tenant_id:
  165. :param provider_name:
  166. :param config:
  167. :return:
  168. """
  169. # validate custom provider config
  170. self.custom_provider_config_validate(provider_name, config)
  171. # get provider
  172. provider = db.session.query(Provider) \
  173. .filter(
  174. Provider.tenant_id == tenant_id,
  175. Provider.provider_name == provider_name,
  176. Provider.provider_type == ProviderType.CUSTOM.value
  177. ).first()
  178. model_provider_class = ModelProviderFactory.get_model_provider_class(provider_name)
  179. encrypted_config = model_provider_class.encrypt_provider_credentials(tenant_id, config)
  180. # save provider
  181. if provider:
  182. provider.encrypted_config = json.dumps(encrypted_config)
  183. provider.is_valid = True
  184. provider.updated_at = datetime.datetime.utcnow()
  185. db.session.commit()
  186. else:
  187. provider = Provider(
  188. tenant_id=tenant_id,
  189. provider_name=provider_name,
  190. provider_type=ProviderType.CUSTOM.value,
  191. encrypted_config=json.dumps(encrypted_config),
  192. is_valid=True
  193. )
  194. db.session.add(provider)
  195. db.session.commit()
  196. def delete_custom_provider(self, tenant_id: str, provider_name: str) -> None:
  197. """
  198. delete custom provider.
  199. :param tenant_id:
  200. :param provider_name:
  201. :return:
  202. """
  203. # get provider
  204. provider = db.session.query(Provider) \
  205. .filter(
  206. Provider.tenant_id == tenant_id,
  207. Provider.provider_name == provider_name,
  208. Provider.provider_type == ProviderType.CUSTOM.value
  209. ).first()
  210. if provider:
  211. try:
  212. self.switch_preferred_provider(tenant_id, provider_name, ProviderType.SYSTEM.value)
  213. except ValueError:
  214. pass
  215. db.session.delete(provider)
  216. db.session.commit()
  217. def custom_provider_model_config_validate(self,
  218. provider_name: str,
  219. model_name: str,
  220. model_type: str,
  221. config: dict) -> None:
  222. """
  223. validate custom provider model config.
  224. :param provider_name:
  225. :param model_name:
  226. :param model_type:
  227. :param config:
  228. :return:
  229. :raises CredentialsValidateFailedError: When the config credential verification fails.
  230. """
  231. # get model provider rules
  232. model_provider_rules = ModelProviderFactory.get_provider_rule(provider_name)
  233. if model_provider_rules['model_flexibility'] != 'configurable':
  234. raise ValueError('Only support configurable model provider')
  235. # only support provider type CUSTOM
  236. if ProviderType.CUSTOM.value not in model_provider_rules['support_provider_types']:
  237. raise ValueError('Only support provider type CUSTOM')
  238. # validate provider model config
  239. model_type = ModelType.value_of(model_type)
  240. model_provider_class = ModelProviderFactory.get_model_provider_class(provider_name)
  241. model_provider_class.is_model_credentials_valid_or_raise(model_name, model_type, config)
  242. def add_or_save_custom_provider_model_config(self,
  243. tenant_id: str,
  244. provider_name: str,
  245. model_name: str,
  246. model_type: str,
  247. config: dict) -> None:
  248. """
  249. Add or save custom provider model config.
  250. :param tenant_id:
  251. :param provider_name:
  252. :param model_name:
  253. :param model_type:
  254. :param config:
  255. :return:
  256. """
  257. # validate custom provider model config
  258. self.custom_provider_model_config_validate(provider_name, model_name, model_type, config)
  259. # get provider
  260. provider = db.session.query(Provider) \
  261. .filter(
  262. Provider.tenant_id == tenant_id,
  263. Provider.provider_name == provider_name,
  264. Provider.provider_type == ProviderType.CUSTOM.value
  265. ).first()
  266. if not provider:
  267. provider = Provider(
  268. tenant_id=tenant_id,
  269. provider_name=provider_name,
  270. provider_type=ProviderType.CUSTOM.value,
  271. is_valid=True
  272. )
  273. db.session.add(provider)
  274. db.session.commit()
  275. elif not provider.is_valid:
  276. provider.is_valid = True
  277. provider.encrypted_config = None
  278. db.session.commit()
  279. model_provider_class = ModelProviderFactory.get_model_provider_class(provider_name)
  280. encrypted_config = model_provider_class.encrypt_model_credentials(
  281. tenant_id,
  282. model_name,
  283. ModelType.value_of(model_type),
  284. config
  285. )
  286. # get provider model
  287. provider_model = db.session.query(ProviderModel) \
  288. .filter(
  289. ProviderModel.tenant_id == tenant_id,
  290. ProviderModel.provider_name == provider_name,
  291. ProviderModel.model_name == model_name,
  292. ProviderModel.model_type == model_type
  293. ).first()
  294. if provider_model:
  295. provider_model.encrypted_config = json.dumps(encrypted_config)
  296. provider_model.is_valid = True
  297. db.session.commit()
  298. else:
  299. provider_model = ProviderModel(
  300. tenant_id=tenant_id,
  301. provider_name=provider_name,
  302. model_name=model_name,
  303. model_type=model_type,
  304. encrypted_config=json.dumps(encrypted_config),
  305. is_valid=True
  306. )
  307. db.session.add(provider_model)
  308. db.session.commit()
  309. def delete_custom_provider_model(self,
  310. tenant_id: str,
  311. provider_name: str,
  312. model_name: str,
  313. model_type: str) -> None:
  314. """
  315. delete custom provider model.
  316. :param tenant_id:
  317. :param provider_name:
  318. :param model_name:
  319. :param model_type:
  320. :return:
  321. """
  322. # get provider model
  323. provider_model = db.session.query(ProviderModel) \
  324. .filter(
  325. ProviderModel.tenant_id == tenant_id,
  326. ProviderModel.provider_name == provider_name,
  327. ProviderModel.model_name == model_name,
  328. ProviderModel.model_type == model_type
  329. ).first()
  330. if provider_model:
  331. db.session.delete(provider_model)
  332. db.session.commit()
  333. def switch_preferred_provider(self, tenant_id: str, provider_name: str, preferred_provider_type: str) -> None:
  334. """
  335. switch preferred provider.
  336. :param tenant_id:
  337. :param provider_name:
  338. :param preferred_provider_type:
  339. :return:
  340. """
  341. provider_type = ProviderType.value_of(preferred_provider_type)
  342. if not provider_type:
  343. raise ValueError(f'Invalid preferred provider type: {preferred_provider_type}')
  344. model_provider_rules = ModelProviderFactory.get_provider_rule(provider_name)
  345. if preferred_provider_type not in model_provider_rules['support_provider_types']:
  346. raise ValueError(f'Not support provider type: {preferred_provider_type}')
  347. model_provider = ModelProviderFactory.get_model_provider_class(provider_name)
  348. if not model_provider.is_provider_type_system_supported():
  349. return
  350. # get preferred provider
  351. preferred_model_provider = db.session.query(TenantPreferredModelProvider) \
  352. .filter(
  353. TenantPreferredModelProvider.tenant_id == tenant_id,
  354. TenantPreferredModelProvider.provider_name == provider_name
  355. ).first()
  356. if preferred_model_provider:
  357. preferred_model_provider.preferred_provider_type = preferred_provider_type
  358. else:
  359. preferred_model_provider = TenantPreferredModelProvider(
  360. tenant_id=tenant_id,
  361. provider_name=provider_name,
  362. preferred_provider_type=preferred_provider_type
  363. )
  364. db.session.add(preferred_model_provider)
  365. db.session.commit()
  366. def get_default_model_of_model_type(self, tenant_id: str, model_type: str) -> Optional[TenantDefaultModel]:
  367. """
  368. get default model of model type.
  369. :param tenant_id:
  370. :param model_type:
  371. :return:
  372. """
  373. return ModelFactory.get_default_model(tenant_id, ModelType.value_of(model_type))
  374. def update_default_model_of_model_type(self,
  375. tenant_id: str,
  376. model_type: str,
  377. provider_name: str,
  378. model_name: str) -> TenantDefaultModel:
  379. """
  380. update default model of model type.
  381. :param tenant_id:
  382. :param model_type:
  383. :param provider_name:
  384. :param model_name:
  385. :return:
  386. """
  387. return ModelFactory.update_default_model(tenant_id, ModelType.value_of(model_type), provider_name, model_name)
  388. def get_valid_model_list(self, tenant_id: str, model_type: str) -> list:
  389. """
  390. get valid model list.
  391. :param tenant_id:
  392. :param model_type:
  393. :return:
  394. """
  395. valid_model_list = []
  396. # get model provider rules
  397. model_provider_rules = ModelProviderFactory.get_provider_rules()
  398. for model_provider_name, model_provider_rule in model_provider_rules.items():
  399. model_provider = ModelProviderFactory.get_preferred_model_provider(tenant_id, model_provider_name)
  400. if not model_provider:
  401. continue
  402. model_list = model_provider.get_supported_model_list(ModelType.value_of(model_type))
  403. provider = model_provider.provider
  404. for model in model_list:
  405. valid_model_dict = {
  406. "model_name": model['id'],
  407. "model_display_name": model['name'],
  408. "model_type": model_type,
  409. "model_provider": {
  410. "provider_name": provider.provider_name,
  411. "provider_type": provider.provider_type
  412. },
  413. 'features': []
  414. }
  415. if 'mode' in model:
  416. valid_model_dict['model_mode'] = model['mode']
  417. if 'features' in model:
  418. valid_model_dict['features'] = model['features']
  419. if provider.provider_type == ProviderType.SYSTEM.value:
  420. valid_model_dict['model_provider']['quota_type'] = provider.quota_type
  421. valid_model_dict['model_provider']['quota_unit'] = model_provider_rule['system_config']['quota_unit']
  422. valid_model_dict['model_provider']['quota_limit'] = provider.quota_limit
  423. valid_model_dict['model_provider']['quota_used'] = provider.quota_used
  424. valid_model_list.append(valid_model_dict)
  425. return valid_model_list
  426. def get_model_parameter_rules(self, tenant_id: str, model_provider_name: str, model_name: str, model_type: str) \
  427. -> ModelKwargsRules:
  428. """
  429. get model parameter rules.
  430. It depends on preferred provider in use.
  431. :param tenant_id:
  432. :param model_provider_name:
  433. :param model_name:
  434. :param model_type:
  435. :return:
  436. """
  437. # get model provider
  438. model_provider = ModelProviderFactory.get_preferred_model_provider(tenant_id, model_provider_name)
  439. if not model_provider:
  440. # get empty model provider
  441. return ModelKwargsRules()
  442. # get model parameter rules
  443. return model_provider.get_model_parameter_rules(model_name, ModelType.value_of(model_type))
  444. def free_quota_submit(self, tenant_id: str, provider_name: str):
  445. api_key = os.environ.get("FREE_QUOTA_APPLY_API_KEY")
  446. api_base_url = os.environ.get("FREE_QUOTA_APPLY_BASE_URL")
  447. api_url = api_base_url + '/api/v1/providers/apply'
  448. headers = {
  449. 'Content-Type': 'application/json',
  450. 'Authorization': f"Bearer {api_key}"
  451. }
  452. response = requests.post(api_url, headers=headers, json={'workspace_id': tenant_id, 'provider_name': provider_name})
  453. if not response.ok:
  454. logging.error(f"Request FREE QUOTA APPLY SERVER Error: {response.status_code} ")
  455. raise ValueError(f"Error: {response.status_code} ")
  456. if response.json()["code"] != 'success':
  457. raise ValueError(
  458. f"error: {response.json()['message']}"
  459. )
  460. rst = response.json()
  461. if rst['type'] == 'redirect':
  462. return {
  463. 'type': rst['type'],
  464. 'redirect_url': rst['redirect_url']
  465. }
  466. else:
  467. return {
  468. 'type': rst['type'],
  469. 'result': 'success'
  470. }
  471. def free_quota_qualification_verify(self, tenant_id: str, provider_name: str, token: Optional[str]):
  472. api_key = os.environ.get("FREE_QUOTA_APPLY_API_KEY")
  473. api_base_url = os.environ.get("FREE_QUOTA_APPLY_BASE_URL")
  474. api_url = api_base_url + '/api/v1/providers/qualification-verify'
  475. headers = {
  476. 'Content-Type': 'application/json',
  477. 'Authorization': f"Bearer {api_key}"
  478. }
  479. json_data = {'workspace_id': tenant_id, 'provider_name': provider_name}
  480. if token:
  481. json_data['token'] = token
  482. response = requests.post(api_url, headers=headers,
  483. json=json_data)
  484. if not response.ok:
  485. logging.error(f"Request FREE QUOTA APPLY SERVER Error: {response.status_code} ")
  486. raise ValueError(f"Error: {response.status_code} ")
  487. rst = response.json()
  488. if rst["code"] != 'success':
  489. raise ValueError(
  490. f"error: {rst['message']}"
  491. )
  492. data = rst['data']
  493. if data['qualified'] is True:
  494. return {
  495. 'result': 'success',
  496. 'provider_name': provider_name,
  497. 'flag': True
  498. }
  499. else:
  500. return {
  501. 'result': 'success',
  502. 'provider_name': provider_name,
  503. 'flag': False,
  504. 'reason': data['reason']
  505. }