| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284 | 
							- from abc import ABC, abstractmethod
 
- from datetime import datetime
 
- from typing import Type, Optional
 
- from flask import current_app
 
- from pydantic import BaseModel
 
- from core.model_providers.error import QuotaExceededError, LLMBadRequestError
 
- from extensions.ext_database import db
 
- from core.model_providers.models.entity.model_params import ModelType, ModelKwargsRules
 
- from core.model_providers.models.entity.provider import ProviderQuotaUnit
 
- from core.model_providers.rules import provider_rules
 
- from models.provider import Provider, ProviderType, ProviderModel
 
- class BaseModelProvider(BaseModel, ABC):
 
-     provider: Provider
 
-     class Config:
 
-         """Configuration for this pydantic object."""
 
-         arbitrary_types_allowed = True
 
-     @property
 
-     @abstractmethod
 
-     def provider_name(self):
 
-         """
 
-         Returns the name of a provider.
 
-         """
 
-         raise NotImplementedError
 
-     def get_rules(self):
 
-         """
 
-         Returns the rules of a provider.
 
-         """
 
-         return provider_rules[self.provider_name]
 
-     def get_supported_model_list(self, model_type: ModelType) -> list[dict]:
 
-         """
 
-         get supported model object list for use.
 
-         :param model_type:
 
-         :return:
 
-         """
 
-         rules = self.get_rules()
 
-         if 'custom' not in rules['support_provider_types']:
 
-             return self._get_fixed_model_list(model_type)
 
-         if 'model_flexibility' not in rules:
 
-             return self._get_fixed_model_list(model_type)
 
-         if rules['model_flexibility'] == 'fixed':
 
-             return self._get_fixed_model_list(model_type)
 
-         # get configurable provider models
 
-         provider_models = db.session.query(ProviderModel).filter(
 
-             ProviderModel.tenant_id == self.provider.tenant_id,
 
-             ProviderModel.provider_name == self.provider.provider_name,
 
-             ProviderModel.model_type == model_type.value,
 
-             ProviderModel.is_valid == True
 
-         ).order_by(ProviderModel.created_at.asc()).all()
 
-         return [{
 
-             'id': provider_model.model_name,
 
-             'name': provider_model.model_name
 
-         } for provider_model in provider_models]
 
-     @abstractmethod
 
-     def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]:
 
-         """
 
-         get supported model object list for use.
 
-         :param model_type:
 
-         :return:
 
-         """
 
-         raise NotImplementedError
 
-     @abstractmethod
 
-     def get_model_class(self, model_type: ModelType) -> Type:
 
-         """
 
-         get specific model class.
 
-         :param model_type:
 
-         :return:
 
-         """
 
-         raise NotImplementedError
 
-     @classmethod
 
-     @abstractmethod
 
-     def is_provider_credentials_valid_or_raise(cls, credentials: dict):
 
-         """
 
-         check provider credentials valid.
 
-         :param credentials:
 
-         """
 
-         raise NotImplementedError
 
-     @classmethod
 
-     @abstractmethod
 
-     def encrypt_provider_credentials(cls, tenant_id: str, credentials: dict) -> dict:
 
-         """
 
-         encrypt provider credentials for save.
 
-         :param tenant_id:
 
-         :param credentials:
 
-         :return:
 
-         """
 
-         raise NotImplementedError
 
-     @abstractmethod
 
-     def get_provider_credentials(self, obfuscated: bool = False) -> dict:
 
-         """
 
-         get credentials for llm use.
 
-         :param obfuscated:
 
-         :return:
 
-         """
 
-         raise NotImplementedError
 
-     @classmethod
 
-     @abstractmethod
 
-     def is_model_credentials_valid_or_raise(cls, model_name: str, model_type: ModelType, credentials: dict):
 
-         """
 
-         check model credentials valid.
 
-         :param model_name:
 
-         :param model_type:
 
-         :param credentials:
 
-         """
 
-         raise NotImplementedError
 
-     @classmethod
 
-     @abstractmethod
 
-     def encrypt_model_credentials(cls, tenant_id: str, model_name: str, model_type: ModelType,
 
-                                   credentials: dict) -> dict:
 
-         """
 
-         encrypt model credentials for save.
 
-         :param tenant_id:
 
-         :param model_name:
 
-         :param model_type:
 
-         :param credentials:
 
-         :return:
 
-         """
 
-         raise NotImplementedError
 
