Sergio Sacristán преди 6 месеца
родител
ревизия
28de676956
променени са 4 файла, в които са добавени 249 реда и са изтрити 203 реда
  1. 2 203
      api/app.py
  2. 213 0
      api/app_factory.py
  3. 24 0
      api/tests/integration_tests/controllers/app_fixture.py
  4. 10 0
      api/tests/integration_tests/controllers/test_controllers.py

+ 2 - 203
api/app.py

@@ -10,44 +10,19 @@ if os.environ.get("DEBUG", "false").lower() != "true":
     grpc.experimental.gevent.init_gevent()
     grpc.experimental.gevent.init_gevent()
 
 
 import json
 import json
-import logging
-import sys
 import threading
 import threading
 import time
 import time
 import warnings
 import warnings
-from logging.handlers import RotatingFileHandler
 
 
-from flask import Flask, Response, request
-from flask_cors import CORS
-from werkzeug.exceptions import Unauthorized
+from flask import Response
 
 
-import contexts
-from commands import register_commands
-from configs import dify_config
+from app_factory import create_app
 
 
 # DO NOT REMOVE BELOW
 # DO NOT REMOVE BELOW
 from events import event_handlers  # noqa: F401
 from events import event_handlers  # noqa: F401
-from extensions import (
-    ext_celery,
-    ext_code_based_extension,
-    ext_compress,
-    ext_database,
-    ext_hosting_provider,
-    ext_login,
-    ext_mail,
-    ext_migrate,
-    ext_proxy_fix,
-    ext_redis,
-    ext_sentry,
-    ext_storage,
-)
-from extensions.ext_database import db
-from extensions.ext_login import login_manager
-from libs.passport import PassportService
 
 
 # TODO: Find a way to avoid importing models here
 # TODO: Find a way to avoid importing models here
 from models import account, dataset, model, source, task, tool, tools, web  # noqa: F401
 from models import account, dataset, model, source, task, tool, tools, web  # noqa: F401
-from services.account_service import AccountService
 
 
 # DO NOT REMOVE ABOVE
 # DO NOT REMOVE ABOVE
 
 
@@ -60,188 +35,12 @@ if hasattr(time, "tzset"):
     time.tzset()
     time.tzset()
 
 
 
 
-class DifyApp(Flask):
-    pass
-
-
 # -------------
 # -------------
 # Configuration
 # Configuration
 # -------------
 # -------------
-
-
 config_type = os.getenv("EDITION", default="SELF_HOSTED")  # ce edition first
 config_type = os.getenv("EDITION", default="SELF_HOSTED")  # ce edition first
 
 
 
 
