app_factory.py 6.2 KB


  1. import os
  2. if os.environ.get("DEBUG", "false").lower() != "true":
  3. from gevent import monkey
  4. monkey.patch_all()
  5. import grpc.experimental.gevent
  6. grpc.experimental.gevent.init_gevent()
  7. import json
  8. import logging
  9. import sys
  10. from logging.handlers import RotatingFileHandler
  11. from flask import Flask, Response, request
  12. from flask_cors import CORS
  13. from werkzeug.exceptions import Unauthorized
  14. import contexts
  15. from commands import register_commands
  16. from configs import dify_config
  17. from extensions import (
  18. ext_celery,
  19. ext_code_based_extension,
  20. ext_compress,
  21. ext_database,
  22. ext_hosting_provider,
  23. ext_login,
  24. ext_mail,
  25. ext_migrate,
  26. ext_proxy_fix,
  27. ext_redis,
  28. ext_sentry,
  29. ext_storage,
  30. )
  31. from extensions.ext_database import db
  32. from extensions.ext_login import login_manager
  33. from libs.passport import PassportService
  34. from services.account_service import AccountService
  35. class DifyApp(Flask):
  36. pass
  37. # ----------------------------
  38. # Application Factory Function
  39. # ----------------------------
  40. def create_flask_app_with_configs() -> Flask:
  41. """
  42. create a raw flask app
  43. with configs loaded from .env file
  44. """
  45. dify_app = DifyApp(__name__)
  46. dify_app.config.from_mapping(dify_config.model_dump())
  47. # populate configs into system environment variables
  48. for key, value in dify_app.config.items():
  49. if isinstance(value, str):
  50. os.environ[key] = value
  51. elif isinstance(value, int | float | bool):
  52. os.environ[key] = str(value)
  53. elif value is None:
  54. os.environ[key] = ""
  55. return dify_app
  56. def create_app() -> Flask:
  57. app = create_flask_app_with_configs()
  58. app.secret_key = app.config["SECRET_KEY"]
  59. log_handlers = None
  60. log_file = app.config.get("LOG_FILE")
  61. if log_file:
  62. log_dir = os.path.dirname(log_file)
  63. os.makedirs(log_dir, exist_ok=True)
  64. log_handlers = [
  65. RotatingFileHandler(
  66. filename=log_file,
  67. maxBytes=1024 * 1024 * 1024,
  68. backupCount=5,
  69. ),
  70. logging.StreamHandler(sys.stdout),
  71. ]
  72. logging.basicConfig(
  73. level=app.config.get("LOG_LEVEL"),
  74. format=app.config.get("LOG_FORMAT"),
  75. datefmt=app.config.get("LOG_DATEFORMAT"),
  76. handlers=log_handlers,
  77. force=True,
  78. )
  79. log_tz = app.config.get("LOG_TZ")
  80. if log_tz:
  81. from datetime import datetime
  82. import pytz
  83. timezone = pytz.timezone(log_tz)
  84. def time_converter(seconds):
  85. return datetime.utcfromtimestamp(seconds).astimezone(timezone).timetuple()
  86. for handler in logging.root.handlers:
  87. handler.formatter.converter = time_converter
  88. initialize_extensions(app)
  89. register_blueprints(app)
  90. register_commands(app)
  91. return app
  92. def initialize_extensions(app):
  93. # Since the application instance is now created, pass it to each Flask
  94. # extension instance to bind it to the Flask application instance (app)
  95. ext_compress.init_app(app)
  96. ext_code_based_extension.init()
  97. ext_database.init_app(app)
  98. ext_migrate.init(app, db)
  99. ext_redis.init_app(app)
  100. ext_storage.init_app(app)
  101. ext_celery.init_app(app)
  102. ext_login.init_app(app)
  103. ext_mail.init_app(app)
  104. ext_hosting_provider.init_app(app)
  105. ext_sentry.init_app(app)
  106. ext_proxy_fix.init_app(app)
  107. # Flask-Login configuration
  108. @login_manager.request_loader
  109. def load_user_from_request(request_from_flask_login):
  110. """Load user based on the request."""
  111. if request.blueprint not in {"console", "inner_api"}:
  112. return None
  113. # Check if the user_id contains a dot, indicating the old format
  114. auth_header = request.headers.get("Authorization", "")
  115. if not auth_header:
  116. auth_token = request.args.get("_token")
  117. if not auth_token:
  118. raise Unauthorized("Invalid Authorization token.")
  119. else:
  120. if " " not in auth_header:
  121. raise Unauthorized("Invalid Authorization header format. Expected 'Bearer <api-key>' format.")
  122. auth_scheme, auth_token = auth_header.split(None, 1)
  123. auth_scheme = auth_scheme.lower()
  124. if auth_scheme != "bearer":
  125. raise Unauthorized("Invalid Authorization header format. Expected 'Bearer <api-key>' format.")
  126. decoded = PassportService().verify(auth_token)
  127. user_id = decoded.get("user_id")
  128. logged_in_account = AccountService.load_logged_in_account(account_id=user_id)
  129. if logged_in_account:
  130. contexts.tenant_id.set(logged_in_account.current_tenant_id)
  131. return logged_in_account
  132. @login_manager.unauthorized_handler
  133. def unauthorized_handler():
  134. """Handle unauthorized requests."""
  135. return Response(
  136. json.dumps({"code": "unauthorized", "message": "Unauthorized."}),
  137. status=401,
  138. content_type="application/json",
  139. )
  140. # register blueprint routers
  141. def register_blueprints(app):
  142. from controllers.console import bp as console_app_bp
  143. from controllers.files import bp as files_bp
  144. from controllers.inner_api import bp as inner_api_bp
  145. from controllers.service_api import bp as service_api_bp
  146. from controllers.web import bp as web_bp
  147. CORS(
  148. service_api_bp,
  149. allow_headers=["Content-Type", "Authorization", "X-App-Code"],
  150. methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"],
  151. )
  152. app.register_blueprint(service_api_bp)
  153. CORS(
  154. web_bp,
  155. resources={r"/*": {"origins": app.config["WEB_API_CORS_ALLOW_ORIGINS"]}},
  156. supports_credentials=True,
  157. allow_headers=["Content-Type", "Authorization", "X-App-Code"],
  158. methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"],
  159. expose_headers=["X-Version", "X-Env"],
  160. )
  161. app.register_blueprint(web_bp)
  162. CORS(
  163. console_app_bp,
  164. resources={r"/*": {"origins": app.config["CONSOLE_CORS_ALLOW_ORIGINS"]}},
  165. supports_credentials=True,
  166. allow_headers=["Content-Type", "Authorization"],
  167. methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"],
  168. expose_headers=["X-Version", "X-Env"],
  169. )
  170. app.register_blueprint(console_app_bp)
  171. CORS(files_bp, allow_headers=["Content-Type"], methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"])
  172. app.register_blueprint(files_bp)
  173. app.register_blueprint(inner_api_bp)