Explorar o código

Refactor part of the ProviderManager code to improve readability (#4524)

非法操作 hai 11 meses
pai
achega
3efb5fe7e2
Modificáronse 1 ficheiros con 23 adicións e 56 borrados
  1. 23 56
      api/core/provider_manager.py

+ 23 - 56
api/core/provider_manager.py

@@ -105,14 +105,8 @@ class ProviderManager:
         # Construct ProviderConfiguration objects for each provider
         for provider_entity in provider_entities:
             provider_name = provider_entity.provider
-
-            provider_records = provider_name_to_provider_records_dict.get(provider_entity.provider)
-            if not provider_records:
-                provider_records = []
-
-            provider_model_records = provider_name_to_provider_model_records_dict.get(provider_entity.provider)
-            if not provider_model_records:
-                provider_model_records = []
+            provider_records = provider_name_to_provider_records_dict.get(provider_entity.provider, [])
+            provider_model_records = provider_name_to_provider_model_records_dict.get(provider_entity.provider, [])
 
             # Convert to custom configuration
             custom_configuration = self._to_custom_configuration(
@@ -134,38 +128,24 @@ class ProviderManager:
 
             if preferred_provider_type_record:
                 preferred_provider_type = ProviderType.value_of(preferred_provider_type_record.preferred_provider_type)
+            elif custom_configuration.provider or custom_configuration.models:
+                preferred_provider_type = ProviderType.CUSTOM
+            elif system_configuration.enabled:
+                preferred_provider_type = ProviderType.SYSTEM
             else:
-                if custom_configuration.provider or custom_configuration.models:
-                    preferred_provider_type = ProviderType.CUSTOM
-                elif system_configuration.enabled:
-                    preferred_provider_type = ProviderType.SYSTEM
-                else:
-                    preferred_provider_type = ProviderType.CUSTOM
+                preferred_provider_type = ProviderType.CUSTOM
 
             using_provider_type = preferred_provider_type
+            has_valid_quota = any(quota_conf.is_valid for quota_conf in system_configuration.quota_configurations)
+
             if preferred_provider_type == ProviderType.SYSTEM:
-                if not system_configuration.enabled:
+                if not system_configuration.enabled or not has_valid_quota:
                     using_provider_type = ProviderType.CUSTOM
 
-                has_valid_quota = False
-                for quota_configuration in system_configuration.quota_configurations:
-                    if quota_configuration.is_valid:
-                        has_valid_quota = True
-                        break
-
-                if not has_valid_quota:
-                    using_provider_type = ProviderType.CUSTOM
             else:
                 if not custom_configuration.provider and not custom_configuration.models:
-                    if system_configuration.enabled:
-                        has_valid_quota = False
-                        for quota_configuration in system_configuration.quota_configurations:
-                            if quota_configuration.is_valid:
-                                has_valid_quota = True
-                                break
-
-                        if has_valid_quota:
-                            using_provider_type = ProviderType.SYSTEM
+                    if system_configuration.enabled and has_valid_quota:
+                        using_provider_type = ProviderType.SYSTEM
 
             provider_configuration = ProviderConfiguration(
                 tenant_id=tenant_id,
@@ -233,30 +213,17 @@ class ProviderManager:
             )
 
             if available_models:
-                found = False
-                for available_model in available_models:
-                    if available_model.model == "gpt-4":
-                        default_model = TenantDefaultModel(
-                            tenant_id=tenant_id,
-                            model_type=model_type.to_origin_model_type(),
-                            provider_name=available_model.provider.provider,
-                            model_name=available_model.model
-                        )
-                        db.session.add(default_model)
-                        db.session.commit()
-                        found = True
-                        break
-
-                if not found:
-                    available_model = available_models[0]
-                    default_model = TenantDefaultModel(
-                        tenant_id=tenant_id,
-                        model_type=model_type.to_origin_model_type(),
-                        provider_name=available_model.provider.provider,
-                        model_name=available_model.model
-                    )
-                    db.session.add(default_model)
-                    db.session.commit()
+                available_model = next((model for model in available_models if model.model == "gpt-4"),
+                                       available_models[0])
+
+                default_model = TenantDefaultModel(
+                    tenant_id=tenant_id,
+                    model_type=model_type.to_origin_model_type(),
+                    provider_name=available_model.provider.provider,
+                    model_name=available_model.model
+                )
+                db.session.add(default_model)
+                db.session.commit()
 
         if not default_model:
             return None