Browse Source

标签改造,涉及以下内容:
1.添加标签类型knowledge_category,代表知识库的类型;
2.改造tag相关接口,适配深圳智能问答dify改造需求
3.知识库实体添加关联类型列表,并改造知识库分页查询接口,支持根据类型查询分页查询

liangxunge 3 weeks ago
parent
commit
4b982b0955

+ 2 - 1
api/controllers/console/datasets/datasets.py

@@ -58,12 +58,13 @@ class DatasetListApi(Resource):
         # provider = request.args.get("provider", default="vendor")
         search = request.args.get("keyword", default=None, type=str)
         tag_ids = request.args.getlist("tag_ids")
+        category_ids = request.args.getlist("category_ids")
         include_all = request.args.get("include_all", default="false").lower() == "true"
         if ids:
             datasets, total = DatasetService.get_datasets_by_ids(ids, current_user.current_tenant_id)
         else:
             datasets, total = DatasetService.get_datasets(
-                page, limit, current_user.current_tenant_id, current_user, search, tag_ids, include_all
+                page, limit, current_user.current_tenant_id, current_user, search, tag_ids, category_ids, include_all
             )
 
         # check embedding setting

+ 12 - 0
api/fields/dataset_fields.py

@@ -76,6 +76,7 @@ dataset_detail_fields = {
     "embedding_available": fields.Boolean,
     "retrieval_model_dict": fields.Nested(dataset_retrieval_model_fields),
     "tags": fields.List(fields.Nested(tag_fields)),
+    "categories": fields.List(fields.Nested(tag_fields)),
     "doc_form": fields.String,
     "external_knowledge_info": fields.Nested(external_knowledge_info_fields),
     "external_retrieval_model": fields.Nested(external_retrieval_model_fields, allow_null=True),
@@ -98,3 +99,14 @@ dataset_metadata_fields = {
     "type": fields.String,
     "name": fields.String,
 }
+
+dataset_category_fields = {
+    "id": fields.String,
+    "name": fields.String,
+    "description": fields.String,
+    "tenant_id": fields.String,
+    "created_by": fields.String,
+    "created_at": TimestampField,
+    "updated_by": fields.String,
+    "updated_at": TimestampField,
+}

+ 16 - 0
api/models/dataset.py

@@ -177,6 +177,22 @@ class Dataset(db.Model):  # type: ignore[name-defined]
         return tags or []
 
     @property
+    def categories(self):
+        categories = (
+            db.session.query(Tag)
+            .join(TagBinding, Tag.id == TagBinding.tag_id)
+            .filter(
+                TagBinding.target_id == self.id,
+                TagBinding.tenant_id == self.tenant_id,
+                Tag.tenant_id == self.tenant_id,
+                Tag.type == "knowledge_category",
+            )
+            .all()
+        )
+
+        return categories or []
+
+    @property
     def external_knowledge_info(self):
         if self.provider != "external":
             return None

+ 1 - 1
api/models/model.py

@@ -1766,7 +1766,7 @@ class Tag(Base):
         db.Index("tag_name_idx", "name"),
     )
 
-    TAG_TYPE_LIST = ["knowledge", "app"]
+    TAG_TYPE_LIST = ["knowledge", "app", "knowledge_category"]
 
     id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()"))
     tenant_id = db.Column(StringUUID, nullable=True)

+ 9 - 1
api/services/dataset_service.py

@@ -76,7 +76,8 @@ from tasks.sync_website_document_indexing_task import sync_website_document_inde
 
 class DatasetService:
     @staticmethod
-    def get_datasets(page, per_page, tenant_id=None, user=None, search=None, tag_ids=None, include_all=False):
+    def get_datasets(page, per_page, tenant_id=None, user=None, search=None, tag_ids=None, category_ids=None,
+                     include_all=False):
         query = Dataset.query.filter(Dataset.tenant_id == tenant_id).order_by(Dataset.created_at.desc())
 
         if user:
@@ -129,6 +130,13 @@ class DatasetService:
             else:
                 return [], 0
 
+        if category_ids:
+            target_ids_by_category_ids = TagService.get_target_ids_by_tag_ids("knowledge_category", tenant_id,
+                                                                              category_ids)
+            if target_ids_by_category_ids:
+                query = query.filter(Dataset.id.in_(target_ids_by_category_ids))
+            return [], 0
+
         datasets = query.paginate(page=page, per_page=per_page, max_per_page=100, error_out=False)
 
         return datasets.items, datasets.total

+ 5 - 0
api/services/errors/tag.py

@@ -0,0 +1,5 @@
+from services.errors.base import BaseServiceError
+
+
+class TagNameDuplicateError(BaseServiceError):
+    pass

+ 35 - 4
api/services/tag_service.py

@@ -8,6 +8,7 @@ from werkzeug.exceptions import NotFound
 from extensions.ext_database import db
 from models.dataset import Dataset
 from models.model import App, Tag, TagBinding
+from services.errors.tag import TagNameDuplicateError
 
 
 class TagService:
@@ -25,6 +26,12 @@ class TagService:
         return results
 
     @staticmethod
+    def get_tag_by_tag_name(tag_type: str, tenant_id: str, tag_name: str) -> Optional[Tag]:
+        tag: Optional[Tag] = db.session.query(Tag).filter(Tag.type == tag_type, Tag.tenant_id == tenant_id,
+                                                          Tag.name == tag_name).first()
+        return tag
+
+    @staticmethod
     def get_target_ids_by_tag_ids(tag_type: str, current_tenant_id: str, tag_ids: list) -> list:
         tags = (
             db.session.query(Tag)
@@ -61,13 +68,37 @@ class TagService:
         return tags or []
 
     @staticmethod
+    def get_tag(tag_id: str):
+        tag = db.session.query(Tag).filter(Tag.id == tag_id).first()
+        if not tag:
+            raise NotFound("Tag not found")
+        return tag
+
+    @staticmethod
+    def get_page_tags(page, per_page, tag_type, tenant_id):
+        query = (
+            db.session.query(Tag)
+            .filter(Tag.tenant_id == tenant_id, Tag.type == tag_type)
+            .order_by(Tag.created_at.desc())
+        )
+        tags = query.paginate(page=page, per_page=per_page, error_out=False)
+        return tags.items, tags.total
+
+    @staticmethod
     def save_tags(args: dict) -> Tag:
+        name = args["name"]
+        type = args["type"]
+        tenant_id = current_user.current_tenant_id
+        tag = TagService.get_tag_by_tag_name(type, tenant_id, name)
+        if tag:
+            raise TagNameDuplicateError(f"Tag with name {name} already exists.")
+
         tag = Tag(
             id=str(uuid.uuid4()),
-            name=args["name"],
-            type=args["type"],
+            name=name,
+            type=type,
             created_by=current_user.id,
-            tenant_id=current_user.current_tenant_id,
+            tenant_id=tenant_id,
         )
         db.session.add(tag)
         db.session.commit()
@@ -138,7 +169,7 @@ class TagService:
 
     @staticmethod
     def check_target_exists(type: str, target_id: str):
-        if type == "knowledge":
+        if type in {"knowledge", "knowledge_category"}:
             dataset = (
                 db.session.query(Dataset)
                 .filter(Dataset.tenant_id == current_user.current_tenant_id, Dataset.id == target_id)