소스 검색

refactor: Add @staticmethod decorator in `api/core` (#7652)

Shota Totsuka 7 달 전
부모
커밋
430e100142
4개의 변경된 파일52개의 추가작업 그리고 27개의 파일을 삭제
  1. 12 6
      api/core/hosting_configuration.py
  2. 18 9
      api/core/indexing_runner.py
  3. 6 4
      api/core/model_manager.py
  4. 16 8
      api/core/provider_manager.py

+ 12 - 6
api/core/hosting_configuration.py

@@ -58,7 +58,8 @@ class HostingConfiguration:
 
 
         self.moderation_config = self.init_moderation_config(config)
         self.moderation_config = self.init_moderation_config(config)
 
 
-    def init_azure_openai(self, app_config: Config) -> HostingProvider:
+    @staticmethod
+    def init_azure_openai(app_config: Config) -> HostingProvider:
         quota_unit = QuotaUnit.TIMES
         quota_unit = QuotaUnit.TIMES
         if app_config.get("HOSTED_AZURE_OPENAI_ENABLED"):
         if app_config.get("HOSTED_AZURE_OPENAI_ENABLED"):
             credentials = {
             credentials = {
@@ -145,7 +146,8 @@ class HostingConfiguration:
             quota_unit=quota_unit,
             quota_unit=quota_unit,
         )
         )
 
 
-    def init_anthropic(self, app_config: Config) -> HostingProvider:
+    @staticmethod
+    def init_anthropic(app_config: Config) -> HostingProvider:
         quota_unit = QuotaUnit.TOKENS
         quota_unit = QuotaUnit.TOKENS
         quotas = []
         quotas = []
 
 
@@ -180,7 +182,8 @@ class HostingConfiguration:
             quota_unit=quota_unit,
             quota_unit=quota_unit,
         )
         )
 
 
-    def init_minimax(self, app_config: Config) -> HostingProvider:
+    @staticmethod
+    def init_minimax(app_config: Config) -> HostingProvider:
         quota_unit = QuotaUnit.TOKENS
         quota_unit = QuotaUnit.TOKENS
         if app_config.get("HOSTED_MINIMAX_ENABLED"):
         if app_config.get("HOSTED_MINIMAX_ENABLED"):
             quotas = [FreeHostingQuota()]
             quotas = [FreeHostingQuota()]
@@ -197,7 +200,8 @@ class HostingConfiguration:
             quota_unit=quota_unit,
             quota_unit=quota_unit,
         )
         )
 
 
-    def init_spark(self, app_config: Config) -> HostingProvider:
+    @staticmethod
+    def init_spark(app_config: Config) -> HostingProvider:
         quota_unit = QuotaUnit.TOKENS
         quota_unit = QuotaUnit.TOKENS
         if app_config.get("HOSTED_SPARK_ENABLED"):
         if app_config.get("HOSTED_SPARK_ENABLED"):
             quotas = [FreeHostingQuota()]
             quotas = [FreeHostingQuota()]
@@ -214,7 +218,8 @@ class HostingConfiguration:
             quota_unit=quota_unit,
             quota_unit=quota_unit,
         )
         )
 
 
-    def init_zhipuai(self, app_config: Config) -> HostingProvider:
+    @staticmethod
+    def init_zhipuai(app_config: Config) -> HostingProvider:
         quota_unit = QuotaUnit.TOKENS
         quota_unit = QuotaUnit.TOKENS
         if app_config.get("HOSTED_ZHIPUAI_ENABLED"):
         if app_config.get("HOSTED_ZHIPUAI_ENABLED"):
             quotas = [FreeHostingQuota()]
             quotas = [FreeHostingQuota()]
@@ -231,7 +236,8 @@ class HostingConfiguration:
             quota_unit=quota_unit,
             quota_unit=quota_unit,
         )
         )
 
 
