浏览代码

refactor: check dependencies

Yeuoly 5 月之前
父节点
当前提交
ee38bd8817

+ 2 - 1
api/controllers/console/__init__.py

@@ -2,7 +2,7 @@ from flask import Blueprint
 
 from libs.external_api import ExternalApi
 
-from .app.app_import import AppImportApi, AppImportConfirmApi
+from .app.app_import import AppImportApi, AppImportCheckDependenciesApi, AppImportConfirmApi
 from .files import FileApi, FilePreviewApi, FileSupportTypeApi
 from .remote_files import RemoteFileInfoApi, RemoteFileUploadApi
 
@@ -21,6 +21,7 @@ api.add_resource(RemoteFileUploadApi, "/remote-files/upload")
 # Import App
 api.add_resource(AppImportApi, "/apps/imports")
 api.add_resource(AppImportConfirmApi, "/apps/imports/<string:import_id>/confirm")
+api.add_resource(AppImportCheckDependenciesApi, "/apps/imports/<string:app_id>/check-dependencies")
 
 # Import other controllers
 from . import admin, apikey, extension, feature, ping, setup, version

+ 20 - 1
api/controllers/console/app/app_import.py

@@ -5,14 +5,16 @@ from flask_restful import Resource, marshal_with, reqparse
 from sqlalchemy.orm import Session
 from werkzeug.exceptions import Forbidden
 
+from controllers.console.app.wraps import get_app_model
 from controllers.console.wraps import (
     account_initialization_required,
     setup_required,
 )
 from extensions.ext_database import db
-from fields.app_fields import app_import_fields
+from fields.app_fields import app_import_check_dependencies_fields, app_import_fields
 from libs.login import login_required
 from models import Account
+from models.model import App
 from services.app_dsl_service import AppDslService, ImportStatus
 
 
@@ -88,3 +90,20 @@ class AppImportConfirmApi(Resource):
         if result.status == ImportStatus.FAILED.value:
             return result.model_dump(mode="json"), 400
         return result.model_dump(mode="json"), 200
+
+
+class AppImportCheckDependenciesApi(Resource):
+    @setup_required
+    @login_required
+    @get_app_model
+    @account_initialization_required
+    @marshal_with(app_import_check_dependencies_fields)
+    def get(self, app_model: App):
+        if not current_user.is_editor:
+            raise Forbidden()
+
+        with Session(db.engine) as session:
+            import_service = AppDslService(session)
+            result = import_service.check_dependencies(app_model=app_model)
+
+        return result.model_dump(mode="json"), 200

+ 3 - 0
api/fields/app_fields.py

@@ -207,5 +207,8 @@ app_import_fields = {
     "current_dsl_version": fields.String,
     "imported_dsl_version": fields.String,
     "error": fields.String,
+}
+
+app_import_check_dependencies_fields = {
     "leaked_dependencies": fields.List(fields.Nested(leaked_dependency_fields)),
 }

+ 5 - 3
api/libs/helper.py

@@ -9,7 +9,7 @@ import uuid
 from collections.abc import Generator
 from datetime import datetime
 from hashlib import sha256
-from typing import Any, Optional, Union
+from typing import Any, Optional, Union, TYPE_CHECKING
 
 from flask import Response, stream_with_context
 from flask_restful import fields
@@ -19,7 +19,9 @@ from configs import dify_config
 from core.app.features.rate_limiting.rate_limit import RateLimitGenerator
 from core.file import helpers as file_helpers
 from extensions.ext_redis import redis_client
-from models.account import Account
+
+if TYPE_CHECKING:
+    from models.account import Account
 
 
 def run(script):
@@ -196,7 +198,7 @@ class TokenManager:
     def generate_token(
         cls,
         token_type: str,
-        account: Optional[Account] = None,
+        account: Optional["Account"] = None,
         email: Optional[str] = None,
         additional_data: Optional[dict] = None,
     ) -> str:

+ 49 - 19
api/services/app_dsl_service.py

@@ -31,7 +31,8 @@ from services.workflow_service import WorkflowService
 logger = logging.getLogger(__name__)
 
 IMPORT_INFO_REDIS_KEY_PREFIX = "app_import_info:"