-# ----------------------------
-# Application Factory Function
-# ----------------------------
-
-
-def create_flask_app_with_configs() -> Flask:
-    """
-    create a raw flask app
-    with configs loaded from .env file
-    """
-    dify_app = DifyApp(__name__)
-    dify_app.config.from_mapping(dify_config.model_dump())
-
-    # populate configs into system environment variables
-    for key, value in dify_app.config.items():
-        if isinstance(value, str):
-            os.environ[key] = value
-        elif isinstance(value, int | float | bool):
-            os.environ[key] = str(value)
-        elif value is None:
-            os.environ[key] = ""
-
-    return dify_app
-
-
-def create_app() -> Flask:
-    app = create_flask_app_with_configs()
-
-    app.secret_key = app.config["SECRET_KEY"]
-
-    log_handlers = None
-    log_file = app.config.get("LOG_FILE")
-    if log_file:
-        log_dir = os.path.dirname(log_file)
-        os.makedirs(log_dir, exist_ok=True)
-        log_handlers = [
-            RotatingFileHandler(
-                filename=log_file,
-                maxBytes=1024 * 1024 * 1024,
-                backupCount=5,
-            ),
-            logging.StreamHandler(sys.stdout),
-        ]
-
-    logging.basicConfig(
-        level=app.config.get("LOG_LEVEL"),
-        format=app.config.get("LOG_FORMAT"),
-        datefmt=app.config.get("LOG_DATEFORMAT"),
-        handlers=log_handlers,
-        force=True,
-    )
-    log_tz = app.config.get("LOG_TZ")
-    if log_tz:
-        from datetime import datetime
-
-        import pytz
-
-        timezone = pytz.timezone(log_tz)
-
-        def time_converter(seconds):
-            return datetime.utcfromtimestamp(seconds).astimezone(timezone).timetuple()
-
-        for handler in logging.root.handlers:
-            handler.formatter.converter = time_converter
-    initialize_extensions(app)
-    register_blueprints(app)
-    register_commands(app)
-
-    return app
-
-
-def initialize_extensions(app):
-    # Since the application instance is now created, pass it to each Flask
-    # extension instance to bind it to the Flask application instance (app)
-    ext_compress.init_app(app)
-    ext_code_based_extension.init()
-    ext_database.init_app(app)
-    ext_migrate.init(app, db)
-    ext_redis.init_app(app)
-    ext_storage.init_app(app)
-    ext_celery.init_app(app)
-    ext_login.init_app(app)
-    ext_mail.init_app(app)
-    ext_hosting_provider.init_app(app)
-    ext_sentry.init_app(app)
-    ext_proxy_fix.init_app(app)
-
-
-# Flask-Login configuration
-@login_manager.request_loader
-def load_user_from_request(request_from_flask_login):
-    """Load user based on the request."""
-    if request.blueprint not in {"console", "inner_api"}:
-        return None
-    # Check if the user_id contains a dot, indicating the old format
-    auth_header = request.headers.get("Authorization", "")
-    if not auth_header:
-        auth_token = request.args.get("_token")
-        if not auth_token:
-            raise Unauthorized("Invalid Authorization token.")
-    else:
-        if " " not in auth_header:
-            raise Unauthorized("Invalid Authorization header format. Expected 'Bearer <api-key>' format.")
-        auth_scheme, auth_token = auth_header.split(None, 1)
-        auth_scheme = auth_scheme.lower()
-        if auth_scheme != "bearer":
-            raise Unauthorized("Invalid Authorization header format. Expected 'Bearer <api-key>' format.")
-
-    decoded = PassportService().verify(auth_token)
-    user_id = decoded.get("user_id")
-
-    logged_in_account = AccountService.load_logged_in_account(account_id=user_id)
-    if logged_in_account:
-        contexts.tenant_id.set(logged_in_account.current_tenant_id)
-    return logged_in_account
-
-
-@login_manager.unauthorized_handler
-def unauthorized_handler():
-    """Handle unauthorized requests."""
-    return Response(
-        json.dumps({"code": "unauthorized", "message": "Unauthorized."}),
-        status=401,
-        content_type="application/json",
-    )
-
-
-# register blueprint routers
-def register_blueprints(app):
-    from controllers.console import bp as console_app_bp
-    from controllers.files import bp as files_bp
-    from controllers.inner_api import bp as inner_api_bp
-    from controllers.service_api import bp as service_api_bp
-    from controllers.web import bp as web_bp
-
-    CORS(
-        service_api_bp,
-        allow_headers=["Content-Type", "Authorization", "X-App-Code"],
-        methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"],
-    )
-    app.register_blueprint(service_api_bp)
-
-    CORS(
-        web_bp,
-        resources={r"/*": {"origins": app.config["WEB_API_CORS_ALLOW_ORIGINS"]}},
-        supports_credentials=True,
-        allow_headers=["Content-Type", "Authorization", "X-App-Code"],
-        methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"],
-        expose_headers=["X-Version", "X-Env"],
-    )
-
-    app.register_blueprint(web_bp)
-
-    CORS(
-        console_app_bp,
-        resources={r"/*": {"origins": app.config["CONSOLE_CORS_ALLOW_ORIGINS"]}},
-        supports_credentials=True,
-        allow_headers=["Content-Type", "Authorization"],
-        methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"],
-        expose_headers=["X-Version", "X-Env"],
-    )
-
-    app.register_blueprint(console_app_bp)
-
-    CORS(files_bp, allow_headers=["Content-Type"], methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"])
-    app.register_blueprint(files_bp)
-
-    app.register_blueprint(inner_api_bp)
-
-
 # create app
 # create app
 app = create_app()
 app = create_app()
 celery = app.extensions["celery"]
 celery = app.extensions["celery"]

+ 213 - 0
api/app_factory.py

@@ -0,0 +1,213 @@
+import os
+
+if os.environ.get("DEBUG", "false").lower() != "true":
+    from gevent import monkey
+
+    monkey.patch_all()
+
+    import grpc.experimental.gevent
+
+    grpc.experimental.gevent.init_gevent()
+
+import json
+import logging
+import sys
+from logging.handlers import RotatingFileHandler
+
+from flask import Flask, Response, request
+from flask_cors import CORS
+from werkzeug.exceptions import Unauthorized
+
+import contexts
+from commands import register_commands
+from configs import dify_config
+from extensions import (
+    ext_celery,
+    ext_code_based_extension,
+    ext_compress,
+    ext_database,
+    ext_hosting_provider,
+    ext_login,
+    ext_mail,
+    ext_migrate,
+    ext_proxy_fix,
+    ext_redis,
+    ext_sentry,
+    ext_storage,
+)
+from extensions.ext_database import db
+from extensions.ext_login import login_manager
+from libs.passport import PassportService
+from services.account_service import AccountService
+
+
+class DifyApp(Flask):
+    pass
+
+
+# ----------------------------
+# Application Factory Function
+# ----------------------------
+def create_flask_app_with_configs() -> Flask:
+    """
+    create a raw flask app
+    with configs loaded from .env file
+    """
+    dify_app = DifyApp(__name__)
+    dify_app.config.from_mapping(dify_config.model_dump())
+
+    # populate configs into system environment variables
+    for key, value in dify_app.config.items():
+        if isinstance(value, str):
+            os.environ[key] = value
+        elif isinstance(value, int | float | bool):
+            os.environ[key] = str(value)
+        elif value is None:
+            os.environ[key] = ""
+
+    return dify_app
+
+
+def create_app() -> Flask:
+    app = create_flask_app_with_configs()
+
+    app.secret_key = app.config["SECRET_KEY"]
+
+    log_handlers = None
+    log_file = app.config.get("LOG_FILE")
+    if log_file:
+        log_dir = os.path.dirname(log_file)
+        os.makedirs(log_dir, exist_ok=True)
+        log_handlers = [
+            RotatingFileHandler(
+                filename=log_file,
+                maxBytes=1024 * 1024 * 1024,
+                backupCount=5,
+            ),
+            logging.StreamHandler(sys.stdout),
+        ]
+
+    logging.basicConfig(
+        level=app.config.get("LOG_LEVEL"),
+        format=app.config.get("LOG_FORMAT"),
+        datefmt=app.config.get("LOG_DATEFORMAT"),
+        handlers=log_handlers,
+        force=True,
+    )
+    log_tz = app.config.get("LOG_TZ")
+    if log_tz:
+        from datetime import datetime
+
+        import pytz
+
+        timezone = pytz.timezone(log_tz)
+
+        def time_converter(seconds):
+            return datetime.utcfromtimestamp(seconds).astimezone(timezone).timetuple()
+
+        for handler in logging.root.handlers:
+            handler.formatter.converter = time_converter
+    initialize_extensions(app)
+    register_blueprints(app)
+    register_commands(app)
+
+    return app
+
+
+def initialize_extensions(app):
+    # Since the application instance is now created, pass it to each Flask
+    # extension instance to bind it to the Flask application instance (app)
+    ext_compress.init_app(app)
+    ext_code_based_extension.init()
+    ext_database.init_app(app)
+    ext_migrate.init(app, db)
+    ext_redis.init_app(app)
+    ext_storage.init_app(app)
+    ext_celery.init_app(app)
+    ext_login.init_app(app)
+    ext_mail.init_app(app)
+    ext_hosting_provider.init_app(app)
+    ext_sentry.init_app(app)
+    ext_proxy_fix.init_app(app)
+
+
+# Flask-Login configuration
+@login_manager.request_loader
+def load_user_from_request(request_from_flask_login):
+    """Load user based on the request."""
+    if request.blueprint not in {"console", "inner_api"}:
+        return None
+    # Check if the user_id contains a dot, indicating the old format
+    auth_header = request.headers.get("Authorization", "")
+    if not auth_header:
+        auth_token = request.args.get("_token")
+        if not auth_token:
+            raise Unauthorized("Invalid Authorization token.")
+    else:
+        if " " not in auth_header:
+            raise Unauthorized("Invalid Authorization header format. Expected 'Bearer <api-key>' format.")
+        auth_scheme, auth_token = auth_header.split(None, 1)
+        auth_scheme = auth_scheme.lower()
+        if auth_scheme != "bearer":
+            raise Unauthorized("Invalid Authorization header format. Expected 'Bearer <api-key>' format.")
+
+    decoded = PassportService().verify(auth_token)
+    user_id = decoded.get("user_id")
+
+    logged_in_account = AccountService.load_logged_in_account(account_id=user_id)
+    if logged_in_account:
+        contexts.tenant_id.set(logged_in_account.current_tenant_id)
+    return logged_in_account
+
+
+@login_manager.unauthorized_handler
+def unauthorized_handler():
+    """Handle unauthorized requests."""
+    return Response(
+        json.dumps({"code": "unauthorized", "message": "Unauthorized."}),
+        status=401,
+        content_type="application/json",
+    )
+
+
+# register blueprint routers
+def register_blueprints(app):
+    from controllers.console import bp as console_app_bp
+    from controllers.files import bp as files_bp
+    from controllers.inner_api import bp as inner_api_bp
+    from controllers.service_api import bp as service_api_bp
+    from controllers.web import bp as web_bp
+
+    CORS(
+        service_api_bp,
+        allow_headers=["Content-Type", "Authorization", "X-App-Code"],
+        methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"],
+    )
+    app.register_blueprint(service_api_bp)
+
+    CORS(
+        web_bp,
+        resources={r"/*": {"origins": app.config["WEB_API_CORS_ALLOW_ORIGINS"]}},
+        supports_credentials=True,
+        allow_headers=["Content-Type", "Authorization", "X-App-Code"],
+        methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"],
+        expose_headers=["X-Version", "X-Env"],
+    )
+
+    app.register_blueprint(web_bp)
+
+    CORS(
+        console_app_bp,
+        resources={r"/*": {"origins": app.config["CONSOLE_CORS_ALLOW_ORIGINS"]}},
+        supports_credentials=True,
+        allow_headers=["Content-Type", "Authorization"],
+        methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"],
+        expose_headers=["X-Version", "X-Env"],
+    )
+
+    app.register_blueprint(console_app_bp)
+
+    CORS(files_bp, allow_headers=["Content-Type"], methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"])
+    app.register_blueprint(files_bp)
+
+    app.register_blueprint(inner_api_bp)

+ 24 - 0
api/tests/integration_tests/controllers/app_fixture.py

@@ -0,0 +1,24 @@
+import pytest
+
+from app_factory import create_app
+
+mock_user = type(
+    "MockUser",
+    (object,),
+    {
+        "is_authenticated": True,
+        "id": "123",
+        "is_editor": True,
+        "is_dataset_editor": True,
+        "status": "active",
+        "get_id": "123",
+        "current_tenant_id": "9d2074fc-6f86-45a9-b09d-6ecc63b9056b",
+    },
+)
+
+
+@pytest.fixture
+def app():
+    app = create_app()
+    app.config["LOGIN_DISABLED"] = True
+    return app

+ 10 - 0
api/tests/integration_tests/controllers/test_controllers.py

@@ -0,0 +1,10 @@
+from unittest.mock import patch
+
+from app_fixture import app, mock_user
+
+
+def test_post_requires_login(app):
+    with app.test_client() as client:
+        with patch("flask_login.utils._get_user", mock_user):
+            response = client.get("/console/api/data-source/integrates")
+            assert response.status_code == 200