Pārlūkot izejas kodu

feat: support plugin permission management

Yeuoly 7 mēneši atpakaļ
vecāks
revīzija
c657378d06

+ 56 - 69
api/controllers/console/datasets/datasets_document.py

@@ -5,8 +5,7 @@ from datetime import datetime, timezone
 from flask import request
 from flask_login import current_user
 from flask_restful import Resource, fields, marshal, marshal_with, reqparse
-from sqlalchemy import asc, desc, select
-from sqlalchemy.orm import Session
+from sqlalchemy import asc, desc
 from werkzeug.exceptions import Forbidden, NotFound
 
 import services
@@ -105,8 +104,7 @@ class GetProcessRuleApi(Resource):
         rules = DocumentService.DEFAULT_RULES["rules"]
         if document_id:
             # get the latest process rule
-            with Session(db.engine) as session:
-                document = session.execute(select(Document).get_or_404(document_id)).scalar_one_or_none()
+            document = Document.query.get_or_404(document_id)
 
             dataset = DatasetService.get_dataset(document.dataset_id)
 
@@ -169,77 +167,66 @@ class DatasetDocumentListApi(Resource):
         except services.errors.account.NoPermissionError as e:
             raise Forbidden(str(e))
 
-        with Session(db.engine) as session:
-            query = session.query(Document).filter_by(
-                dataset_id=str(dataset_id), tenant_id=current_user.current_tenant_id
-            )
+        query = Document.query.filter_by(dataset_id=str(dataset_id), tenant_id=current_user.current_tenant_id)
 
-            if search:
-                search = f"%{search}%"
-                query = query.filter(Document.name.like(search))
+        if search:
+            search = f"%{search}%"
+            query = query.filter(Document.name.like(search))
 
-            if sort.startswith("-"):
-                sort_logic = desc
-                sort = sort[1:]
-            else:
-                sort_logic = asc
+        if sort.startswith("-"):
+            sort_logic = desc
+            sort = sort[1:]
+        else:
+            sort_logic = asc
 
-            if sort == "hit_count":
-                sub_query = (
-                    db.select(
-                        DocumentSegment.document_id, db.func.sum(DocumentSegment.hit_count).label("total_hit_count")
-                    )
-                    .group_by(DocumentSegment.document_id)
-                    .subquery()
-                )
+        if sort == "hit_count":
+            sub_query = (
+                db.select(DocumentSegment.document_id, db.func.sum(DocumentSegment.hit_count).label("total_hit_count"))
+                .group_by(DocumentSegment.document_id)
+                .subquery()
+            )
 
-                query = query.outerjoin(sub_query, sub_query.c.document_id == Document.id).order_by(
-                    sort_logic(db.func.coalesce(sub_query.c.total_hit_count, 0)),
-                    sort_logic(Document.position),
-                )
-            elif sort == "created_at":
-                query = query.order_by(
-                    sort_logic(Document.created_at),
-                    sort_logic(Document.position),
-                )
-            else:
-                query = query.order_by(
-                    desc(Document.created_at),
-                    desc(Document.position),
-                )
+            query = query.outerjoin(sub_query, sub_query.c.document_id == Document.id).order_by(
+                sort_logic(db.func.coalesce(sub_query.c.total_hit_count, 0)),
+                sort_logic(Document.position),
+            )
+        elif sort == "created_at":
+            query = query.order_by(
+                sort_logic(Document.created_at),
+                sort_logic(Document.position),
+            )
+        else:
+            query = query.order_by(
+                desc(Document.created_at),
+                desc(Document.position),
+            )
 
