import uuid from typing import Optional from flask_login import current_user # type: ignore from sqlalchemy import func from werkzeug.exceptions import NotFound from extensions.ext_database import db from models.dataset import Dataset from models.model import App, Tag, TagBinding from services.errors.tag import TagNameDuplicateError class TagService: @staticmethod def get_tags(tag_type: str, current_tenant_id: str, keyword: Optional[str] = None) -> list: query = ( db.session.query(Tag.id, Tag.type, Tag.name, func.count(TagBinding.id).label("binding_count")) .outerjoin(TagBinding, Tag.id == TagBinding.tag_id) .filter(Tag.type == tag_type, Tag.tenant_id == current_tenant_id) ) if keyword: query = query.filter(db.and_(Tag.name.ilike(f"%{keyword}%"))) query = query.group_by(Tag.id, Tag.type, Tag.name) results: list = query.order_by(Tag.created_at.desc()).all() return results @staticmethod def get_tag_by_tag_name(tag_type: str, tenant_id: str, tag_name: str) -> Optional[Tag]: tag: Optional[Tag] = db.session.query(Tag).filter(Tag.type == tag_type, Tag.tenant_id == tenant_id, Tag.name == tag_name).first() return tag @staticmethod def get_target_ids_by_tag_ids(tag_type: str, current_tenant_id: str, tag_ids: list) -> list: tags = ( db.session.query(Tag) .filter(Tag.id.in_(tag_ids), Tag.tenant_id == current_tenant_id, Tag.type == tag_type) .all() ) if not tags: return [] tag_ids = [tag.id for tag in tags] tag_bindings = ( db.session.query(TagBinding.target_id) .filter(TagBinding.tag_id.in_(tag_ids), TagBinding.tenant_id == current_tenant_id) .all() ) if not tag_bindings: return [] results = [tag_binding.target_id for tag_binding in tag_bindings] return results @staticmethod def get_tags_by_target_id(tag_type: str, current_tenant_id: str, target_id: str) -> list: tags = ( db.session.query(Tag) .join(TagBinding, Tag.id == TagBinding.tag_id) .filter( TagBinding.target_id == target_id, TagBinding.tenant_id == current_tenant_id, Tag.tenant_id == current_tenant_id, Tag.type == tag_type, ) .all() ) return tags or [] @staticmethod def get_tag(tag_id: str): tag = db.session.query(Tag).filter(Tag.id == tag_id).first() if not tag: raise NotFound("Tag not found") return tag @staticmethod def get_page_tags(page, per_page, tag_type, tenant_id): query = ( db.session.query(Tag.id, Tag.type, Tag.name, func.count(TagBinding.id).label("binding_count")) .outerjoin(TagBinding, Tag.id == TagBinding.tag_id) .filter(Tag.type == tag_type, Tag.tenant_id == tenant_id) ) query = query.group_by(Tag.id, Tag.type, Tag.name) tags = query.paginate(page=page, per_page=per_page, error_out=False) return tags.items, tags.total @staticmethod def save_tags(args: dict) -> Tag: name = args["name"] type = args["type"] tenant_id = current_user.current_tenant_id tag = TagService.get_tag_by_tag_name(type, tenant_id, name) if tag: raise TagNameDuplicateError(f"Tag with name {name} already exists.") tag = Tag( id=str(uuid.uuid4()), name=name, type=type, created_by=current_user.id, tenant_id=tenant_id, ) db.session.add(tag) db.session.commit() return tag @staticmethod def update_tags(args: dict, tag_id: str) -> Tag: tag = db.session.query(Tag).filter(Tag.id == tag_id).first() if not tag: raise NotFound("Tag not found") tag.name = args["name"] db.session.commit() return tag @staticmethod def get_tag_binding_count(tag_id: str) -> int: count = db.session.query(TagBinding).filter(TagBinding.tag_id == tag_id).count() return count @staticmethod def delete_tag(tag_id: str): tag = db.session.query(Tag).filter(Tag.id == tag_id).first() if not tag: raise NotFound("Tag not found") db.session.delete(tag) # delete tag binding tag_bindings = db.session.query(TagBinding).filter(TagBinding.tag_id == tag_id).all() if tag_bindings: for tag_binding in tag_bindings: db.session.delete(tag_binding) db.session.commit() @staticmethod def save_tag_binding(args): # check if target exists TagService.check_target_exists(args["type"], args["target_id"]) # save tag binding for tag_id in args["tag_ids"]: tag_binding = ( db.session.query(TagBinding) .filter(TagBinding.tag_id == tag_id, TagBinding.target_id == args["target_id"]) .first() ) if tag_binding: continue new_tag_binding = TagBinding( tag_id=tag_id, target_id=args["target_id"], tenant_id=current_user.current_tenant_id, created_by=current_user.id, ) db.session.add(new_tag_binding) db.session.commit() @staticmethod def delete_tag_binding(args): # check if target exists TagService.check_target_exists(args["type"], args["target_id"]) # delete tag binding tag_bindings = ( db.session.query(TagBinding) .filter(TagBinding.target_id == args["target_id"], TagBinding.tag_id == (args["tag_id"])) .first() ) if tag_bindings: db.session.delete(tag_bindings) db.session.commit() @staticmethod def check_target_exists(type: str, target_id: str): if type in {"knowledge", "knowledge_category"}: dataset = ( db.session.query(Dataset) .filter(Dataset.tenant_id == current_user.current_tenant_id, Dataset.id == target_id) .first() ) if not dataset: raise NotFound("Dataset not found") elif type == "app": app = ( db.session.query(App) .filter(App.tenant_id == current_user.current_tenant_id, App.id == target_id) .first() ) if not app: raise NotFound("App not found") else: raise NotFound("Invalid binding type")