-IMPORT_INFO_REDIS_EXPIRY = 2 * 60 * 60  # 2 hours
+CHECK_DEPENDENCIES_REDIS_KEY_PREFIX = "app_check_dependencies:"
+IMPORT_INFO_REDIS_EXPIRY = 10 * 60  #  10 minutes
 CURRENT_DSL_VERSION = "0.1.4"
 DSL_MAX_SIZE = 10 * 1024 * 1024  # 10MB
 
@@ -54,10 +55,13 @@ class Import(BaseModel):
     app_id: Optional[str] = None
     current_dsl_version: str = CURRENT_DSL_VERSION
     imported_dsl_version: str = ""
-    leaked_dependencies: list[PluginDependency] = Field(default_factory=list)
     error: str = ""
 
 
+class CheckDependenciesResult(BaseModel):
+    leaked_dependencies: list[PluginDependency] = Field(default_factory=list)
+
+
 def _check_version_compatibility(imported_version: str) -> ImportStatus:
     """Determine import status based on version comparison"""
     try:
@@ -87,6 +91,11 @@ class PendingData(BaseModel):
     app_id: str | None
 
 
+class CheckDependenciesPendingData(BaseModel):
+    dependencies: list[PluginDependency]
+    app_id: str | None
+
+
 class AppDslService:
     def __init__(self, session: Session):
         self._session = session
@@ -243,23 +252,11 @@ class AppDslService:
                     imported_dsl_version=imported_version,
                 )
 
-            try:
-                dependencies = self.get_leaked_dependencies(account.current_tenant_id, data.get("dependencies", []))
-            except Exception as e:
-                return Import(
-                    id=import_id,
-                    status=ImportStatus.FAILED,
-                    error=str(e),
-                )
-
-            if len(dependencies) > 0:
-                return Import(
-                    id=import_id,
-                    status=ImportStatus.PENDING,
-                    app_id=app_id,
-                    imported_dsl_version=imported_version,
-                    leaked_dependencies=dependencies,
-                )
+            # Extract dependencies
+            dependencies = data.get("dependencies", [])
+            check_dependencies_pending_data = None
+            if dependencies:
+                check_dependencies_pending_data = [PluginDependency.model_validate(d) for d in dependencies]
 
             # Create or update app
             app = self._create_or_update_app(
@@ -271,6 +268,7 @@ class AppDslService:
                 icon_type=icon_type,
                 icon=icon,
                 icon_background=icon_background,
+                dependencies=check_dependencies_pending_data,
             )
 
             return Import(
@@ -355,6 +353,29 @@ class AppDslService:
                 error=str(e),
             )
 
+    def check_dependencies(
+        self,
+        *,
+        app_model: App,
+    ) -> CheckDependenciesResult:
+        """Check dependencies"""
+        # Get dependencies from Redis
+        redis_key = f"{CHECK_DEPENDENCIES_REDIS_KEY_PREFIX}{app_model.id}"
+        dependencies = redis_client.get(redis_key)
+        if not dependencies:
+            return CheckDependenciesResult()
+
+        # Extract dependencies
+        dependencies = CheckDependenciesPendingData.model_validate_json(dependencies)
+
+        # Get leaked dependencies
+        leaked_dependencies = DependenciesAnalysisService.get_leaked_dependencies(
+            tenant_id=app_model.tenant_id, dependencies=dependencies.dependencies
+        )
+        return CheckDependenciesResult(
+            leaked_dependencies=leaked_dependencies,
+        )
+
     def _create_or_update_app(
         self,
         *,
@@ -366,6 +387,7 @@ class AppDslService:
         icon_type: Optional[str] = None,
         icon: Optional[str] = None,
         icon_background: Optional[str] = None,
+        dependencies: Optional[list[PluginDependency]] = None,
     ) -> App:
         """Create a new app or update an existing one."""
         app_data = data.get("app", {})
@@ -408,6 +430,14 @@ class AppDslService:
             self._session.commit()
             app_was_created.send(app, account=account)
 
+        # save dependencies
+        if dependencies:
+            redis_client.setex(
+                f"{CHECK_DEPENDENCIES_REDIS_KEY_PREFIX}{app.id}",
+                IMPORT_INFO_REDIS_EXPIRY,
+                CheckDependenciesPendingData(app_id=app.id, dependencies=dependencies).model_dump_json(),
+            )
+
         # Initialize app based on mode
         if app_mode in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}:
             workflow_data = data.get("workflow")