|
@@ -105,14 +105,8 @@ class ProviderManager:
|
|
|
# Construct ProviderConfiguration objects for each provider
|
|
|
for provider_entity in provider_entities:
|
|
|
provider_name = provider_entity.provider
|
|
|
-
|
|
|
- provider_records = provider_name_to_provider_records_dict.get(provider_entity.provider)
|
|
|
- if not provider_records:
|
|
|
- provider_records = []
|
|
|
-
|
|
|
- provider_model_records = provider_name_to_provider_model_records_dict.get(provider_entity.provider)
|
|
|
- if not provider_model_records:
|
|
|
- provider_model_records = []
|
|
|
+ provider_records = provider_name_to_provider_records_dict.get(provider_entity.provider, [])
|
|
|
+ provider_model_records = provider_name_to_provider_model_records_dict.get(provider_entity.provider, [])
|
|
|
|
|
|
# Convert to custom configuration
|
|
|
custom_configuration = self._to_custom_configuration(
|
|
@@ -134,38 +128,24 @@ class ProviderManager:
|
|
|
|
|
|
if preferred_provider_type_record:
|
|
|
preferred_provider_type = ProviderType.value_of(preferred_provider_type_record.preferred_provider_type)
|
|
|
+ elif custom_configuration.provider or custom_configuration.models:
|
|
|
+ preferred_provider_type = ProviderType.CUSTOM
|
|
|
+ elif system_configuration.enabled:
|
|
|
+ preferred_provider_type = ProviderType.SYSTEM
|
|
|
else:
|
|
|
- if custom_configuration.provider or custom_configuration.models:
|
|
|
- preferred_provider_type = ProviderType.CUSTOM
|
|
|
- elif system_configuration.enabled:
|
|
|
- preferred_provider_type = ProviderType.SYSTEM
|
|
|
- else:
|
|
|
- preferred_provider_type = ProviderType.CUSTOM
|
|
|
+ preferred_provider_type = ProviderType.CUSTOM
|
|
|
|
|
|
using_provider_type = preferred_provider_type
|
|
|
+ has_valid_quota = any(quota_conf.is_valid for quota_conf in system_configuration.quota_configurations)
|
|
|
+
|
|
|
if preferred_provider_type == ProviderType.SYSTEM:
|
|
|
- if not system_configuration.enabled:
|
|
|
+ if not system_configuration.enabled or not has_valid_quota:
|
|
|
using_provider_type = ProviderType.CUSTOM
|
|
|
|
|
|
- has_valid_quota = False
|
|
|
- for quota_configuration in system_configuration.quota_configurations:
|
|
|
- if quota_configuration.is_valid:
|
|
|
- has_valid_quota = True
|
|
|
- break
|
|
|
-
|
|
|
- if not has_valid_quota:
|
|
|
- using_provider_type = ProviderType.CUSTOM
|
|
|
else:
|
|
|
if not custom_configuration.provider and not custom_configuration.models:
|
|
|
- if system_configuration.enabled:
|
|
|
- has_valid_quota = False
|
|
|
- for quota_configuration in system_configuration.quota_configurations:
|
|
|
- if quota_configuration.is_valid:
|
|
|
- has_valid_quota = True
|
|
|
- break
|
|
|
-
|
|
|
- if has_valid_quota:
|
|
|
- using_provider_type = ProviderType.SYSTEM
|
|
|
+ if system_configuration.enabled and has_valid_quota:
|
|
|
+ using_provider_type = ProviderType.SYSTEM
|
|
|
|
|
|
provider_configuration = ProviderConfiguration(
|
|
|
tenant_id=tenant_id,
|
|
@@ -233,30 +213,17 @@ class ProviderManager:
|
|
|
)
|
|
|
|
|
|
if available_models:
|
|
|
- found = False
|
|
|
- for available_model in available_models:
|
|
|
- if available_model.model == "gpt-4":
|
|
|
- default_model = TenantDefaultModel(
|
|
|
- tenant_id=tenant_id,
|
|
|
- model_type=model_type.to_origin_model_type(),
|
|
|
- provider_name=available_model.provider.provider,
|
|
|
- model_name=available_model.model
|
|
|
- )
|
|
|
- db.session.add(default_model)
|
|
|
- db.session.commit()
|
|
|
- found = True
|
|
|
- break
|
|
|
-
|
|
|
- if not found:
|
|
|
- available_model = available_models[0]
|
|
|
- default_model = TenantDefaultModel(
|
|
|
- tenant_id=tenant_id,
|
|
|
- model_type=model_type.to_origin_model_type(),
|
|
|
- provider_name=available_model.provider.provider,
|
|
|
- model_name=available_model.model
|
|
|
- )
|
|
|
- db.session.add(default_model)
|
|
|
- db.session.commit()
|
|
|
+ available_model = next((model for model in available_models if model.model == "gpt-4"),
|
|
|
+ available_models[0])
|
|
|
+
|
|
|
+ default_model = TenantDefaultModel(
|
|
|
+ tenant_id=tenant_id,
|
|
|
+ model_type=model_type.to_origin_model_type(),
|
|
|
+ provider_name=available_model.provider.provider,
|
|
|
+ model_name=available_model.model
|
|
|
+ )
|
|
|
+ db.session.add(default_model)
|
|
|
+ db.session.commit()
|
|
|
|
|
|
if not default_model:
|
|
|
return None
|