tag_service.py 5.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158
  1. import uuid
  2. from flask_login import current_user
  3. from sqlalchemy import func
  4. from werkzeug.exceptions import NotFound
  5. from extensions.ext_database import db
  6. from models.dataset import Dataset
  7. from models.model import App, Tag, TagBinding
  8. class TagService:
  9. @staticmethod
  10. def get_tags(tag_type: str, current_tenant_id: str, keyword: str = None) -> list:
  11. query = (
  12. db.session.query(Tag.id, Tag.type, Tag.name, func.count(TagBinding.id).label("binding_count"))
  13. .outerjoin(TagBinding, Tag.id == TagBinding.tag_id)
  14. .filter(Tag.type == tag_type, Tag.tenant_id == current_tenant_id)
  15. )
  16. if keyword:
  17. query = query.filter(db.and_(Tag.name.ilike(f"%{keyword}%")))
  18. query = query.group_by(Tag.id)
  19. results = query.order_by(Tag.created_at.desc()).all()
  20. return results
  21. @staticmethod
  22. def get_target_ids_by_tag_ids(tag_type: str, current_tenant_id: str, tag_ids: list) -> list:
  23. tags = (
  24. db.session.query(Tag)
  25. .filter(Tag.id.in_(tag_ids), Tag.tenant_id == current_tenant_id, Tag.type == tag_type)
  26. .all()
  27. )
  28. if not tags:
  29. return []
  30. tag_ids = [tag.id for tag in tags]
  31. tag_bindings = (
  32. db.session.query(TagBinding.target_id)
  33. .filter(TagBinding.tag_id.in_(tag_ids), TagBinding.tenant_id == current_tenant_id)
  34. .all()
  35. )
  36. if not tag_bindings:
  37. return []
  38. results = [tag_binding.target_id for tag_binding in tag_bindings]
  39. return results
  40. @staticmethod
  41. def get_tags_by_target_id(tag_type: str, current_tenant_id: str, target_id: str) -> list:
  42. tags = (
  43. db.session.query(Tag)
  44. .join(TagBinding, Tag.id == TagBinding.tag_id)
  45. .filter(
  46. TagBinding.target_id == target_id,
  47. TagBinding.tenant_id == current_tenant_id,
  48. Tag.tenant_id == current_tenant_id,
  49. Tag.type == tag_type,
  50. )
  51. .all()
  52. )
  53. return tags or []
  54. @staticmethod
  55. def save_tags(args: dict) -> Tag:
  56. tag = Tag(
  57. id=str(uuid.uuid4()),
  58. name=args["name"],
  59. type=args["type"],
  60. created_by=current_user.id,
  61. tenant_id=current_user.current_tenant_id,
  62. )
  63. db.session.add(tag)
  64. db.session.commit()
  65. return tag
  66. @staticmethod
  67. def update_tags(args: dict, tag_id: str) -> Tag:
  68. tag = db.session.query(Tag).filter(Tag.id == tag_id).first()
  69. if not tag:
  70. raise NotFound("Tag not found")
  71. tag.name = args["name"]
  72. db.session.commit()
  73. return tag
  74. @staticmethod
  75. def get_tag_binding_count(tag_id: str) -> int:
  76. count = db.session.query(TagBinding).filter(TagBinding.tag_id == tag_id).count()
  77. return count
  78. @staticmethod
  79. def delete_tag(tag_id: str):
  80. tag = db.session.query(Tag).filter(Tag.id == tag_id).first()
  81. if not tag:
  82. raise NotFound("Tag not found")
  83. db.session.delete(tag)
  84. # delete tag binding
  85. tag_bindings = db.session.query(TagBinding).filter(TagBinding.tag_id == tag_id).all()
  86. if tag_bindings:
  87. for tag_binding in tag_bindings:
  88. db.session.delete(tag_binding)
  89. db.session.commit()
  90. @staticmethod
  91. def save_tag_binding(args):
  92. # check if target exists
  93. TagService.check_target_exists(args["type"], args["target_id"])
  94. # save tag binding
  95. for tag_id in args["tag_ids"]:
  96. tag_binding = (
  97. db.session.query(TagBinding)
  98. .filter(TagBinding.tag_id == tag_id, TagBinding.target_id == args["target_id"])
  99. .first()
  100. )
  101. if tag_binding:
  102. continue
  103. new_tag_binding = TagBinding(
  104. tag_id=tag_id,
  105. target_id=args["target_id"],
  106. tenant_id=current_user.current_tenant_id,
  107. created_by=current_user.id,
  108. )
  109. db.session.add(new_tag_binding)
  110. db.session.commit()
  111. @staticmethod
  112. def delete_tag_binding(args):
  113. # check if target exists
  114. TagService.check_target_exists(args["type"], args["target_id"])
  115. # delete tag binding
  116. tag_bindings = (
  117. db.session.query(TagBinding)
  118. .filter(TagBinding.target_id == args["target_id"], TagBinding.tag_id == (args["tag_id"]))
  119. .first()
  120. )
  121. if tag_bindings:
  122. db.session.delete(tag_bindings)
  123. db.session.commit()
  124. @staticmethod
  125. def check_target_exists(type: str, target_id: str):
  126. if type == "knowledge":
  127. dataset = (
  128. db.session.query(Dataset)
  129. .filter(Dataset.tenant_id == current_user.current_tenant_id, Dataset.id == target_id)
  130. .first()
  131. )
  132. if not dataset:
  133. raise NotFound("Dataset not found")
  134. elif type == "app":
  135. app = (
  136. db.session.query(App)
  137. .filter(App.tenant_id == current_user.current_tenant_id, App.id == target_id)
  138. .first()
  139. )
  140. if not app:
  141. raise NotFound("App not found")
  142. else:
  143. raise NotFound("Invalid binding type")