-     @abstractmethod
 
-     def get_model_parameter_rules(self, model_name: str, model_type: ModelType) -> ModelKwargsRules:
 
-         """
 
-         get model parameter rules.
 
-         :param model_name:
 
-         :param model_type:
 
-         :return:
 
-         """
 
-         raise NotImplementedError
 
-     @abstractmethod
 
-     def get_model_credentials(self, model_name: str, model_type: ModelType, obfuscated: bool = False) -> dict:
 
-         """
 
-         get credentials for llm use.
 
-         :param model_name:
 
-         :param model_type:
 
-         :param obfuscated:
 
-         :return:
 
-         """
 
-         raise NotImplementedError
 
-     @classmethod
 
-     def is_provider_type_system_supported(cls) -> bool:
 
-         return current_app.config['EDITION'] == 'CLOUD'
 
-     def check_quota_over_limit(self):
 
-         """
 
-         check provider quota over limit.
 
-         :return:
 
-         """
 
-         if self.provider.provider_type != ProviderType.SYSTEM.value:
 
-             return
 
-         rules = self.get_rules()
 
-         if 'system' not in rules['support_provider_types']:
 
-             return
 
-         provider = db.session.query(Provider).filter(
 
-             db.and_(
 
-                 Provider.id == self.provider.id,
 
-                 Provider.is_valid == True,
 
-                 Provider.quota_limit > Provider.quota_used
 
-             )
 
-         ).first()
 
-         if not provider:
 
-             raise QuotaExceededError()
 
-     def deduct_quota(self, used_tokens: int = 0) -> None:
 
-         """
 
-         deduct available quota when provider type is system or paid.
 
-         :return:
 
-         """
 
-         if self.provider.provider_type != ProviderType.SYSTEM.value:
 
-             return
 
-         rules = self.get_rules()
 
-         if 'system' not in rules['support_provider_types']:
 
-             return
 
-         if not self.should_deduct_quota():
 
-             return
 
-         if 'system_config' not in rules:
 
-             quota_unit = ProviderQuotaUnit.TIMES.value
 
-         elif 'quota_unit' not in rules['system_config']:
 
-             quota_unit = ProviderQuotaUnit.TIMES.value
 
-         else:
 
-             quota_unit = rules['system_config']['quota_unit']
 
-         if quota_unit == ProviderQuotaUnit.TOKENS.value:
 
-             used_quota = used_tokens
 
-         else:
 
-             used_quota = 1
 
-         db.session.query(Provider).filter(
 
-             Provider.tenant_id == self.provider.tenant_id,
 
-             Provider.provider_name == self.provider.provider_name,
 
-             Provider.provider_type == self.provider.provider_type,
 
-             Provider.quota_type == self.provider.quota_type,
 
-             Provider.quota_limit > Provider.quota_used
 
-         ).update({'quota_used': Provider.quota_used + used_quota})
 
-         db.session.commit()
 
-     def should_deduct_quota(self):
 
-         return False
 
-     def update_last_used(self) -> None:
 
-         """
 
-         update last used time.
 
-         :return:
 
-         """
 
-         db.session.query(Provider).filter(
 
-             Provider.tenant_id == self.provider.tenant_id,
 
-             Provider.provider_name == self.provider.provider_name
 
-         ).update({'last_used': datetime.utcnow()})
 
-         db.session.commit()
 
-     def get_payment_info(self) -> Optional[dict]:
 
-         """
 
-         get product info if it payable.
 
-         :return:
 
-         """
 
-         return None
 
-     def _get_provider_model(self, model_name: str, model_type: ModelType) -> ProviderModel:
 
-         """
 
-         get provider model.
 
-         :param model_name:
 
-         :param model_type:
 
-         :return:
 
-         """
 
-         provider_model = db.session.query(ProviderModel).filter(
 
-             ProviderModel.tenant_id == self.provider.tenant_id,
 
-             ProviderModel.provider_name == self.provider.provider_name,
 
-             ProviderModel.model_name == model_name,
 
-             ProviderModel.model_type == model_type.value,
 
-             ProviderModel.is_valid == True
 
-         ).first()
 
-         if not provider_model:
 
-             raise LLMBadRequestError(f"The model {model_name} does not exist. "
 
-                                      f"Please check the configuration.")
 
-         return provider_model
 
- class CredentialsValidateFailedError(Exception):
 
-     pass
 
 
  |