Просмотр исходного кода

refactor: introduce storage factory and speed up api startup by importing storage client on demand (#9086)

Bowen Liang месяцев назад: 6
Родитель
Сommit
b360feb4c1

+ 49 - 31
api/extensions/ext_storage.py

@@ -4,16 +4,9 @@ from typing import Union
 
 from flask import Flask
 
-from extensions.storage.aliyun_storage import AliyunStorage
-from extensions.storage.azure_storage import AzureStorage
-from extensions.storage.baidu_storage import BaiduStorage
-from extensions.storage.google_storage import GoogleStorage
-from extensions.storage.huawei_storage import HuaweiStorage
-from extensions.storage.local_storage import LocalStorage
-from extensions.storage.oci_storage import OCIStorage
-from extensions.storage.s3_storage import S3Storage
-from extensions.storage.tencent_storage import TencentStorage
-from extensions.storage.volcengine_storage import VolcengineStorage
+from configs import dify_config
+from extensions.storage.base_storage import BaseStorage
+from extensions.storage.storage_type import StorageType
 
 
 class Storage:
@@ -21,27 +14,52 @@ class Storage:
         self.storage_runner = None
 
     def init_app(self, app: Flask):
-        storage_type = app.config.get("STORAGE_TYPE")
-        if storage_type == "s3":
-            self.storage_runner = S3Storage(app=app)
-        elif storage_type == "azure-blob":
-            self.storage_runner = AzureStorage(app=app)
-        elif storage_type == "aliyun-oss":
-            self.storage_runner = AliyunStorage(app=app)
-        elif storage_type == "google-storage":
-            self.storage_runner = GoogleStorage(app=app)
-        elif storage_type == "tencent-cos":
-            self.storage_runner = TencentStorage(app=app)
-        elif storage_type == "oci-storage":
-            self.storage_runner = OCIStorage(app=app)
-        elif storage_type == "huawei-obs":
-            self.storage_runner = HuaweiStorage(app=app)
-        elif storage_type == "baidu-obs":
-            self.storage_runner = BaiduStorage(app=app)
-        elif storage_type == "volcengine-tos":
-            self.storage_runner = VolcengineStorage(app=app)
-        else:
-            self.storage_runner = LocalStorage(app=app)
+        storage_factory = self.get_storage_factory(dify_config.STORAGE_TYPE)
+        self.storage_runner = storage_factory(app=app)
+
+    @staticmethod
+    def get_storage_factory(storage_type: str) -> type[BaseStorage]:
+        match storage_type:
+            case StorageType.S3:
+                from extensions.storage.aws_s3_storage import AwsS3Storage
+
+                return AwsS3Storage
+            case StorageType.AZURE_BLOB:
+                from extensions.storage.azure_blob_storage import AzureBlobStorage
+
+                return AzureBlobStorage
+            case StorageType.ALIYUN_OSS:
+                from extensions.storage.aliyun_oss_storage import AliyunOssStorage
+
+                return AliyunOssStorage
+            case StorageType.GOOGLE_STORAGE:
+                from extensions.storage.google_cloud_storage import GoogleCloudStorage
+
+                return GoogleCloudStorage
+            case StorageType.TENCENT_COS:
+                from extensions.storage.tencent_cos_storage import TencentCosStorage
+
+                return TencentCosStorage
+            case StorageType.OCI_STORAGE:
+                from extensions.storage.oracle_oci_storage import OracleOCIStorage
+
+                return OracleOCIStorage
+            case StorageType.HUAWEI_OBS:
+                from extensions.storage.huawei_obs_storage import HuaweiObsStorage
+
+                return HuaweiObsStorage
+            case StorageType.BAIDU_OBS:
+                from extensions.storage.baidu_obs_storage import BaiduObsStorage
+
+                return BaiduObsStorage
+            case StorageType.VOLCENGINE_TOS:
+                from extensions.storage.volcengine_tos_storage import VolcengineTosStorage
+
+                return VolcengineTosStorage
+            case StorageType.LOCAL | _:
+                from extensions.storage.local_fs_storage import LocalFsStorage
+
+                return LocalFsStorage
 
     def save(self, filename, data):
         try:

+ 2 - 2
api/extensions/storage/aliyun_storage.py

@@ -7,8 +7,8 @@ from flask import Flask
 from extensions.storage.base_storage import BaseStorage
 
 
-class AliyunStorage(BaseStorage):
-    """Implementation for aliyun storage."""
+class AliyunOssStorage(BaseStorage):
+    """Implementation for Aliyun OSS storage."""
 
     def __init__(self, app: Flask):
         super().__init__(app)

+ 2 - 2
api/extensions/storage/s3_storage.py

@@ -9,8 +9,8 @@ from flask import Flask
 from extensions.storage.base_storage import BaseStorage
 
 
-class S3Storage(BaseStorage):
-    """Implementation for s3 storage."""
+class AwsS3Storage(BaseStorage):
+    """Implementation for Amazon Web Services S3 storage."""
 
     def __init__(self, app: Flask):
         super().__init__(app)

+ 2 - 2
api/extensions/storage/azure_storage.py

@@ -8,8 +8,8 @@ from extensions.ext_redis import redis_client
 from extensions.storage.base_storage import BaseStorage
 
 
-class AzureStorage(BaseStorage):
-    """Implementation for azure storage."""
+class AzureBlobStorage(BaseStorage):
+    """Implementation for Azure Blob storage."""
 
     def __init__(self, app: Flask):
         super().__init__(app)

+ 2 - 2
api/extensions/storage/baidu_storage.py

@@ -10,8 +10,8 @@ from flask import Flask
 from extensions.storage.base_storage import BaseStorage
 
 
-class BaiduStorage(BaseStorage):
-    """Implementation for baidu obs storage."""
+class BaiduObsStorage(BaseStorage):
+    """Implementation for Baidu OBS storage."""
 
     def __init__(self, app: Flask):
         super().__init__(app)

+ 2 - 2
api/extensions/storage/google_storage.py

@@ -10,8 +10,8 @@ from google.cloud import storage as google_cloud_storage
 from extensions.storage.base_storage import BaseStorage
 
 
-class GoogleStorage(BaseStorage):
-    """Implementation for google storage."""
+class GoogleCloudStorage(BaseStorage):
+    """Implementation for Google Cloud storage."""
 
     def __init__(self, app: Flask):
         super().__init__(app)

+ 2 - 2
api/extensions/storage/huawei_storage.py

@@ -6,8 +6,8 @@ from obs import ObsClient
 from extensions.storage.base_storage import BaseStorage
 
 
-class HuaweiStorage(BaseStorage):
-    """Implementation for huawei obs storage."""
+class HuaweiObsStorage(BaseStorage):
+    """Implementation for Huawei OBS storage."""
 
     def __init__(self, app: Flask):
         super().__init__(app)

+ 2 - 2
api/extensions/storage/local_storage.py

@@ -8,8 +8,8 @@ from flask import Flask
 from extensions.storage.base_storage import BaseStorage
 
 
-class LocalStorage(BaseStorage):
-    """Implementation for local storage."""
+class LocalFsStorage(BaseStorage):
+    """Implementation for local filesystem storage."""
 
     def __init__(self, app: Flask):
         super().__init__(app)

+ 3 - 1
api/extensions/storage/oci_storage.py

@@ -8,7 +8,9 @@ from flask import Flask
 from extensions.storage.base_storage import BaseStorage
 
 
-class OCIStorage(BaseStorage):
+class OracleOCIStorage(BaseStorage):
+    """Implementation for Oracle OCI storage."""
+
     def __init__(self, app: Flask):
         super().__init__(app)
         app_config = self.app.config

+ 14 - 0
api/extensions/storage/storage_type.py

@@ -0,0 +1,14 @@
+from enum import Enum
+
+
+class StorageType(str, Enum):
+    ALIYUN_OSS = "aliyun-oss"
+    AZURE_BLOB = "azure-blob"
+    BAIDU_OBS = "baidu-obs"
+    GOOGLE_STORAGE = "google-storage"
+    HUAWEI_OBS = "huawei-obs"
+    LOCAL = "local"
+    OCI_STORAGE = "oci-storage"
+    S3 = "s3"
+    TENCENT_COS = "tencent-cos"
+    VOLCENGINE_TOS = "volcengine-tos"

+ 2 - 2
api/extensions/storage/tencent_storage.py

@@ -6,8 +6,8 @@ from qcloud_cos import CosConfig, CosS3Client
 from extensions.storage.base_storage import BaseStorage
 
 
-class TencentStorage(BaseStorage):
-    """Implementation for tencent cos storage."""
+class TencentCosStorage(BaseStorage):
+    """Implementation for Tencent Cloud COS storage."""
 
     def __init__(self, app: Flask):
         super().__init__(app)

+ 1 - 1
api/extensions/storage/volcengine_storage.py

@@ -6,7 +6,7 @@ from flask import Flask
 from extensions.storage.base_storage import BaseStorage
 
 
-class VolcengineStorage(BaseStorage):
+class VolcengineTosStorage(BaseStorage):
     """Implementation for Volcengine TOS storage."""
 
     def __init__(self, app: Flask):

+ 1 - 1
api/poetry.lock

@@ -10595,4 +10595,4 @@ cffi = ["cffi (>=1.11)"]
 [metadata]
 lock-version = "2.0"
 python-versions = ">=3.10,<3.13"
-content-hash = "fd183812f910faf4e840267501c571db5d758ad6eb328d106ba6f79a0322a555"
+content-hash = "34ba8efcc67da342036ef075b693f59fdc67d246f40b857c9c1bd6f80c7283bd"

+ 1 - 1
api/pyproject.toml

@@ -113,7 +113,6 @@ authlib = "1.3.1"
 azure-ai-inference = "~1.0.0b3"
 azure-ai-ml = "~1.20.0"
 azure-identity = "1.16.1"
-azure-storage-blob = "12.13.0"
 beautifulsoup4 = "4.12.2"
 boto3 = "1.35.17"
 bs4 = "~0.0.1"
@@ -221,6 +220,7 @@ yfinance = "~0.2.40"
 # Required for storage clients
 ############################################################
 [tool.poetry.group.storage.dependencies]
+azure-storage-blob = "12.13.0"
 bce-python-sdk = "~0.9.23"
 cos-python-sdk-v5 = "1.9.30"
 esdk-obs-python = "3.24.6.1"