-    def init_moderation_config(self, app_config: Config) -> HostedModerationConfig:
+    @staticmethod
+    def init_moderation_config(app_config: Config) -> HostedModerationConfig:
         if app_config.get("HOSTED_MODERATION_ENABLED") \
         if app_config.get("HOSTED_MODERATION_ENABLED") \
                 and app_config.get("HOSTED_MODERATION_PROVIDERS"):
                 and app_config.get("HOSTED_MODERATION_PROVIDERS"):
             return HostedModerationConfig(
             return HostedModerationConfig(

+ 18 - 9
api/core/indexing_runner.py

@@ -411,7 +411,8 @@ class IndexingRunner:
 
 
         return text_docs
         return text_docs
 
 
-    def filter_string(self, text):
+    @staticmethod
+    def filter_string(text):
         text = re.sub(r'<\|', '<', text)
         text = re.sub(r'<\|', '<', text)
         text = re.sub(r'\|>', '>', text)
         text = re.sub(r'\|>', '>', text)
         text = re.sub(r'[\x00-\x08\x0B\x0C\x0E-\x1F\x7F\xEF\xBF\xBE]', '', text)
         text = re.sub(r'[\x00-\x08\x0B\x0C\x0E-\x1F\x7F\xEF\xBF\xBE]', '', text)
@@ -419,7 +420,8 @@ class IndexingRunner:
         text = re.sub('\uFFFE', '', text)
         text = re.sub('\uFFFE', '', text)
         return text
         return text
 
 
-    def _get_splitter(self, processing_rule: DatasetProcessRule,
+    @staticmethod
+    def _get_splitter(processing_rule: DatasetProcessRule,
                       embedding_model_instance: Optional[ModelInstance]) -> TextSplitter:
                       embedding_model_instance: Optional[ModelInstance]) -> TextSplitter:
         """
         """
         Get the NodeParser object according to the processing rule.
         Get the NodeParser object according to the processing rule.
@@ -611,7 +613,8 @@ class IndexingRunner:
 
 
         return all_documents
         return all_documents
 
 
-    def _document_clean(self, text: str, processing_rule: DatasetProcessRule) -> str:
+    @staticmethod
+    def _document_clean(text: str, processing_rule: DatasetProcessRule) -> str:
         """
         """
         Clean the document text according to the processing rules.
         Clean the document text according to the processing rules.
         """
         """
@@ -640,7 +643,8 @@ class IndexingRunner:
 
 
         return text
         return text
 
 
-    def format_split_text(self, text):
+    @staticmethod
+    def format_split_text(text):
         regex = r"Q\d+:\s*(.*?)\s*A\d+:\s*([\s\S]*?)(?=Q\d+:|$)"
         regex = r"Q\d+:\s*(.*?)\s*A\d+:\s*([\s\S]*?)(?=Q\d+:|$)"
         matches = re.findall(regex, text, re.UNICODE)
         matches = re.findall(regex, text, re.UNICODE)
 
 
@@ -704,7 +708,8 @@ class IndexingRunner:
             }
             }
         )
         )
 
 
-    def _process_keyword_index(self, flask_app, dataset_id, document_id, documents):
+    @staticmethod
+    def _process_keyword_index(flask_app, dataset_id, document_id, documents):
         with flask_app.app_context():
         with flask_app.app_context():
             dataset = Dataset.query.filter_by(id=dataset_id).first()
             dataset = Dataset.query.filter_by(id=dataset_id).first()
             if not dataset:
             if not dataset:
@@ -758,13 +763,15 @@ class IndexingRunner:
 
 
             return tokens
             return tokens
 
 
-    def _check_document_paused_status(self, document_id: str):
+    @staticmethod
+    def _check_document_paused_status(document_id: str):
         indexing_cache_key = 'document_{}_is_paused'.format(document_id)
         indexing_cache_key = 'document_{}_is_paused'.format(document_id)
         result = redis_client.get(indexing_cache_key)
         result = redis_client.get(indexing_cache_key)
         if result:
         if result:
             raise DocumentIsPausedException()
             raise DocumentIsPausedException()
 
 
-    def _update_document_index_status(self, document_id: str, after_indexing_status: str,
+    @staticmethod
+    def _update_document_index_status(document_id: str, after_indexing_status: str,
                                       extra_update_params: Optional[dict] = None) -> None:
                                       extra_update_params: Optional[dict] = None) -> None:
         """
         """
         Update the document indexing status.
         Update the document indexing status.
@@ -786,14 +793,16 @@ class IndexingRunner:
         DatasetDocument.query.filter_by(id=document_id).update(update_params)
         DatasetDocument.query.filter_by(id=document_id).update(update_params)
         db.session.commit()
         db.session.commit()
 
 
-    def _update_segments_by_document(self, dataset_document_id: str, update_params: dict) -> None:
+    @staticmethod
+    def _update_segments_by_document(dataset_document_id: str, update_params: dict) -> None:
         """
         """
         Update the document segment by document id.
         Update the document segment by document id.
         """
         """
         DocumentSegment.query.filter_by(document_id=dataset_document_id).update(update_params)
         DocumentSegment.query.filter_by(document_id=dataset_document_id).update(update_params)
         db.session.commit()
         db.session.commit()
 
 
-    def batch_add_segments(self, segments: list[DocumentSegment], dataset: Dataset):
+    @staticmethod
+    def batch_add_segments(segments: list[DocumentSegment], dataset: Dataset):
         """
         """
         Batch add segments index processing
         Batch add segments index processing
         """
         """

+ 6 - 4
api/core/model_manager.py

@@ -44,7 +44,8 @@ class ModelInstance:
             credentials=self.credentials
             credentials=self.credentials
         )
         )
 
 
-    def _fetch_credentials_from_bundle(self, provider_model_bundle: ProviderModelBundle, model: str) -> dict:
+    @staticmethod
+    def _fetch_credentials_from_bundle(provider_model_bundle: ProviderModelBundle, model: str) -> dict:
         """
         """
         Fetch credentials from provider model bundle
         Fetch credentials from provider model bundle
         :param provider_model_bundle: provider model bundle
         :param provider_model_bundle: provider model bundle
@@ -63,7 +64,8 @@ class ModelInstance:
 
 
         return credentials
         return credentials
 
 