-            paginated_documents = query.paginate(page=page, per_page=limit, max_per_page=100, error_out=False)
-            documents = paginated_documents.items
-            if fetch:
-                for document in documents:
-                    completed_segments = (
-                        session.query(DocumentSegment)
-                        .filter(
-                            DocumentSegment.completed_at.isnot(None),
-                            DocumentSegment.document_id == str(document.id),
-                            DocumentSegment.status != "re_segment",
-                        )
-                        .count()
-                    )
-                    total_segments = (
-                        session.query(DocumentSegment)
-                        .filter(DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment")
-                        .count()
-                    )
-                    document.completed_segments = completed_segments
-                    document.total_segments = total_segments
-                data = marshal(documents, document_with_segments_fields)
-            else:
-                data = marshal(documents, document_fields)
-            response = {
-                "data": data,
-                "has_more": len(documents) == limit,
-                "limit": limit,
-                "total": paginated_documents.total,
-                "page": page,
-            }
+        paginated_documents = query.paginate(page=page, per_page=limit, max_per_page=100, error_out=False)
+        documents = paginated_documents.items
+        if fetch:
+            for document in documents:
+                completed_segments = DocumentSegment.query.filter(
+                    DocumentSegment.completed_at.isnot(None),
+                    DocumentSegment.document_id == str(document.id),
+                    DocumentSegment.status != "re_segment",
+                ).count()
+                total_segments = DocumentSegment.query.filter(
+                    DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment"
+                ).count()
+                document.completed_segments = completed_segments
+                document.total_segments = total_segments
+            data = marshal(documents, document_with_segments_fields)
+        else:
+            data = marshal(documents, document_fields)
+        response = {
+            "data": data,
+            "has_more": len(documents) == limit,
+            "limit": limit,
+            "total": paginated_documents.total,
+            "page": page,
+        }
 
-            return response
+        return response
 
     documents_and_batch_fields = {"documents": fields.List(fields.Nested(document_fields)), "batch": fields.String}
 

+ 56 - 0
api/controllers/console/workspace/__init__.py

@@ -0,0 +1,56 @@
+from functools import wraps
+
+from flask_login import current_user
+from sqlalchemy.orm import Session
+from werkzeug.exceptions import Forbidden
+
+from extensions.ext_database import db
+from models.account import TenantPluginPermission
+
+
+def plugin_permission_required(
+    install_required: bool = False,
+    debug_required: bool = False,
+):
+    def interceptor(view):
+        @wraps(view)
+        def decorated(*args, **kwargs):
+            user = current_user
+            tenant_id = user.current_tenant_id
+
+            with Session(db.engine) as session:
+                permission = (
+                    session.query(TenantPluginPermission)
+                    .filter(
+                        TenantPluginPermission.tenant_id == tenant_id,
+                    )
+                    .first()
+                )
+
+                if not permission:
+                    # no permission set, allow access for everyone
+                    return view(*args, **kwargs)
+
+                if install_required:
+                    if permission.install_permission == TenantPluginPermission.InstallPermission.NOBODY:
+                        raise Forbidden()
+                    if permission.install_permission == TenantPluginPermission.InstallPermission.ADMINS:
+                        if not user.is_admin_or_owner:
+                            raise Forbidden()
+                    if permission.install_permission == TenantPluginPermission.InstallPermission.EVERYONE:
+                        pass
+
+                if debug_required:
+                    if permission.debug_permission == TenantPluginPermission.DebugPermission.NOBODY:
+                        raise Forbidden()
+                    if permission.debug_permission == TenantPluginPermission.DebugPermission.ADMINS:
+                        if not user.is_admin_or_owner:
+                            raise Forbidden()
+                    if permission.debug_permission == TenantPluginPermission.DebugPermission.EVERYONE:
+                        pass
+
+            return view(*args, **kwargs)
+
+        return decorated
+
+    return interceptor

+ 78 - 82
api/controllers/console/workspace/plugin.py

@@ -8,9 +8,12 @@ from werkzeug.exceptions import Forbidden
 from configs import dify_config
 from controllers.console import api
 from controllers.console.setup import setup_required
+from controllers.console.workspace import plugin_permission_required
 from controllers.console.wraps import account_initialization_required
 from core.model_runtime.utils.encoders import jsonable_encoder
 from libs.login import login_required
+from models.account import TenantPluginPermission
+from services.plugin.plugin_permission_service import PluginPermissionService
 from services.plugin.plugin_service import PluginService
 
 
@@ -18,12 +21,9 @@ class PluginDebuggingKeyApi(Resource):
     @setup_required
     @login_required
     @account_initialization_required
+    @plugin_permission_required(debug_required=True)
     def get(self):
-        user = current_user
-        if not user.is_admin_or_owner:
-            raise Forbidden()
-
-        tenant_id = user.current_tenant_id
+        tenant_id = current_user.current_tenant_id
 
         return {
             "key": PluginService.get_debugging_key(tenant_id),
@@ -37,8 +37,7 @@ class PluginListApi(Resource):
     @login_required
     @account_initialization_required
     def get(self):
-        user = current_user
-        tenant_id = user.current_tenant_id
+        tenant_id = current_user.current_tenant_id
         plugins = PluginService.list(tenant_id)
         return jsonable_encoder({"plugins": plugins})
 
@@ -57,32 +56,13 @@ class PluginIconApi(Resource):
         return send_file(io.BytesIO(icon_bytes), mimetype=mimetype, max_age=icon_cache_max_age)
 
 
-class PluginUploadPkgApi(Resource):
-    @setup_required
-    @login_required
-    @account_initialization_required
-    def post(self):
-        user = current_user
-        if not user.is_admin_or_owner:
-            raise Forbidden()
-
-        tenant_id = user.current_tenant_id
-        file = request.files["pkg"]
-        content = file.read()
-
-        return jsonable_encoder(PluginService.upload_pkg(tenant_id, content))
-
-
 class PluginUploadFromPkgApi(Resource):
     @setup_required
     @login_required
     @account_initialization_required
+    @plugin_permission_required(install_required=True)
     def post(self):
-        user = current_user
-        if not user.is_admin_or_owner:
-            raise Forbidden()
-
-        tenant_id = user.current_tenant_id
+        tenant_id = current_user.current_tenant_id
 
         file = request.files["pkg"]
 
@@ -100,12 +80,9 @@ class PluginUploadFromGithubApi(Resource):
     @setup_required
     @login_required
     @account_initialization_required
+    @plugin_permission_required(install_required=True)
     def post(self):
-        user = current_user
-        if not user.is_admin_or_owner:
-            raise Forbidden()
-
-        tenant_id = user.current_tenant_id
+        tenant_id = current_user.current_tenant_id
 
         parser = reqparse.RequestParser()
         parser.add_argument("repo", type=str, required=True, location="json")
@@ -124,12 +101,9 @@ class PluginInstallFromPkgApi(Resource):
     @setup_required
     @login_required
     @account_initialization_required
+    @plugin_permission_required(install_required=True)
     def post(self):
-        user = current_user
-        if not user.is_admin_or_owner:
-            raise Forbidden()
-
-        tenant_id = user.current_tenant_id
+        tenant_id = current_user.current_tenant_id
 
         parser = reqparse.RequestParser()
         parser.add_argument("plugin_unique_identifiers", type=list, required=True, location="json")
@@ -149,12 +123,9 @@ class PluginInstallFromGithubApi(Resource):
     @setup_required
     @login_required
     @account_initialization_required
+    @plugin_permission_required(install_required=True)
     def post(self):
-        user = current_user
-        if not user.is_admin_or_owner:
-            raise Forbidden()
-
-        tenant_id = user.current_tenant_id
+        tenant_id = current_user.current_tenant_id
 
         parser = reqparse.RequestParser()
         parser.add_argument("repo", type=str, required=True, location="json")
@@ -178,12 +149,9 @@ class PluginInstallFromMarketplaceApi(Resource):
     @setup_required
     @login_required
     @account_initialization_required
+    @plugin_permission_required(install_required=True)
     def post(self):
-        user = current_user
-        if not user.is_admin_or_owner:
-            raise Forbidden()
-
-        tenant_id = user.current_tenant_id
+        tenant_id = current_user.current_tenant_id
 
         parser = reqparse.RequestParser()
         parser.add_argument("plugin_unique_identifiers", type=list, required=True, location="json")
@@ -203,15 +171,14 @@ class PluginFetchManifestApi(Resource):
     @setup_required
     @login_required
     @account_initialization_required
+    @plugin_permission_required(debug_required=True)
     def get(self):
-        user = current_user
+        tenant_id = current_user.current_tenant_id
 
         parser = reqparse.RequestParser()
         parser.add_argument("plugin_unique_identifier", type=str, required=True, location="args")
         args = parser.parse_args()
 
-        tenant_id = user.current_tenant_id
-
         return jsonable_encoder(
             {"manifest": PluginService.fetch_plugin_manifest(tenant_id, args["plugin_unique_identifier"]).model_dump()}
         )
@@ -221,12 +188,9 @@ class PluginFetchInstallTasksApi(Resource):
     @setup_required
     @login_required
     @account_initialization_required
+    @plugin_permission_required(debug_required=True)
     def get(self):
-        user = current_user
-        if not user.is_admin_or_owner:
-            raise Forbidden()
-
-        tenant_id = user.current_tenant_id
+        tenant_id = current_user.current_tenant_id
 
         parser = reqparse.RequestParser()
         parser.add_argument("page", type=int, required=True, location="args")
@@ -242,12 +206,9 @@ class PluginFetchInstallTaskApi(Resource):
     @setup_required
     @login_required
     @account_initialization_required
+    @plugin_permission_required(debug_required=True)
     def get(self, task_id: str):
-        user = current_user
-        if not user.is_admin_or_owner:
-            raise Forbidden()
-
-        tenant_id = user.current_tenant_id
+        tenant_id = current_user.current_tenant_id
 
         return jsonable_encoder({"task": PluginService.fetch_install_task(tenant_id, task_id)})
 
@@ -256,12 +217,9 @@ class PluginDeleteInstallTaskApi(Resource):
     @setup_required
     @login_required
     @account_initialization_required
+    @plugin_permission_required(debug_required=True)
     def post(self, task_id: str):
-        user = current_user
-        if not user.is_admin_or_owner:
-            raise Forbidden()
-
-        tenant_id = user.current_tenant_id
+        tenant_id = current_user.current_tenant_id
 
         return {"success": PluginService.delete_install_task(tenant_id, task_id)}
 
@@ -270,12 +228,9 @@ class PluginDeleteInstallTaskItemApi(Resource):
     @setup_required
     @login_required
     @account_initialization_required
+    @plugin_permission_required(debug_required=True)
     def post(self, task_id: str, identifier: str):
-        user = current_user
-        if not user.is_admin_or_owner:
-            raise Forbidden()
-
-        tenant_id = user.current_tenant_id
+        tenant_id = current_user.current_tenant_id
 
         return {"success": PluginService.delete_install_task_item(tenant_id, task_id, identifier)}
 
@@ -284,12 +239,9 @@ class PluginUpgradeFromMarketplaceApi(Resource):
     @setup_required
     @login_required
     @account_initialization_required
+    @plugin_permission_required(debug_required=True)
     def post(self):
-        user = current_user
-        if not user.is_admin_or_owner:
-            raise Forbidden()
-
-        tenant_id = user.current_tenant_id
+        tenant_id = current_user.current_tenant_id
 
         parser = reqparse.RequestParser()
         parser.add_argument("original_plugin_unique_identifier", type=str, required=True, location="json")
@@ -307,12 +259,9 @@ class PluginUpgradeFromGithubApi(Resource):
     @setup_required
     @login_required
     @account_initialization_required
+    @plugin_permission_required(debug_required=True)
     def post(self):
-        user = current_user
-        if not user.is_admin_or_owner:
-            raise Forbidden()
-
-        tenant_id = user.current_tenant_id
+        tenant_id = current_user.current_tenant_id
 
         parser = reqparse.RequestParser()
         parser.add_argument("original_plugin_unique_identifier", type=str, required=True, location="json")
@@ -338,18 +287,62 @@ class PluginUninstallApi(Resource):
     @setup_required
     @login_required
     @account_initialization_required
+    @plugin_permission_required(debug_required=True)
     def post(self):
         req = reqparse.RequestParser()
         req.add_argument("plugin_installation_id", type=str, required=True, location="json")
         args = req.parse_args()
 
+        tenant_id = current_user.current_tenant_id
+
+        return {"success": PluginService.uninstall(tenant_id, args["plugin_installation_id"])}
+
+
+class PluginChangePermissionApi(Resource):
+    @setup_required
+    @login_required
+    @account_initialization_required
+    @plugin_permission_required(debug_required=True)
+    def post(self):
         user = current_user
         if not user.is_admin_or_owner:
             raise Forbidden()
 
+        req = reqparse.RequestParser()
+        req.add_argument("install_permission", type=str, required=True, location="json")
+        req.add_argument("debug_permission", type=str, required=True, location="json")
+        args = req.parse_args()
+
+        install_permission = TenantPluginPermission.InstallPermission(args["install_permission"])
+        debug_permission = TenantPluginPermission.DebugPermission(args["debug_permission"])
+
         tenant_id = user.current_tenant_id
 
-        return {"success": PluginService.uninstall(tenant_id, args["plugin_installation_id"])}
+        return {"success": PluginPermissionService.change_permission(tenant_id, install_permission, debug_permission)}
+
+
+class PluginFetchPermissionApi(Resource):
+    @setup_required
+    @login_required
+    @account_initialization_required
+    def get(self):
+        tenant_id = current_user.current_tenant_id
+
+        permission = PluginPermissionService.get_permission(tenant_id)
+        if not permission:
+            return jsonable_encoder(
+                {
+                    "install_permission": TenantPluginPermission.InstallPermission.EVERYONE,
+                    "debug_permission": TenantPluginPermission.DebugPermission.EVERYONE,
+                }
+            )
+
+        return jsonable_encoder(
+            {
+                "install_permission": permission.install_permission,
+                "debug_permission": permission.debug_permission,
+            }
+        )
 
 
 api.add_resource(PluginDebuggingKeyApi, "/workspaces/current/plugin/debugging-key")
@@ -368,3 +361,6 @@ api.add_resource(PluginFetchInstallTaskApi, "/workspaces/current/plugin/tasks/<t
 api.add_resource(PluginDeleteInstallTaskApi, "/workspaces/current/plugin/tasks/<task_id>/delete")
 api.add_resource(PluginDeleteInstallTaskItemApi, "/workspaces/current/plugin/tasks/<task_id>/delete/<path:identifier>")
 api.add_resource(PluginUninstallApi, "/workspaces/current/plugin/uninstall")
+
+api.add_resource(PluginChangePermissionApi, "/workspaces/current/plugin/permission/change")
+api.add_resource(PluginFetchPermissionApi, "/workspaces/current/plugin/permission/fetch")

+ 37 - 0
api/migrations/versions/2024_10_28_0720-08ec4f75af5e_add_tenant_plugin_permisisons.py

@@ -0,0 +1,37 @@
+"""add_tenant_plugin_permisisons
+
+Revision ID: 08ec4f75af5e
+Revises: ddcc8bbef391
+Create Date: 2024-10-28 07:20:39.711124
+
+"""
+from alembic import op
+import models as models
+import sqlalchemy as sa
+from sqlalchemy.dialects import postgresql
+
+# revision identifiers, used by Alembic.
+revision = '08ec4f75af5e'
+down_revision = 'ddcc8bbef391'
+branch_labels = None
+depends_on = None
+
+
+def upgrade():
+    # ### commands auto generated by Alembic - please adjust! ###
+    op.create_table('account_plugin_permissions',
+    sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
+    sa.Column('tenant_id', models.types.StringUUID(), nullable=False),
+    sa.Column('install_permission', sa.String(length=16), server_default='everyone', nullable=False),
+    sa.Column('debug_permission', sa.String(length=16), server_default='noone', nullable=False),
+    sa.PrimaryKeyConstraint('id', name='account_plugin_permission_pkey'),
+    sa.UniqueConstraint('tenant_id', name='unique_tenant_plugin')
+    )
+
+    # ### end Alembic commands ###
+
+
+def downgrade():
+    # ### commands auto generated by Alembic - please adjust! ###
+    op.drop_table('account_plugin_permissions')
+    # ### end Alembic commands ###

+ 26 - 0
api/models/account.py

@@ -2,6 +2,7 @@ import enum
 import json
 
 from flask_login import UserMixin
+from sqlalchemy.orm import Mapped, mapped_column
 
 from extensions.ext_database import db
 from models.base import Base
@@ -260,3 +261,28 @@ class InvitationCode(db.Model):
     used_by_account_id = db.Column(StringUUID)
     deprecated_at = db.Column(db.DateTime)
     created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
+
+
+class TenantPluginPermission(Base):
+    class InstallPermission(str, enum.Enum):
+        EVERYONE = "everyone"
+        ADMINS = "admins"
+        NOBODY = "noone"
+
+    class DebugPermission(str, enum.Enum):
+        EVERYONE = "everyone"
+        ADMINS = "admins"
+        NOBODY = "noone"
+
+    __tablename__ = "account_plugin_permissions"
+    __table_args__ = (
+        db.PrimaryKeyConstraint("id", name="account_plugin_permission_pkey"),
+        db.UniqueConstraint("tenant_id", name="unique_tenant_plugin"),
+    )
+
+    id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
+    tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
+    install_permission: Mapped[InstallPermission] = mapped_column(
+        db.String(16), nullable=False, server_default="everyone"
+    )
+    debug_permission: Mapped[DebugPermission] = mapped_column(db.String(16), nullable=False, server_default="noone")

+ 34 - 0
api/services/plugin/plugin_permission_service.py

@@ -0,0 +1,34 @@
+from sqlalchemy.orm import Session
+
+from extensions.ext_database import db
+from models.account import TenantPluginPermission
+
+
+class PluginPermissionService:
+    @staticmethod
+    def get_permission(tenant_id: str) -> TenantPluginPermission | None:
+        with Session(db.engine) as session:
+            return session.query(TenantPluginPermission).filter(TenantPluginPermission.tenant_id == tenant_id).first()
+
+    @staticmethod
+    def change_permission(
+        tenant_id: str,
+        install_permission: TenantPluginPermission.InstallPermission,
+        debug_permission: TenantPluginPermission.DebugPermission,
+    ):
+        with Session(db.engine) as session:
+            permission = (
+                session.query(TenantPluginPermission).filter(TenantPluginPermission.tenant_id == tenant_id).first()
+            )
+            if not permission:
+                permission = TenantPluginPermission(
+                    tenant_id=tenant_id, install_permission=install_permission, debug_permission=debug_permission
+                )
+
+                session.add(permission)
+            else:
+                permission.install_permission = install_permission
+                permission.debug_permission = debug_permission
+
+            session.commit()
+            return True