Browse Source

refactor: extract db configs and celery configs into dify config (#5491)

Bowen Liang 10 months ago
parent
commit
f67b164b0d

+ 0 - 44
api/config.py

@@ -3,19 +3,6 @@ import os
 import dotenv
 
 DEFAULTS = {
-    'DB_USERNAME': 'postgres',
-    'DB_PASSWORD': '',
-    'DB_HOST': 'localhost',
-    'DB_PORT': '5432',
-    'DB_DATABASE': 'dify',
-    'DB_CHARSET': '',
-    'SQLALCHEMY_DATABASE_URI_SCHEME': 'postgresql',
-    'SQLALCHEMY_POOL_SIZE': 30,
-    'SQLALCHEMY_MAX_OVERFLOW': 10,
-    'SQLALCHEMY_POOL_RECYCLE': 3600,
-    'SQLALCHEMY_POOL_PRE_PING': 'False',
-    'SQLALCHEMY_ECHO': 'False',
-    'CELERY_BACKEND': 'database',
     'HOSTED_OPENAI_QUOTA_LIMIT': 200,
     'HOSTED_OPENAI_TRIAL_ENABLED': 'False',
     'HOSTED_OPENAI_TRIAL_MODELS': 'gpt-3.5-turbo,gpt-3.5-turbo-1106,gpt-3.5-turbo-instruct,gpt-3.5-turbo-16k,gpt-3.5-turbo-16k-0613,gpt-3.5-turbo-0613,gpt-3.5-turbo-0125,text-davinci-003',
@@ -68,37 +55,6 @@ class Config:
             'WEB_API_CORS_ALLOW_ORIGINS', '*')
 
         # ------------------------
-        # Database Configurations.
-        # ------------------------
-        db_credentials = {
-            key: get_env(key) for key in
-            ['DB_USERNAME', 'DB_PASSWORD', 'DB_HOST', 'DB_PORT', 'DB_DATABASE', 'DB_CHARSET']
-        }
-        self.SQLALCHEMY_DATABASE_URI_SCHEME = get_env('SQLALCHEMY_DATABASE_URI_SCHEME')
-
-        db_extras = f"?client_encoding={db_credentials['DB_CHARSET']}" if db_credentials['DB_CHARSET'] else ""
-
-        self.SQLALCHEMY_DATABASE_URI = f"{self.SQLALCHEMY_DATABASE_URI_SCHEME}://{db_credentials['DB_USERNAME']}:{db_credentials['DB_PASSWORD']}@{db_credentials['DB_HOST']}:{db_credentials['DB_PORT']}/{db_credentials['DB_DATABASE']}{db_extras}"
-        self.SQLALCHEMY_ENGINE_OPTIONS = {
-            'pool_size': int(get_env('SQLALCHEMY_POOL_SIZE')),
-            'max_overflow': int(get_env('SQLALCHEMY_MAX_OVERFLOW')),
-            'pool_recycle': int(get_env('SQLALCHEMY_POOL_RECYCLE')),
-            'pool_pre_ping': get_bool_env('SQLALCHEMY_POOL_PRE_PING'),
-            'connect_args': {'options': '-c timezone=UTC'},
-        }
-
-        self.SQLALCHEMY_ECHO = get_bool_env('SQLALCHEMY_ECHO')
-
-        # ------------------------
-        # Celery worker Configurations.
-        # ------------------------
-        self.CELERY_BROKER_URL = get_env('CELERY_BROKER_URL')
-        self.CELERY_BACKEND = get_env('CELERY_BACKEND')
-        self.CELERY_RESULT_BACKEND = 'db+{}'.format(self.SQLALCHEMY_DATABASE_URI) \
-            if self.CELERY_BACKEND == 'database' else self.CELERY_BROKER_URL
-        self.BROKER_USE_SSL = self.CELERY_BROKER_URL.startswith('rediss://') if self.CELERY_BROKER_URL else False
-
-        # ------------------------
         # Platform Configurations.
         # ------------------------
         self.HOSTED_OPENAI_API_KEY = get_env('HOSTED_OPENAI_API_KEY')

+ 108 - 2
api/configs/middleware/__init__.py

@@ -1,6 +1,6 @@
-from typing import Optional
+from typing import Any, Optional
 
-from pydantic import BaseModel, Field
+from pydantic import BaseModel, Field, NonNegativeInt, PositiveInt, computed_field
 
 from configs.middleware.redis_config import RedisConfig
 from configs.middleware.storage.aliyun_oss_storage_config import AliyunOSSStorageConfig
@@ -49,8 +49,114 @@ class KeywordStoreConfigs(BaseModel):
     )
 
 