-    def _get_load_balancing_manager(self, configuration: ProviderConfiguration,
+    @staticmethod
+    def _get_load_balancing_manager(configuration: ProviderConfiguration,
                                     model_type: ModelType,
                                     model_type: ModelType,
                                     model: str,
                                     model: str,
                                     credentials: dict) -> Optional["LBModelManager"]:
                                     credentials: dict) -> Optional["LBModelManager"]:
@@ -515,8 +517,8 @@ class LBModelManager:
         res = cast(bool, res)
         res = cast(bool, res)
         return res
         return res
 
 
-    @classmethod
-    def get_config_in_cooldown_and_ttl(cls, tenant_id: str,
+    @staticmethod
+    def get_config_in_cooldown_and_ttl(tenant_id: str,
                                        provider: str,
                                        provider: str,
                                        model_type: ModelType,
                                        model_type: ModelType,
                                        model: str,
                                        model: str,

+ 16 - 8
api/core/provider_manager.py

@@ -350,7 +350,8 @@ class ProviderManager:
 
 
         return default_model
         return default_model
 
 
-    def _get_all_providers(self, tenant_id: str) -> dict[str, list[Provider]]:
+    @staticmethod
+    def _get_all_providers(tenant_id: str) -> dict[str, list[Provider]]:
         """
         """
         Get all provider records of the workspace.
         Get all provider records of the workspace.
 
 
@@ -369,7 +370,8 @@ class ProviderManager:
 
 
         return provider_name_to_provider_records_dict
         return provider_name_to_provider_records_dict
 
 
-    def _get_all_provider_models(self, tenant_id: str) -> dict[str, list[ProviderModel]]:
+    @staticmethod
+    def _get_all_provider_models(tenant_id: str) -> dict[str, list[ProviderModel]]:
         """
         """
         Get all provider model records of the workspace.
         Get all provider model records of the workspace.
 
 
@@ -389,7 +391,8 @@ class ProviderManager:
 
 
         return provider_name_to_provider_model_records_dict
         return provider_name_to_provider_model_records_dict
 
 
-    def _get_all_preferred_model_providers(self, tenant_id: str) -> dict[str, TenantPreferredModelProvider]:
+    @staticmethod
+    def _get_all_preferred_model_providers(tenant_id: str) -> dict[str, TenantPreferredModelProvider]:
         """
         """
         Get All preferred provider types of the workspace.
         Get All preferred provider types of the workspace.
 
 
@@ -408,7 +411,8 @@ class ProviderManager:
 
 
         return provider_name_to_preferred_provider_type_records_dict
         return provider_name_to_preferred_provider_type_records_dict
 
 
-    def _get_all_provider_model_settings(self, tenant_id: str) -> dict[str, list[ProviderModelSetting]]:
+    @staticmethod
+    def _get_all_provider_model_settings(tenant_id: str) -> dict[str, list[ProviderModelSetting]]:
         """
         """
         Get All provider model settings of the workspace.
         Get All provider model settings of the workspace.
 
 
@@ -427,7 +431,8 @@ class ProviderManager:
 
 
         return provider_name_to_provider_model_settings_dict
         return provider_name_to_provider_model_settings_dict
 
 
-    def _get_all_provider_load_balancing_configs(self, tenant_id: str) -> dict[str, list[LoadBalancingModelConfig]]:
+    @staticmethod
+    def _get_all_provider_load_balancing_configs(tenant_id: str) -> dict[str, list[LoadBalancingModelConfig]]:
         """
         """
         Get All provider load balancing configs of the workspace.
         Get All provider load balancing configs of the workspace.
 
 
@@ -458,7 +463,8 @@ class ProviderManager:
 
 
         return provider_name_to_provider_load_balancing_model_configs_dict
         return provider_name_to_provider_load_balancing_model_configs_dict
 
 
-    def _init_trial_provider_records(self, tenant_id: str,
+    @staticmethod
+    def _init_trial_provider_records(tenant_id: str,
                                      provider_name_to_provider_records_dict: dict[str, list]) -> dict[str, list]:
                                      provider_name_to_provider_records_dict: dict[str, list]) -> dict[str, list]:
         """
         """
         Initialize trial provider records if not exists.
         Initialize trial provider records if not exists.
@@ -791,7 +797,8 @@ class ProviderManager:
             credentials=current_using_credentials
             credentials=current_using_credentials
         )
         )
 
 
-    def _choice_current_using_quota_type(self, quota_configurations: list[QuotaConfiguration]) -> ProviderQuotaType:
+    @staticmethod
+    def _choice_current_using_quota_type(quota_configurations: list[QuotaConfiguration]) -> ProviderQuotaType:
         """
         """
         Choice current using quota type.
         Choice current using quota type.
         paid quotas > provider free quotas > hosting trial quotas
         paid quotas > provider free quotas > hosting trial quotas
@@ -818,7 +825,8 @@ class ProviderManager:
 
 
         raise ValueError('No quota type available')
         raise ValueError('No quota type available')
 
 
-    def _extract_secret_variables(self, credential_form_schemas: list[CredentialFormSchema]) -> list[str]:
+    @staticmethod
+    def _extract_secret_variables(credential_form_schemas: list[CredentialFormSchema]) -> list[str]:
         """
         """
         Extract secret input form variables.
         Extract secret input form variables.