Procházet zdrojové kódy

Merge branch '1.1.3-lxg' into 1.1.3-master

liangxunge před 3 měsíci
rodič
revize
db950af0c4

+ 1 - 1
api/controllers/console/__init__.py

@@ -43,7 +43,7 @@ api.add_resource(AppImportConfirmApi, "/apps/imports/<string:import_id>/confirm"
 api.add_resource(AppImportCheckDependenciesApi, "/apps/imports/<string:app_id>/check-dependencies")
 
 # Import other controllers
-from . import admin, apikey, extension, feature, ping, setup, version
+from . import admin, apikey, extension, external_application, feature, ping, setup, version
 
 # Import app controllers
 from .app import (

+ 2 - 1
api/controllers/console/datasets/datasets.py

@@ -58,12 +58,13 @@ class DatasetListApi(Resource):
         # provider = request.args.get("provider", default="vendor")
         search = request.args.get("keyword", default=None, type=str)
         tag_ids = request.args.getlist("tag_ids")
+        category_ids = request.args.getlist("category_ids")
         include_all = request.args.get("include_all", default="false").lower() == "true"
         if ids:
             datasets, total = DatasetService.get_datasets_by_ids(ids, current_user.current_tenant_id)
         else:
             datasets, total = DatasetService.get_datasets(
-                page, limit, current_user.current_tenant_id, current_user, search, tag_ids, include_all
+                page, limit, current_user.current_tenant_id, current_user, search, tag_ids, category_ids, include_all
             )
 
         # check embedding setting

+ 108 - 0
api/controllers/console/external_application.py

@@ -0,0 +1,108 @@
+from flask import request
+from flask_login import login_required
+from flask_restful import Resource, marshal, marshal_with, reqparse
+from werkzeug.exceptions import NotFound
+
+from controllers.console import api
+from controllers.console.wraps import account_initialization_required, setup_required
+from fields.external_application_fields import external_application_fields
+from models.external_application import ExternalApplication
+from services.external_application_service import ExternalApplicationService
+
+
+class ExternalApplicationListApi(Resource):
+
+    @setup_required
+    @login_required
+    @account_initialization_required
+    def get(self):
+        page = request.args.get("page", default=1, type=int)
+        limit = request.args.get("limit", default=20, type=int)
+        search = request.args.get("search", default=None, type=str)
+        type = request.args.get("type", default=None, type=str)
+        url = request.args.get("url", default=None, type=str)
+        method = request.args.get("method", default=None, type=str)
+
+        external_applications, total = ExternalApplicationService.get_external_applications(
+            page, limit, search, type, url, method)
+
+        data = marshal(external_applications, external_application_fields)
+        response = {"data": data, "has_more": len(external_applications) == limit, "limit": limit,
+                    "total": total, "page": page}
+        return response, 200
+
+    @setup_required
+    @login_required
+    @account_initialization_required
+    @marshal_with(external_application_fields)
+    def post(self):
+        parser = reqparse.RequestParser()
+        parser.add_argument(
+            "name", location="json", nullable=False, required=True, help="Name must be between 1 to 50 characters."
+        )
+        parser.add_argument(
+            "type", type=str, location="json", choices=ExternalApplication.EXTERNAL_APPLICATION_TYPE_LIST,
+            nullable=False, help="Invalid external_application type."
+        )
+        parser.add_argument(
+            "url", type=str, location="json", nullable=False, help="Invalid external_application url."
+        )
+        parser.add_argument(
+            "method", type=str, location="json", nullable=False, help="Invalid external_application method."
+        )
+        parser.add_argument(
+            "status", type=bool, location="json", nullable=False, help="Invalid external_application status."
+        )
+        args = parser.parse_args()
+        external_application = ExternalApplicationService.save_external_application(args)
+        return external_application, 200
+
+class ExternalApplicationApi(Resource):
+
+    @setup_required
+    @login_required
+    @account_initialization_required
+    @marshal_with(external_application_fields)
+    def get(self, external_application_id):
+        external_application = ExternalApplicationService.get_external_application(external_application_id)
+        if not external_application:
+            raise NotFound(f"ExternalApplication with id {id} not found.")
+        return external_application, 200
+
+    @setup_required
+    @login_required
+    @account_initialization_required
+    @marshal_with(external_application_fields)
+    def patch(self, external_application_id):
+        external_application_id = str(external_application_id)
+
+        parser = reqparse.RequestParser()
+        parser.add_argument(
+            "name", location="json", nullable=False, required=True, help="Name must be between 1 to 50 characters."
+        )
+        parser.add_argument(
+            "type", type=str, location="json", nullable=False, help="Invalid external_application type."
+        )
+        parser.add_argument(
+            "url", type=str, location="json", nullable=False, help="Invalid external_application url."
+        )
+        parser.add_argument(
+            "method", type=str, location="json", nullable=False, help="Invalid external_application method."
+        )
+        parser.add_argument(
+            "status", type=bool, location="json", nullable=False, help="Invalid external_application status."
+        )
+        args = parser.parse_args()
+        external_application = ExternalApplicationService.update_external_application(args, external_application_id)
+        return external_application, 200
+
+    @setup_required
+    @login_required
+    @account_initialization_required
+    def delete(self, external_application_id):
+        external_application_id = str(external_application_id)
+        ExternalApplicationService.delete_external_application(external_application_id)
+        return 200
+
+api.add_resource(ExternalApplicationListApi, '/external_applications')
+api.add_resource(ExternalApplicationApi, '/external_applications/<external_application_id>')

+ 15 - 1
api/controllers/console/tag/tags.py

@@ -1,6 +1,6 @@
 from flask import request
 from flask_login import current_user  # type: ignore
-from flask_restful import Resource, marshal_with, reqparse  # type: ignore
+from flask_restful import Resource, marshal, marshal_with, reqparse  # type: ignore
 from werkzeug.exceptions import Forbidden
 
 from controllers.console import api
@@ -51,6 +51,19 @@ class TagListApi(Resource):
 
         return response, 200
 
+class TagPageListApi(Resource):
+
+    @setup_required
+    @login_required
+    @account_initialization_required
+    def get(self):
+        page = request.args.get("page", default=1, type=int)
+        limit = request.args.get("limit", default=20, type=int)
+        tag_type = request.args.get("tag_type", type=str, default="")
+        tags, total = TagService.get_page_tags(page, limit, tag_type, current_user.current_tenant_id)
+        data = marshal(tags, tag_fields)
+        response = {"data": data, "has_more": len(tags) == limit, "limit": limit, "total": total, "page": page}
+        return response, 200
 
 class TagUpdateDeleteApi(Resource):
     @setup_required
@@ -136,6 +149,7 @@ class TagBindingDeleteApi(Resource):
 
 
 api.add_resource(TagListApi, "/tags")
+api.add_resource(TagPageListApi, "/tags/page")
 api.add_resource(TagUpdateDeleteApi, "/tags/<uuid:tag_id>")
 api.add_resource(TagBindingCreateApi, "/tag-bindings/create")
 api.add_resource(TagBindingDeleteApi, "/tag-bindings/remove")

+ 12 - 0
api/fields/dataset_fields.py

@@ -76,6 +76,7 @@ dataset_detail_fields = {
     "embedding_available": fields.Boolean,
     "retrieval_model_dict": fields.Nested(dataset_retrieval_model_fields),
     "tags": fields.List(fields.Nested(tag_fields)),
+    "categories": fields.List(fields.Nested(tag_fields)),
     "doc_form": fields.String,
     "external_knowledge_info": fields.Nested(external_knowledge_info_fields),
     "external_retrieval_model": fields.Nested(external_retrieval_model_fields, allow_null=True),
@@ -98,3 +99,14 @@ dataset_metadata_fields = {
     "type": fields.String,
     "name": fields.String,
 }
+
+dataset_category_fields = {
+    "id": fields.String,
+    "name": fields.String,
+    "description": fields.String,
+    "tenant_id": fields.String,
+    "created_by": fields.String,
+    "created_at": TimestampField,
+    "updated_by": fields.String,
+    "updated_at": TimestampField,
+}

+ 16 - 0
api/fields/external_application_fields.py

@@ -0,0 +1,16 @@
+from flask_restful import fields
+
+from libs.helper import TimestampField
+
+external_application_fields = {
+    "id": fields.String,
+    "name": fields.String,
+    "type": fields.String,
+    "url": fields.String,
+    "method": fields.String,
+    "status": fields.Boolean,
+    "updated_at": TimestampField,
+    "updated_by": fields.String,
+    "created_at": TimestampField,
+    "created_by": fields.String,
+}

+ 16 - 0
api/models/dataset.py

@@ -177,6 +177,22 @@ class Dataset(db.Model):  # type: ignore[name-defined]
         return tags or []
 
     @property
+    def categories(self):
+        categories = (
+            db.session.query(Tag)
+            .join(TagBinding, Tag.id == TagBinding.tag_id)
+            .filter(
+                TagBinding.target_id == self.id,
+                TagBinding.tenant_id == self.tenant_id,
+                Tag.tenant_id == self.tenant_id,
+                Tag.type == "knowledge_category",
+            )
+            .all()
+        )
+
+        return categories or []
+
+    @property
     def external_knowledge_info(self):
         if self.provider != "external":
             return None

+ 24 - 0
api/models/external_application.py

@@ -0,0 +1,24 @@
+from sqlalchemy import func
+
+from .engine import db
+from .types import StringUUID
+
+
+class ExternalApplication(db.Model):
+    __tablename__ = "external_applications"
+    __table_args__ = (
+        db.PrimaryKeyConstraint("id", name="external_application_pkey"),
+    )
+
+    EXTERNAL_APPLICATION_TYPE_LIST = ["QUESTION_ANSWER", "SEARCH", "RECOMMEND"]
+
+    id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()"))
+    name = db.Column(db.String(255), nullable=False)
+    type = db.Column(db.String(255), nullable=False)
+    url = db.Column(db.String(255), nullable=False)
+    method = db.Column(db.String(255), nullable=False)
+    status = db.Column(db.Boolean, nullable=False, server_default=db.text("true"))
+    created_by = db.Column(StringUUID, nullable=False)
+    created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
+    updated_by = db.Column(StringUUID, nullable=True)
+    updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())

+ 1 - 1
api/models/model.py

@@ -1766,7 +1766,7 @@ class Tag(Base):
         db.Index("tag_name_idx", "name"),
     )
 
-    TAG_TYPE_LIST = ["knowledge", "app"]
+    TAG_TYPE_LIST = ["knowledge", "app", "knowledge_category"]
 
     id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()"))
     tenant_id = db.Column(StringUUID, nullable=True)

+ 10 - 1
api/services/dataset_service.py

@@ -76,7 +76,8 @@ from tasks.sync_website_document_indexing_task import sync_website_document_inde
 
 class DatasetService:
     @staticmethod
-    def get_datasets(page, per_page, tenant_id=None, user=None, search=None, tag_ids=None, include_all=False):
+    def get_datasets(page, per_page, tenant_id=None, user=None, search=None, tag_ids=None, category_ids=None,
+                     include_all=False):
         query = Dataset.query.filter(Dataset.tenant_id == tenant_id).order_by(Dataset.created_at.desc())
 
         if user:
@@ -129,6 +130,14 @@ class DatasetService:
             else:
                 return [], 0
 
+        if category_ids:
+            target_ids_by_category_ids = TagService.get_target_ids_by_tag_ids("knowledge_category", tenant_id,
+                                                                              category_ids)
+            if target_ids_by_category_ids:
+                query = query.filter(Dataset.id.in_(target_ids_by_category_ids))
+            else:
+                return [], 0
+
         datasets = query.paginate(page=page, per_page=per_page, max_per_page=100, error_out=False)
 
         return datasets.items, datasets.total

+ 5 - 0
api/services/errors/external_application.py

@@ -0,0 +1,5 @@
+from services.errors.base import BaseServiceError
+
+
+class ExternalApplicationNameDuplicateError(BaseServiceError):
+    pass

+ 5 - 0
api/services/errors/tag.py

@@ -0,0 +1,5 @@
+from services.errors.base import BaseServiceError
+
+
+class TagNameDuplicateError(BaseServiceError):
+    pass

+ 87 - 0
api/services/external_application_service.py

@@ -0,0 +1,87 @@
+import uuid
+from datetime import datetime
+from typing import Optional
+
+from flask_login import current_user
+from werkzeug.exceptions import NotFound
+
+from models import db
+from models.external_application import ExternalApplication
+from services.errors.external_application import ExternalApplicationNameDuplicateError
+
+
+class ExternalApplicationService:
+    @staticmethod
+    def get_external_applications(page, per_page, search=None, type=None, url=None, method=None):
+        query = ExternalApplication.query.order_by(ExternalApplication.created_at.desc())
+
+        if search:
+            query = ExternalApplication.query.filter(ExternalApplication.name.ilike(f"%{search}%"))
+
+        if type:
+            query = ExternalApplication.query.filter(ExternalApplication.type == type)
+
+        if url:
+            query = ExternalApplication.query.filter(ExternalApplication.url == url)
+
+        if method:
+            query = ExternalApplication.query.filter(ExternalApplication.method == method)
+
+        external_applications = query.paginate(page=page, per_page=per_page, max_per_page=100, error_out=False)
+        return external_applications.items, external_applications.total
+
+    @staticmethod
+    def get_external_application(external_application_id: str) -> Optional[ExternalApplication]:
+        external_application = (ExternalApplication.query
+                                .filter(ExternalApplication.id == external_application_id).first())
+        return external_application
+
+    @staticmethod
+    def get_external_application_by_name(name: str) -> Optional[ExternalApplication]:
+        external_application = ExternalApplication.query.filter_by(name=name).first()
+        return external_application
+
+    @staticmethod
+    def save_external_application(args: dict) -> ExternalApplication:
+        name = args["name"]
+        external_application = ExternalApplicationService.get_external_application_by_name(name)
+        if external_application:
+            raise ExternalApplicationNameDuplicateError(f"ExternalApplication with name {name} already exists.")
+        external_application = ExternalApplication(
+            id=str(uuid.uuid4()),
+            name=name,
+            type=args["type"],
+            url=args["url"],
+            method=args["method"],
+            status=args["status"],
+            created_by=current_user.id,
+            updated_by=current_user.id,
+        )
+        db.session.add(external_application)
+        db.session.commit()
+        return external_application
+
+    @staticmethod
+    def update_external_application(args: dict, external_application_id: str) -> ExternalApplication:
+        external_application = ExternalApplicationService.get_external_application(external_application_id)
+        if not external_application:
+            raise NotFound("ExternalApplication not found")
+        external_application.name = args["name"]
+        external_application.type = args["type"]
+        external_application.url = args["url"]
+        external_application.method = args["method"]
+        external_application.status = args["status"]
+        external_application.updated_by = current_user.id
+        external_application.updated_at = datetime.now()
+        db.session.add(external_application)
+        db.session.commit()
+        return external_application
+
+    @staticmethod
+    def delete_external_application(external_application_id):
+        external_application = ExternalApplicationService.get_external_application(external_application_id)
+        if not external_application:
+            raise NotFound("ExternalApplication not found")
+        db.session.delete(external_application)
+        db.session.commit()
+        return external_application

+ 36 - 4
api/services/tag_service.py

@@ -8,6 +8,7 @@ from werkzeug.exceptions import NotFound
 from extensions.ext_database import db
 from models.dataset import Dataset
 from models.model import App, Tag, TagBinding
+from services.errors.tag import TagNameDuplicateError
 
 
 class TagService:
@@ -25,6 +26,12 @@ class TagService:
         return results
 
     @staticmethod
+    def get_tag_by_tag_name(tag_type: str, tenant_id: str, tag_name: str) -> Optional[Tag]:
+        tag: Optional[Tag] = db.session.query(Tag).filter(Tag.type == tag_type, Tag.tenant_id == tenant_id,
+                                                          Tag.name == tag_name).first()
+        return tag
+
+    @staticmethod
     def get_target_ids_by_tag_ids(tag_type: str, current_tenant_id: str, tag_ids: list) -> list:
         tags = (
             db.session.query(Tag)
@@ -61,13 +68,38 @@ class TagService:
         return tags or []
 
     @staticmethod
+    def get_tag(tag_id: str):
+        tag = db.session.query(Tag).filter(Tag.id == tag_id).first()
+        if not tag:
+            raise NotFound("Tag not found")
+        return tag
+
+    @staticmethod
+    def get_page_tags(page, per_page, tag_type, tenant_id):
+        query = (
+            db.session.query(Tag.id, Tag.type, Tag.name, func.count(TagBinding.id).label("binding_count"))
+            .outerjoin(TagBinding, Tag.id == TagBinding.tag_id)
+            .filter(Tag.type == tag_type, Tag.tenant_id == tenant_id)
+        )
+        query = query.group_by(Tag.id, Tag.type, Tag.name)
+        tags = query.paginate(page=page, per_page=per_page, error_out=False)
+        return tags.items, tags.total
+
+    @staticmethod
     def save_tags(args: dict) -> Tag:
+        name = args["name"]
+        type = args["type"]
+        tenant_id = current_user.current_tenant_id
+        tag = TagService.get_tag_by_tag_name(type, tenant_id, name)
+        if tag:
+            raise TagNameDuplicateError(f"Tag with name {name} already exists.")
+
         tag = Tag(
             id=str(uuid.uuid4()),
-            name=args["name"],
-            type=args["type"],
+            name=name,
+            type=type,
             created_by=current_user.id,
-            tenant_id=current_user.current_tenant_id,
+            tenant_id=tenant_id,
         )
         db.session.add(tag)
         db.session.commit()
@@ -138,7 +170,7 @@ class TagService:
 
     @staticmethod
     def check_target_exists(type: str, target_id: str):
-        if type == "knowledge":
+        if type in {"knowledge", "knowledge_category"}:
             dataset = (
                 db.session.query(Dataset)
                 .filter(Dataset.tenant_id == current_user.current_tenant_id, Dataset.id == target_id)