+class DatabaseConfigs:
+    DB_HOST: str = Field(
+        description='db host',
+        default='localhost',
+    )
+
+    DB_PORT: PositiveInt = Field(
+        description='db port',
+        default=5432,
+    )
+
+    DB_USERNAME: str = Field(
+        description='db username',
+        default='postgres',
+    )
+
+    DB_PASSWORD: str = Field(
+        description='db password',
+        default='',
+    )
+
+    DB_DATABASE: str = Field(
+        description='db database',
+        default='dify',
+    )
+
+    DB_CHARSET: str = Field(
+        description='db charset',
+        default='',
+    )
+
+    SQLALCHEMY_DATABASE_URI_SCHEME: str = Field(
+        description='db uri scheme',
+        default='postgresql',
+    )
+
+    @computed_field
+    @property
+    def SQLALCHEMY_DATABASE_URI(self) -> str:
+        db_extras = f"?client_encoding={self.DB_CHARSET}" if self.DB_CHARSET else ""
+        return (f"{self.SQLALCHEMY_DATABASE_URI_SCHEME}://"
+                f"{self.DB_USERNAME}:{self.DB_PASSWORD}@{self.DB_HOST}:{self.DB_PORT}/{self.DB_DATABASE}"
+                f"{db_extras}")
+
+    SQLALCHEMY_POOL_SIZE: NonNegativeInt = Field(
+        description='pool size of SqlAlchemy',
+        default=30,
+    )
+
+    SQLALCHEMY_MAX_OVERFLOW: NonNegativeInt = Field(
+        description='max overflows for SqlAlchemy',
+        default=10,
+    )
+
+    SQLALCHEMY_POOL_RECYCLE: NonNegativeInt = Field(
+        description='SqlAlchemy pool recycle',
+        default=3600,
+    )
+
+    SQLALCHEMY_POOL_PRE_PING: bool = Field(
+        description='whether to enable pool pre-ping in SqlAlchemy',
+        default=False,
+    )
+
+    SQLALCHEMY_ECHO: bool = Field(
+        description='whether to enable SqlAlchemy echo',
+        default=False,
+    )
+
+    @computed_field
+    @property
+    def SQLALCHEMY_ENGINE_OPTIONS(self) -> dict[str, Any]:
+        return {
+            'pool_size': self.SQLALCHEMY_POOL_SIZE,
+            'max_overflow': self.SQLALCHEMY_MAX_OVERFLOW,
+            'pool_recycle': self.SQLALCHEMY_POOL_RECYCLE,
+            'pool_pre_ping': self.SQLALCHEMY_POOL_PRE_PING,
+            'connect_args': {'options': '-c timezone=UTC'},
+        }
+
+
+class CeleryConfigs(DatabaseConfigs):
+    CELERY_BACKEND: str = Field(
+        description='Celery backend, available values are `database`, `redis`',
+        default='database',
+    )
+
+    CELERY_BROKER_URL: Optional[str] = Field(
+        description='CELERY_BROKER_URL',
+        default=None,
+    )
+
+    @computed_field
+    @property
+    def CELERY_RESULT_BACKEND(self) -> str:
+        return 'db+{}'.format(self.SQLALCHEMY_DATABASE_URI) \
+            if self.CELERY_BACKEND == 'database' else self.CELERY_BROKER_URL
+
+    @computed_field
+    @property
+    def BROKER_USE_SSL(self) -> bool:
+        return self.CELERY_BROKER_URL.startswith('rediss://') if self.CELERY_BROKER_URL else False
+
+
 class MiddlewareConfig(
     # place the configs in alphabet order
+    CeleryConfigs,
+    DatabaseConfigs,
     KeywordStoreConfigs,
     RedisConfig,
 

+ 0 - 1
api/migrations/README

@@ -1,2 +1 @@
 Single-database configuration for Flask.
-

+ 11 - 0
api/tests/unit_tests/configs/test_dify_config.py

@@ -60,3 +60,14 @@ def test_flask_configs(example_env_file):
     assert config['CONSOLE_API_URL'] == 'https://example.com'
     # fallback to alias choices value as CONSOLE_API_URL
     assert config['FILES_URL'] == 'https://example.com'
+
+    assert config['SQLALCHEMY_DATABASE_URI'] == 'postgresql://postgres:@localhost:5432/dify'
+    assert config['SQLALCHEMY_ENGINE_OPTIONS'] == {
+        'connect_args': {
+            'options': '-c timezone=UTC',
+        },
+        'max_overflow': 10,
+        'pool_pre_ping': False,
+        'pool_recycle': 3600,
+        'pool_size': 30,
+    }