model_load_balancing_service.py 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567
  1. import datetime
  2. import json
  3. import logging
  4. from json import JSONDecodeError
  5. from typing import Optional
  6. from constants import HIDDEN_VALUE
  7. from core.entities.provider_configuration import ProviderConfiguration
  8. from core.helper import encrypter
  9. from core.helper.model_provider_cache import ProviderCredentialsCache, ProviderCredentialsCacheType
  10. from core.model_manager import LBModelManager
  11. from core.model_runtime.entities.model_entities import ModelType
  12. from core.model_runtime.entities.provider_entities import (
  13. ModelCredentialSchema,
  14. ProviderCredentialSchema,
  15. )
  16. from core.model_runtime.model_providers import model_provider_factory
  17. from core.provider_manager import ProviderManager
  18. from extensions.ext_database import db
  19. from models.provider import LoadBalancingModelConfig
  20. logger = logging.getLogger(__name__)
  21. class ModelLoadBalancingService:
  22. def __init__(self) -> None:
  23. self.provider_manager = ProviderManager()
  24. def enable_model_load_balancing(self, tenant_id: str, provider: str, model: str, model_type: str) -> None:
  25. """
  26. enable model load balancing.
  27. :param tenant_id: workspace id
  28. :param provider: provider name
  29. :param model: model name
  30. :param model_type: model type
  31. :return:
  32. """
  33. # Get all provider configurations of the current workspace
  34. provider_configurations = self.provider_manager.get_configurations(tenant_id)
  35. # Get provider configuration
  36. provider_configuration = provider_configurations.get(provider)
  37. if not provider_configuration:
  38. raise ValueError(f"Provider {provider} does not exist.")
  39. # Enable model load balancing
  40. provider_configuration.enable_model_load_balancing(
  41. model=model,
  42. model_type=ModelType.value_of(model_type)
  43. )
  44. def disable_model_load_balancing(self, tenant_id: str, provider: str, model: str, model_type: str) -> None:
  45. """
  46. disable model load balancing.
  47. :param tenant_id: workspace id
  48. :param provider: provider name
  49. :param model: model name
  50. :param model_type: model type
  51. :return:
  52. """
  53. # Get all provider configurations of the current workspace
  54. provider_configurations = self.provider_manager.get_configurations(tenant_id)
  55. # Get provider configuration
  56. provider_configuration = provider_configurations.get(provider)
  57. if not provider_configuration:
  58. raise ValueError(f"Provider {provider} does not exist.")
  59. # disable model load balancing
  60. provider_configuration.disable_model_load_balancing(
  61. model=model,
  62. model_type=ModelType.value_of(model_type)
  63. )
  64. def get_load_balancing_configs(self, tenant_id: str, provider: str, model: str, model_type: str) \
  65. -> tuple[bool, list[dict]]:
  66. """
  67. Get load balancing configurations.
  68. :param tenant_id: workspace id
  69. :param provider: provider name
  70. :param model: model name
  71. :param model_type: model type
  72. :return:
  73. """
  74. # Get all provider configurations of the current workspace
  75. provider_configurations = self.provider_manager.get_configurations(tenant_id)
  76. # Get provider configuration
  77. provider_configuration = provider_configurations.get(provider)
  78. if not provider_configuration:
  79. raise ValueError(f"Provider {provider} does not exist.")
  80. # Convert model type to ModelType
  81. model_type = ModelType.value_of(model_type)
  82. # Get provider model setting
  83. provider_model_setting = provider_configuration.get_provider_model_setting(
  84. model_type=model_type,
  85. model=model,
  86. )
  87. is_load_balancing_enabled = False
  88. if provider_model_setting and provider_model_setting.load_balancing_enabled:
  89. is_load_balancing_enabled = True
  90. # Get load balancing configurations
  91. load_balancing_configs = db.session.query(LoadBalancingModelConfig) \
  92. .filter(
  93. LoadBalancingModelConfig.tenant_id == tenant_id,
  94. LoadBalancingModelConfig.provider_name == provider_configuration.provider.provider,
  95. LoadBalancingModelConfig.model_type == model_type.to_origin_model_type(),
  96. LoadBalancingModelConfig.model_name == model
  97. ).order_by(LoadBalancingModelConfig.created_at).all()
  98. if provider_configuration.custom_configuration.provider:
  99. # check if the inherit configuration exists,
  100. # inherit is represented for the provider or model custom credentials
  101. inherit_config_exists = False
  102. for load_balancing_config in load_balancing_configs:
  103. if load_balancing_config.name == '__inherit__':
  104. inherit_config_exists = True
  105. break
  106. if not inherit_config_exists:
  107. # Initialize the inherit configuration
  108. inherit_config = self._init_inherit_config(tenant_id, provider, model, model_type)
  109. # prepend the inherit configuration
  110. load_balancing_configs.insert(0, inherit_config)
  111. else:
  112. # move the inherit configuration to the first
  113. for i, load_balancing_config in enumerate(load_balancing_configs[:]):
  114. if load_balancing_config.name == '__inherit__':
  115. inherit_config = load_balancing_configs.pop(i)
  116. load_balancing_configs.insert(0, inherit_config)
  117. # Get credential form schemas from model credential schema or provider credential schema
  118. credential_schemas = self._get_credential_schema(provider_configuration)
  119. # Get decoding rsa key and cipher for decrypting credentials
  120. decoding_rsa_key, decoding_cipher_rsa = encrypter.get_decrypt_decoding(tenant_id)
  121. # fetch status and ttl for each config
  122. datas = []
  123. for load_balancing_config in load_balancing_configs:
  124. in_cooldown, ttl = LBModelManager.get_config_in_cooldown_and_ttl(
  125. tenant_id=tenant_id,
  126. provider=provider,
  127. model=model,
  128. model_type=model_type,
  129. config_id=load_balancing_config.id
  130. )
  131. try:
  132. if load_balancing_config.encrypted_config:
  133. credentials = json.loads(load_balancing_config.encrypted_config)
  134. else:
  135. credentials = {}
  136. except JSONDecodeError:
  137. credentials = {}
  138. # Get provider credential secret variables
  139. credential_secret_variables = provider_configuration.extract_secret_variables(
  140. credential_schemas.credential_form_schemas
  141. )
  142. # decrypt credentials
  143. for variable in credential_secret_variables:
  144. if variable in credentials:
  145. try:
  146. credentials[variable] = encrypter.decrypt_token_with_decoding(
  147. credentials.get(variable),
  148. decoding_rsa_key,
  149. decoding_cipher_rsa
  150. )
  151. except ValueError:
  152. pass
  153. # Obfuscate credentials
  154. credentials = provider_configuration.obfuscated_credentials(
  155. credentials=credentials,
  156. credential_form_schemas=credential_schemas.credential_form_schemas
  157. )
  158. datas.append({
  159. 'id': load_balancing_config.id,
  160. 'name': load_balancing_config.name,
  161. 'credentials': credentials,
  162. 'enabled': load_balancing_config.enabled,
  163. 'in_cooldown': in_cooldown,
  164. 'ttl': ttl
  165. })
  166. return is_load_balancing_enabled, datas
  167. def get_load_balancing_config(self, tenant_id: str, provider: str, model: str, model_type: str, config_id: str) \
  168. -> Optional[dict]:
  169. """
  170. Get load balancing configuration.
  171. :param tenant_id: workspace id
  172. :param provider: provider name
  173. :param model: model name
  174. :param model_type: model type
  175. :param config_id: load balancing config id
  176. :return:
  177. """
  178. # Get all provider configurations of the current workspace
  179. provider_configurations = self.provider_manager.get_configurations(tenant_id)
  180. # Get provider configuration
  181. provider_configuration = provider_configurations.get(provider)
  182. if not provider_configuration:
  183. raise ValueError(f"Provider {provider} does not exist.")
  184. # Convert model type to ModelType
  185. model_type = ModelType.value_of(model_type)
  186. # Get load balancing configurations
  187. load_balancing_model_config = db.session.query(LoadBalancingModelConfig) \
  188. .filter(
  189. LoadBalancingModelConfig.tenant_id == tenant_id,
  190. LoadBalancingModelConfig.provider_name == provider_configuration.provider.provider,
  191. LoadBalancingModelConfig.model_type == model_type.to_origin_model_type(),
  192. LoadBalancingModelConfig.model_name == model,
  193. LoadBalancingModelConfig.id == config_id
  194. ).first()
  195. if not load_balancing_model_config:
  196. return None
  197. try:
  198. if load_balancing_model_config.encrypted_config:
  199. credentials = json.loads(load_balancing_model_config.encrypted_config)
  200. else:
  201. credentials = {}
  202. except JSONDecodeError:
  203. credentials = {}
  204. # Get credential form schemas from model credential schema or provider credential schema
  205. credential_schemas = self._get_credential_schema(provider_configuration)
  206. # Obfuscate credentials
  207. credentials = provider_configuration.obfuscated_credentials(
  208. credentials=credentials,
  209. credential_form_schemas=credential_schemas.credential_form_schemas
  210. )
  211. return {
  212. 'id': load_balancing_model_config.id,
  213. 'name': load_balancing_model_config.name,
  214. 'credentials': credentials,
  215. 'enabled': load_balancing_model_config.enabled
  216. }
  217. def _init_inherit_config(self, tenant_id: str, provider: str, model: str, model_type: ModelType) \
  218. -> LoadBalancingModelConfig:
  219. """
  220. Initialize the inherit configuration.
  221. :param tenant_id: workspace id
  222. :param provider: provider name
  223. :param model: model name
  224. :param model_type: model type
  225. :return:
  226. """
  227. # Initialize the inherit configuration
  228. inherit_config = LoadBalancingModelConfig(
  229. tenant_id=tenant_id,
  230. provider_name=provider,
  231. model_type=model_type.to_origin_model_type(),
  232. model_name=model,
  233. name='__inherit__'
  234. )
  235. db.session.add(inherit_config)
  236. db.session.commit()
  237. return inherit_config
  238. def update_load_balancing_configs(self, tenant_id: str,
  239. provider: str,
  240. model: str,
  241. model_type: str,
  242. configs: list[dict]) -> None:
  243. """
  244. Update load balancing configurations.
  245. :param tenant_id: workspace id
  246. :param provider: provider name
  247. :param model: model name
  248. :param model_type: model type
  249. :param configs: load balancing configs
  250. :return:
  251. """
  252. # Get all provider configurations of the current workspace
  253. provider_configurations = self.provider_manager.get_configurations(tenant_id)
  254. # Get provider configuration
  255. provider_configuration = provider_configurations.get(provider)
  256. if not provider_configuration:
  257. raise ValueError(f"Provider {provider} does not exist.")
  258. # Convert model type to ModelType
  259. model_type = ModelType.value_of(model_type)
  260. if not isinstance(configs, list):
  261. raise ValueError('Invalid load balancing configs')
  262. current_load_balancing_configs = db.session.query(LoadBalancingModelConfig) \
  263. .filter(
  264. LoadBalancingModelConfig.tenant_id == tenant_id,
  265. LoadBalancingModelConfig.provider_name == provider_configuration.provider.provider,
  266. LoadBalancingModelConfig.model_type == model_type.to_origin_model_type(),
  267. LoadBalancingModelConfig.model_name == model
  268. ).all()
  269. # id as key, config as value
  270. current_load_balancing_configs_dict = {config.id: config for config in current_load_balancing_configs}
  271. updated_config_ids = set()
  272. for config in configs:
  273. if not isinstance(config, dict):
  274. raise ValueError('Invalid load balancing config')
  275. config_id = config.get('id')
  276. name = config.get('name')
  277. credentials = config.get('credentials')
  278. enabled = config.get('enabled')
  279. if not name:
  280. raise ValueError('Invalid load balancing config name')
  281. if enabled is None:
  282. raise ValueError('Invalid load balancing config enabled')
  283. # is config exists
  284. if config_id:
  285. config_id = str(config_id)
  286. if config_id not in current_load_balancing_configs_dict:
  287. raise ValueError('Invalid load balancing config id: {}'.format(config_id))
  288. updated_config_ids.add(config_id)
  289. load_balancing_config = current_load_balancing_configs_dict[config_id]
  290. # check duplicate name
  291. for current_load_balancing_config in current_load_balancing_configs:
  292. if current_load_balancing_config.id != config_id and current_load_balancing_config.name == name:
  293. raise ValueError('Load balancing config name {} already exists'.format(name))
  294. if credentials:
  295. if not isinstance(credentials, dict):
  296. raise ValueError('Invalid load balancing config credentials')
  297. # validate custom provider config
  298. credentials = self._custom_credentials_validate(
  299. tenant_id=tenant_id,
  300. provider_configuration=provider_configuration,
  301. model_type=model_type,
  302. model=model,
  303. credentials=credentials,
  304. load_balancing_model_config=load_balancing_config,
  305. validate=False
  306. )
  307. # update load balancing config
  308. load_balancing_config.encrypted_config = json.dumps(credentials)
  309. load_balancing_config.name = name
  310. load_balancing_config.enabled = enabled
  311. load_balancing_config.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
  312. db.session.commit()
  313. self._clear_credentials_cache(tenant_id, config_id)
  314. else:
  315. # create load balancing config
  316. if name == '__inherit__':
  317. raise ValueError('Invalid load balancing config name')
  318. # check duplicate name
  319. for current_load_balancing_config in current_load_balancing_configs:
  320. if current_load_balancing_config.name == name:
  321. raise ValueError('Load balancing config name {} already exists'.format(name))
  322. if not credentials:
  323. raise ValueError('Invalid load balancing config credentials')
  324. if not isinstance(credentials, dict):
  325. raise ValueError('Invalid load balancing config credentials')
  326. # validate custom provider config
  327. credentials = self._custom_credentials_validate(
  328. tenant_id=tenant_id,
  329. provider_configuration=provider_configuration,
  330. model_type=model_type,
  331. model=model,
  332. credentials=credentials,
  333. validate=False
  334. )
  335. # create load balancing config
  336. load_balancing_model_config = LoadBalancingModelConfig(
  337. tenant_id=tenant_id,
  338. provider_name=provider_configuration.provider.provider,
  339. model_type=model_type.to_origin_model_type(),
  340. model_name=model,
  341. name=name,
  342. encrypted_config=json.dumps(credentials)
  343. )
  344. db.session.add(load_balancing_model_config)
  345. db.session.commit()
  346. # get deleted config ids
  347. deleted_config_ids = set(current_load_balancing_configs_dict.keys()) - updated_config_ids
  348. for config_id in deleted_config_ids:
  349. db.session.delete(current_load_balancing_configs_dict[config_id])
  350. db.session.commit()
  351. self._clear_credentials_cache(tenant_id, config_id)
  352. def validate_load_balancing_credentials(self, tenant_id: str,
  353. provider: str,
  354. model: str,
  355. model_type: str,
  356. credentials: dict,
  357. config_id: Optional[str] = None) -> None:
  358. """
  359. Validate load balancing credentials.
  360. :param tenant_id: workspace id
  361. :param provider: provider name
  362. :param model_type: model type
  363. :param model: model name
  364. :param credentials: credentials
  365. :param config_id: load balancing config id
  366. :return:
  367. """
  368. # Get all provider configurations of the current workspace
  369. provider_configurations = self.provider_manager.get_configurations(tenant_id)
  370. # Get provider configuration
  371. provider_configuration = provider_configurations.get(provider)
  372. if not provider_configuration:
  373. raise ValueError(f"Provider {provider} does not exist.")
  374. # Convert model type to ModelType
  375. model_type = ModelType.value_of(model_type)
  376. load_balancing_model_config = None
  377. if config_id:
  378. # Get load balancing config
  379. load_balancing_model_config = db.session.query(LoadBalancingModelConfig) \
  380. .filter(
  381. LoadBalancingModelConfig.tenant_id == tenant_id,
  382. LoadBalancingModelConfig.provider_name == provider,
  383. LoadBalancingModelConfig.model_type == model_type.to_origin_model_type(),
  384. LoadBalancingModelConfig.model_name == model,
  385. LoadBalancingModelConfig.id == config_id
  386. ).first()
  387. if not load_balancing_model_config:
  388. raise ValueError(f"Load balancing config {config_id} does not exist.")
  389. # Validate custom provider config
  390. self._custom_credentials_validate(
  391. tenant_id=tenant_id,
  392. provider_configuration=provider_configuration,
  393. model_type=model_type,
  394. model=model,
  395. credentials=credentials,
  396. load_balancing_model_config=load_balancing_model_config
  397. )
  398. def _custom_credentials_validate(self, tenant_id: str,
  399. provider_configuration: ProviderConfiguration,
  400. model_type: ModelType,
  401. model: str,
  402. credentials: dict,
  403. load_balancing_model_config: Optional[LoadBalancingModelConfig] = None,
  404. validate: bool = True) -> dict:
  405. """
  406. Validate custom credentials.
  407. :param tenant_id: workspace id
  408. :param provider_configuration: provider configuration
  409. :param model_type: model type
  410. :param model: model name
  411. :param credentials: credentials
  412. :param load_balancing_model_config: load balancing model config
  413. :param validate: validate credentials
  414. :return:
  415. """
  416. # Get credential form schemas from model credential schema or provider credential schema
  417. credential_schemas = self._get_credential_schema(provider_configuration)
  418. # Get provider credential secret variables
  419. provider_credential_secret_variables = provider_configuration.extract_secret_variables(
  420. credential_schemas.credential_form_schemas
  421. )
  422. if load_balancing_model_config:
  423. try:
  424. # fix origin data
  425. if load_balancing_model_config.encrypted_config:
  426. original_credentials = json.loads(load_balancing_model_config.encrypted_config)
  427. else:
  428. original_credentials = {}
  429. except JSONDecodeError:
  430. original_credentials = {}
  431. # encrypt credentials
  432. for key, value in credentials.items():
  433. if key in provider_credential_secret_variables:
  434. # if send [__HIDDEN__] in secret input, it will be same as original value
  435. if value == HIDDEN_VALUE and key in original_credentials:
  436. credentials[key] = encrypter.decrypt_token(tenant_id, original_credentials[key])
  437. if validate:
  438. if isinstance(credential_schemas, ModelCredentialSchema):
  439. credentials = model_provider_factory.model_credentials_validate(
  440. provider=provider_configuration.provider.provider,
  441. model_type=model_type,
  442. model=model,
  443. credentials=credentials
  444. )
  445. else:
  446. credentials = model_provider_factory.provider_credentials_validate(
  447. provider=provider_configuration.provider.provider,
  448. credentials=credentials
  449. )
  450. for key, value in credentials.items():
  451. if key in provider_credential_secret_variables:
  452. credentials[key] = encrypter.encrypt_token(tenant_id, value)
  453. return credentials
  454. def _get_credential_schema(self, provider_configuration: ProviderConfiguration) \
  455. -> ModelCredentialSchema | ProviderCredentialSchema:
  456. """
  457. Get form schemas.
  458. :param provider_configuration: provider configuration
  459. :return:
  460. """
  461. # Get credential form schemas from model credential schema or provider credential schema
  462. if provider_configuration.provider.model_credential_schema:
  463. credential_schema = provider_configuration.provider.model_credential_schema
  464. else:
  465. credential_schema = provider_configuration.provider.provider_credential_schema
  466. return credential_schema
  467. def _clear_credentials_cache(self, tenant_id: str, config_id: str) -> None:
  468. """
  469. Clear credentials cache.
  470. :param tenant_id: workspace id
  471. :param config_id: load balancing config id
  472. :return:
  473. """
  474. provider_model_credentials_cache = ProviderCredentialsCache(
  475. tenant_id=tenant_id,
  476. identity_id=config_id,
  477. cache_type=ProviderCredentialsCacheType.LOAD_BALANCING_MODEL
  478. )
  479. provider_model_credentials_cache.delete()