| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504 | import datetimeimport jsonfrom collections import defaultdictfrom typing import Optionalfrom core.model_providers.model_factory import ModelFactoryfrom extensions.ext_database import dbfrom core.model_providers.model_provider_factory import ModelProviderFactoryfrom core.model_providers.models.entity.model_params import ModelType, ModelKwargsRulesfrom models.provider import Provider, ProviderModel, TenantPreferredModelProvider, ProviderType, ProviderQuotaType, \    TenantDefaultModelclass ProviderService:    def get_provider_list(self, tenant_id: str):        """        get provider list of tenant.        :param tenant_id:        :return:        """        # get rules for all providers        model_provider_rules = ModelProviderFactory.get_provider_rules()        model_provider_names = [model_provider_name for model_provider_name, _ in model_provider_rules.items()]        configurable_model_provider_names = [            model_provider_name            for model_provider_name, model_provider_rules in model_provider_rules.items()            if 'custom' in model_provider_rules['support_provider_types']               and model_provider_rules['model_flexibility'] == 'configurable'        ]        # get all providers for the tenant        providers = db.session.query(Provider) \            .filter(            Provider.tenant_id == tenant_id,            Provider.provider_name.in_(model_provider_names),            Provider.is_valid == True        ).order_by(Provider.created_at.desc()).all()        provider_name_to_provider_dict = defaultdict(list)        for provider in providers:            provider_name_to_provider_dict[provider.provider_name].append(provider)        # get all configurable provider models for the tenant        provider_models = db.session.query(ProviderModel) \            .filter(            ProviderModel.tenant_id == tenant_id,            ProviderModel.provider_name.in_(configurable_model_provider_names),            ProviderModel.is_valid == True        ).order_by(ProviderModel.created_at.desc()).all()        provider_name_to_provider_model_dict = defaultdict(list)        for provider_model in provider_models:            provider_name_to_provider_model_dict[provider_model.provider_name].append(provider_model)        # get all preferred provider type for the tenant        preferred_provider_types = db.session.query(TenantPreferredModelProvider) \            .filter(            TenantPreferredModelProvider.tenant_id == tenant_id,            TenantPreferredModelProvider.provider_name.in_(model_provider_names)        ).all()        provider_name_to_preferred_provider_type_dict = {preferred_provider_type.provider_name: preferred_provider_type                                                         for preferred_provider_type in preferred_provider_types}        providers_list = {}        for model_provider_name, model_provider_rule in model_provider_rules.items():            # get preferred provider type            preferred_model_provider = provider_name_to_preferred_provider_type_dict.get(model_provider_name)            preferred_provider_type = ModelProviderFactory.get_preferred_type_by_preferred_model_provider(                tenant_id,                model_provider_name,                preferred_model_provider            )            provider_config_dict = {                "preferred_provider_type": preferred_provider_type,                "model_flexibility": model_provider_rule['model_flexibility'],            }            provider_parameter_dict = {}            if ProviderType.SYSTEM.value in model_provider_rule['support_provider_types']:                for quota_type_enum in ProviderQuotaType:                    quota_type = quota_type_enum.value                    if quota_type in model_provider_rule['system_config']['supported_quota_types']:                        key = ProviderType.SYSTEM.value + ':' + quota_type                        provider_parameter_dict[key] = {                            "provider_name": model_provider_name,                            "provider_type": ProviderType.SYSTEM.value,                            "config": None,                            "is_valid": False,  # need update                            "quota_type": quota_type,                            "quota_unit": model_provider_rule['system_config']['quota_unit'],  # need update                            "quota_limit": 0 if quota_type != ProviderQuotaType.TRIAL.value else                            model_provider_rule['system_config']['quota_limit'],  # need update                            "quota_used": 0,  # need update                            "last_used": None  # need update                        }            if ProviderType.CUSTOM.value in model_provider_rule['support_provider_types']:                provider_parameter_dict[ProviderType.CUSTOM.value] = {                    "provider_name": model_provider_name,                    "provider_type": ProviderType.CUSTOM.value,                    "config": None,  # need update                    "models": [],  # need update                    "is_valid": False,                    "last_used": None  # need update                }            model_provider_class = ModelProviderFactory.get_model_provider_class(model_provider_name)            current_providers = provider_name_to_provider_dict[model_provider_name]            for provider in current_providers:                if provider.provider_type == ProviderType.SYSTEM.value:                    quota_type = provider.quota_type                    key = f'{ProviderType.SYSTEM.value}:{quota_type}'                    if key in provider_parameter_dict:                        provider_parameter_dict[key]['is_valid'] = provider.is_valid                        provider_parameter_dict[key]['quota_used'] = provider.quota_used                        provider_parameter_dict[key]['quota_limit'] = provider.quota_limit                        provider_parameter_dict[key]['last_used'] = provider.last_used                elif provider.provider_type == ProviderType.CUSTOM.value \                        and ProviderType.CUSTOM.value in provider_parameter_dict:                    # if custom                    key = ProviderType.CUSTOM.value                    provider_parameter_dict[key]['last_used'] = provider.last_used                    provider_parameter_dict[key]['is_valid'] = provider.is_valid                    if model_provider_rule['model_flexibility'] == 'fixed':                        provider_parameter_dict[key]['config'] = model_provider_class(provider=provider) \                            .get_provider_credentials(obfuscated=True)                    else:                        models = []                        provider_models = provider_name_to_provider_model_dict[model_provider_name]                        for provider_model in provider_models:                            models.append({                                "model_name": provider_model.model_name,                                "model_type": provider_model.model_type,                                "config": model_provider_class(provider=provider) \                                    .get_model_credentials(provider_model.model_name,                                                           ModelType.value_of(provider_model.model_type),                                                           obfuscated=True),                                "is_valid": provider_model.is_valid                            })                        provider_parameter_dict[key]['models'] = models            provider_config_dict['providers'] = list(provider_parameter_dict.values())            providers_list[model_provider_name] = provider_config_dict        return providers_list    def custom_provider_config_validate(self, provider_name: str, config: dict) -> None:        """        validate custom provider config.        :param provider_name:        :param config:        :return:        :raises CredentialsValidateFailedError: When the config credential verification fails.        """        # get model provider rules        model_provider_rules = ModelProviderFactory.get_provider_rule(provider_name)        if model_provider_rules['model_flexibility'] != 'fixed':            raise ValueError('Only support fixed model provider')        # only support provider type CUSTOM        if ProviderType.CUSTOM.value not in model_provider_rules['support_provider_types']:            raise ValueError('Only support provider type CUSTOM')        # validate provider config        model_provider_class = ModelProviderFactory.get_model_provider_class(provider_name)        model_provider_class.is_provider_credentials_valid_or_raise(config)    def save_custom_provider_config(self, tenant_id: str, provider_name: str, config: dict) -> None:        """        save custom provider config.        :param tenant_id:        :param provider_name:        :param config:        :return:        """        # validate custom provider config        self.custom_provider_config_validate(provider_name, config)        # get provider        provider = db.session.query(Provider) \            .filter(            Provider.tenant_id == tenant_id,            Provider.provider_name == provider_name,            Provider.provider_type == ProviderType.CUSTOM.value        ).first()        model_provider_class = ModelProviderFactory.get_model_provider_class(provider_name)        encrypted_config = model_provider_class.encrypt_provider_credentials(tenant_id, config)        # save provider        if provider:            provider.encrypted_config = json.dumps(encrypted_config)            provider.is_valid = True            provider.updated_at = datetime.datetime.utcnow()            db.session.commit()        else:            provider = Provider(                tenant_id=tenant_id,                provider_name=provider_name,                provider_type=ProviderType.CUSTOM.value,                encrypted_config=json.dumps(encrypted_config),                is_valid=True            )            db.session.add(provider)            db.session.commit()    def delete_custom_provider(self, tenant_id: str, provider_name: str) -> None:        """        delete custom provider.        :param tenant_id:        :param provider_name:        :return:        """        # get provider        provider = db.session.query(Provider) \            .filter(            Provider.tenant_id == tenant_id,            Provider.provider_name == provider_name,            Provider.provider_type == ProviderType.CUSTOM.value        ).first()        if provider:            try:                self.switch_preferred_provider(tenant_id, provider_name, ProviderType.SYSTEM.value)            except ValueError:                pass            db.session.delete(provider)            db.session.commit()    def custom_provider_model_config_validate(self,                                              provider_name: str,                                              model_name: str,                                              model_type: str,                                              config: dict) -> None:        """        validate custom provider model config.        :param provider_name:        :param model_name:        :param model_type:        :param config:        :return:        :raises CredentialsValidateFailedError: When the config credential verification fails.        """        # get model provider rules        model_provider_rules = ModelProviderFactory.get_provider_rule(provider_name)        if model_provider_rules['model_flexibility'] != 'configurable':            raise ValueError('Only support configurable model provider')        # only support provider type CUSTOM        if ProviderType.CUSTOM.value not in model_provider_rules['support_provider_types']:            raise ValueError('Only support provider type CUSTOM')        # validate provider model config        model_type = ModelType.value_of(model_type)        model_provider_class = ModelProviderFactory.get_model_provider_class(provider_name)        model_provider_class.is_model_credentials_valid_or_raise(model_name, model_type, config)    def add_or_save_custom_provider_model_config(self,                                                 tenant_id: str,                                                 provider_name: str,                                                 model_name: str,                                                 model_type: str,                                                 config: dict) -> None:        """        Add or save custom provider model config.        :param tenant_id:        :param provider_name:        :param model_name:        :param model_type:        :param config:        :return:        """        # validate custom provider model config        self.custom_provider_model_config_validate(provider_name, model_name, model_type, config)        # get provider        provider = db.session.query(Provider) \            .filter(            Provider.tenant_id == tenant_id,            Provider.provider_name == provider_name,            Provider.provider_type == ProviderType.CUSTOM.value        ).first()        if not provider:            provider = Provider(                tenant_id=tenant_id,                provider_name=provider_name,                provider_type=ProviderType.CUSTOM.value,                is_valid=True            )            db.session.add(provider)            db.session.commit()        elif not provider.is_valid:            provider.is_valid = True            provider.encrypted_config = None            db.session.commit()        model_provider_class = ModelProviderFactory.get_model_provider_class(provider_name)        encrypted_config = model_provider_class.encrypt_model_credentials(            tenant_id,            model_name,            ModelType.value_of(model_type),            config        )        # get provider model        provider_model = db.session.query(ProviderModel) \            .filter(            ProviderModel.tenant_id == tenant_id,            ProviderModel.provider_name == provider_name,            ProviderModel.model_name == model_name,            ProviderModel.model_type == model_type        ).first()        if provider_model:            provider_model.encrypted_config = json.dumps(encrypted_config)            provider_model.is_valid = True            db.session.commit()        else:            provider_model = ProviderModel(                tenant_id=tenant_id,                provider_name=provider_name,                model_name=model_name,                model_type=model_type,                encrypted_config=json.dumps(encrypted_config),                is_valid=True            )            db.session.add(provider_model)            db.session.commit()    def delete_custom_provider_model(self,                                     tenant_id: str,                                     provider_name: str,                                     model_name: str,                                     model_type: str) -> None:        """        delete custom provider model.        :param tenant_id:        :param provider_name:        :param model_name:        :param model_type:        :return:        """        # get provider model        provider_model = db.session.query(ProviderModel) \            .filter(            ProviderModel.tenant_id == tenant_id,            ProviderModel.provider_name == provider_name,            ProviderModel.model_name == model_name,            ProviderModel.model_type == model_type        ).first()        if provider_model:            db.session.delete(provider_model)            db.session.commit()    def switch_preferred_provider(self, tenant_id: str, provider_name: str, preferred_provider_type: str) -> None:        """        switch preferred provider.        :param tenant_id:        :param provider_name:        :param preferred_provider_type:        :return:        """        provider_type = ProviderType.value_of(preferred_provider_type)        if not provider_type:            raise ValueError(f'Invalid preferred provider type: {preferred_provider_type}')        model_provider_rules = ModelProviderFactory.get_provider_rule(provider_name)        if preferred_provider_type not in model_provider_rules['support_provider_types']:            raise ValueError(f'Not support provider type: {preferred_provider_type}')        model_provider = ModelProviderFactory.get_model_provider_class(provider_name)        if not model_provider.is_provider_type_system_supported():            return        # get preferred provider        preferred_model_provider = db.session.query(TenantPreferredModelProvider) \            .filter(            TenantPreferredModelProvider.tenant_id == tenant_id,            TenantPreferredModelProvider.provider_name == provider_name        ).first()        if preferred_model_provider:            preferred_model_provider.preferred_provider_type = preferred_provider_type        else:            preferred_model_provider = TenantPreferredModelProvider(                tenant_id=tenant_id,                provider_name=provider_name,                preferred_provider_type=preferred_provider_type            )            db.session.add(preferred_model_provider)        db.session.commit()    def get_default_model_of_model_type(self, tenant_id: str, model_type: str) -> Optional[TenantDefaultModel]:        """        get default model of model type.        :param tenant_id:        :param model_type:        :return:        """        return ModelFactory.get_default_model(tenant_id, ModelType.value_of(model_type))    def update_default_model_of_model_type(self,                                           tenant_id: str,                                           model_type: str,                                           provider_name: str,                                           model_name: str) -> TenantDefaultModel:        """        update default model of model type.        :param tenant_id:        :param model_type:        :param provider_name:        :param model_name:        :return:        """        return ModelFactory.update_default_model(tenant_id, ModelType.value_of(model_type), provider_name, model_name)    def get_valid_model_list(self, tenant_id: str, model_type: str) -> list:        """        get valid model list.        :param tenant_id:        :param model_type:        :return:        """        valid_model_list = []        # get model provider rules        model_provider_rules = ModelProviderFactory.get_provider_rules()        for model_provider_name, model_provider_rule in model_provider_rules.items():            model_provider = ModelProviderFactory.get_preferred_model_provider(tenant_id, model_provider_name)            if not model_provider:                continue            model_list = model_provider.get_supported_model_list(ModelType.value_of(model_type))            provider = model_provider.provider            for model in model_list:                valid_model_dict = {                    "model_name": model['id'],                    "model_type": model_type,                    "model_provider": {                        "provider_name": provider.provider_name,                        "provider_type": provider.provider_type                    },                    'features': []                }                if 'features' in model:                    valid_model_dict['features'] = model['features']                if provider.provider_type == ProviderType.SYSTEM.value:                    valid_model_dict['model_provider']['quota_type'] = provider.quota_type                    valid_model_dict['model_provider']['quota_unit'] = model_provider_rule['system_config']['quota_unit']                    valid_model_dict['model_provider']['quota_limit'] = provider.quota_limit                    valid_model_dict['model_provider']['quota_used'] = provider.quota_used                valid_model_list.append(valid_model_dict)        return valid_model_list    def get_model_parameter_rules(self, tenant_id: str, model_provider_name: str, model_name: str, model_type: str) \            -> ModelKwargsRules:        """        get model parameter rules.        It depends on preferred provider in use.        :param tenant_id:        :param model_provider_name:        :param model_name:        :param model_type:        :return:        """        # get model provider        model_provider = ModelProviderFactory.get_preferred_model_provider(tenant_id, model_provider_name)        if not model_provider:            # get empty model provider            return ModelKwargsRules()        # get model parameter rules        return model_provider.get_model_parameter_rules(model_name, ModelType.value_of(model_type))
 |