Преглед на файлове

修改知识库列表查询

zhouyuexiang преди 2 месеца
родител
ревизия
a2d80984b0
променени са 3 файла, в които са добавени 110 реда и са изтрити 3 реда
  1. 9 2
      api/controllers/console/datasets/datasets.py
  2. 3 0
      api/fields/dataset_fields.py
  3. 98 1
      api/services/dataset_service.py

+ 9 - 2
api/controllers/console/datasets/datasets.py

@@ -58,13 +58,18 @@ class DatasetListApi(Resource):
         # provider = request.args.get("provider", default="vendor")
         search = request.args.get("keyword", default=None, type=str)
         tag_ids = request.args.getlist("tag_ids")
+        auth_type = request.args.get("authType", default=None, type=int)
+        creator_dept = request.args.get("creatorDept")
+        creator = request.args.get("creator", default=None, type=str)
         category_ids = request.args.getlist("category_ids")
         include_all = request.args.get("include_all", default="false").lower() == "true"
+
         if ids:
             datasets, total = DatasetService.get_datasets_by_ids(ids, current_user.current_tenant_id)
         else:
-            datasets, total = DatasetService.get_datasets(
-                page, limit, current_user.current_tenant_id, current_user, search, tag_ids, category_ids, include_all
+            datasets, total = DatasetService.get_datasets2(
+                page, limit, current_user.current_tenant_id, current_user, search, tag_ids,
+                category_ids, auth_type, creator_dept, creator, include_all
             )
 
         # check embedding setting
@@ -79,6 +84,8 @@ class DatasetListApi(Resource):
 
         data = marshal(datasets, dataset_detail_fields)
         for item in data:
+            # 返回编辑授权
+            item["has_edit_permission"] = DatasetService.has_edit_permission(current_user.id,item["id"])
             # convert embedding_model_provider to plugin standard format
             if item["indexing_technique"] == "high_quality" and item["embedding_model_provider"]:
                 item["embedding_model_provider"] = str(ModelProviderID(item["embedding_model_provider"]))

+ 3 - 0
api/fields/dataset_fields.py

@@ -82,6 +82,9 @@ dataset_detail_fields = {
     "external_retrieval_model": fields.Nested(external_retrieval_model_fields, allow_null=True),
     "doc_metadata": fields.List(fields.Nested(doc_metadata_fields)),
     "built_in_field_enabled": fields.Boolean,
+    "dept_id": fields.String,
+    "edit_auth": fields.Integer,
+    "has_edit_permission": fields.Boolean,
 }
 
 dataset_query_detail_fields = {

+ 98 - 1
api/services/dataset_service.py

@@ -9,7 +9,7 @@ from collections import Counter
 from typing import Any, Optional
 
 from flask_login import current_user  # type: ignore
-from sqlalchemy import func, text
+from sqlalchemy import func, literal, text
 from sqlalchemy.orm import Session
 from werkzeug.exceptions import NotFound
 
@@ -279,6 +279,20 @@ class DatasetService:
         return stats
 
     @staticmethod
+    def has_edit_permission(account_id: str, dataset_id: str) -> bool:
+        result = (
+            db.session.query(DatasetPermissionAll)
+            .filter_by(
+                account_id=account_id,
+                dataset_id=dataset_id,
+                has_edit_permission=True
+            )
+            .first()
+        )
+
+        return result is not None
+
+    @staticmethod
     def get_datasets_edit_permission(dataset_id):
         results = (
             db.session.query(DatasetPermissionAll.account_id, Account.email)
@@ -292,6 +306,7 @@ class DatasetService:
 
         return edit_permission_list
 
+
     @staticmethod
     def get_datasets_read_permission(dataset_id):
         results = (
@@ -748,6 +763,88 @@ class DatasetService:
             "count": 0,
         }
 
+    @staticmethod
+    def get_datasets2(
+            page, per_page, tenant_id=None, user=None,
+            search=None, tag_ids=None, category_ids=None,
+            auth_type=None, creator_dept=None, creator=None, include_all=False
+    ):
+        user_id = user.id
+        queries = []
+        
+        # 1. 创建:row.created_by == loginUserId,
+        query1 = Dataset.query.filter(Dataset.created_by == user_id)
+        # 2. 编辑:row.deptId == loginDeptId && row.editAuth == 2
+        query2 = Dataset.query.join(Account, Dataset.dept_id == Account.dept_id)
+        query2 = query2.filter(Dataset.edit_auth == 2, Account.id == user_id)
+        # 3.授权编辑:row.editUserIds.includes(loginUserId)
+        query3 = Dataset.query.join(
+            DatasetPermissionAll,
+            Dataset.id == DatasetPermissionAll.dataset_id
+        ).filter(
+            DatasetPermissionAll.account_id == user_id,
+            DatasetPermissionAll.has_edit_permission == True
+        )
+        # 4.授权可见:row.lookUserIds.includes(loginUserId)
+        query4 = Dataset.query.join(
+            DatasetPermissionAll,
+            Dataset.id == DatasetPermissionAll.dataset_id
+        ).filter(
+            DatasetPermissionAll.account_id == user_id,
+            DatasetPermissionAll.has_read_permission == True
+        )
+
+        # 根据 auth_type 选择要使用的查询
+        if auth_type is None:
+            # 如果 auth_type 为空,使用所有查询
+            queries = [query1, query2, query3, query4]
+        elif auth_type == 1:
+            queries = [query1]
+        elif auth_type == 2:
+            queries = [query2]
+        elif auth_type == 3:
+            queries = [query3]
+        elif auth_type == 4:
+            queries = [query4]
+
+        # 合并查询
+        if not queries:
+            return [], 0
+            
+        union_query = queries[0]
+        for query in queries[1:]:
+            union_query = union_query.union(query)
+
+        # 添加创建人部门过滤
+        if creator_dept:
+            union_query = union_query.join(Account, Dataset.created_by == Account.id)
+            union_query = union_query.filter(Account.dept_id == literal(str(creator_dept)))
+
+        # 添加创建人过滤
+        if creator:
+            union_query = union_query.filter(Dataset.created_by == creator)
+
+        # 其它过滤
+        if search:
+            union_query = union_query.filter(Dataset.name.ilike(f"%{search}%"))
+        if tag_ids:
+            target_ids = TagService.get_target_ids_by_tag_ids("knowledge", tenant_id, tag_ids)
+            if target_ids:
+                union_query = union_query.filter(Dataset.id.in_(target_ids))
+            else:
+                return [], 0
+        if category_ids:
+            target_ids_by_category_ids = TagService.get_target_ids_by_tag_ids("knowledge_category",
+                                                                              tenant_id, category_ids)
+            if target_ids_by_category_ids:
+                union_query = union_query.filter(Dataset.id.in_(target_ids_by_category_ids))
+            else:
+                return [], 0
+
+        datasets = union_query.order_by(Dataset.created_at.desc()).paginate(page=page, per_page=per_page,
+                                                                            max_per_page=100, error_out=False)
+        return datasets.items, datasets.total
+
 
 class TemplateService:
     DEFAULT_RULES: dict[str, Any] = {