tag_service.py 6.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190
  1. import uuid
  2. from typing import Optional
  3. from flask_login import current_user # type: ignore
  4. from sqlalchemy import func
  5. from werkzeug.exceptions import NotFound
  6. from extensions.ext_database import db
  7. from models.dataset import Dataset
  8. from models.model import App, Tag, TagBinding
  9. from services.errors.tag import TagNameDuplicateError
  10. class TagService:
  11. @staticmethod
  12. def get_tags(tag_type: str, current_tenant_id: str, keyword: Optional[str] = None) -> list:
  13. query = (
  14. db.session.query(Tag.id, Tag.type, Tag.name, func.count(TagBinding.id).label("binding_count"))
  15. .outerjoin(TagBinding, Tag.id == TagBinding.tag_id)
  16. .filter(Tag.type == tag_type, Tag.tenant_id == current_tenant_id)
  17. )
  18. if keyword:
  19. query = query.filter(db.and_(Tag.name.ilike(f"%{keyword}%")))
  20. query = query.group_by(Tag.id, Tag.type, Tag.name)
  21. results: list = query.order_by(Tag.created_at.desc()).all()
  22. return results
  23. @staticmethod
  24. def get_tag_by_tag_name(tag_type: str, tenant_id: str, tag_name: str) -> Optional[Tag]:
  25. tag: Optional[Tag] = db.session.query(Tag).filter(Tag.type == tag_type, Tag.tenant_id == tenant_id,
  26. Tag.name == tag_name).first()
  27. return tag
  28. @staticmethod
  29. def get_target_ids_by_tag_ids(tag_type: str, current_tenant_id: str, tag_ids: list) -> list:
  30. tags = (
  31. db.session.query(Tag)
  32. .filter(Tag.id.in_(tag_ids), Tag.tenant_id == current_tenant_id, Tag.type == tag_type)
  33. .all()
  34. )
  35. if not tags:
  36. return []
  37. tag_ids = [tag.id for tag in tags]
  38. tag_bindings = (
  39. db.session.query(TagBinding.target_id)
  40. .filter(TagBinding.tag_id.in_(tag_ids), TagBinding.tenant_id == current_tenant_id)
  41. .all()
  42. )
  43. if not tag_bindings:
  44. return []
  45. results = [tag_binding.target_id for tag_binding in tag_bindings]
  46. return results
  47. @staticmethod
  48. def get_tags_by_target_id(tag_type: str, current_tenant_id: str, target_id: str) -> list:
  49. tags = (
  50. db.session.query(Tag)
  51. .join(TagBinding, Tag.id == TagBinding.tag_id)
  52. .filter(
  53. TagBinding.target_id == target_id,
  54. TagBinding.tenant_id == current_tenant_id,
  55. Tag.tenant_id == current_tenant_id,
  56. Tag.type == tag_type,
  57. )
  58. .all()
  59. )
  60. return tags or []
  61. @staticmethod
  62. def get_tag(tag_id: str):
  63. tag = db.session.query(Tag).filter(Tag.id == tag_id).first()
  64. if not tag:
  65. raise NotFound("Tag not found")
  66. return tag
  67. @staticmethod
  68. def get_page_tags(page, per_page, tag_type, tenant_id):
  69. query = (
  70. db.session.query(Tag)
  71. .filter(Tag.tenant_id == tenant_id, Tag.type == tag_type)
  72. .order_by(Tag.created_at.desc())
  73. )
  74. tags = query.paginate(page=page, per_page=per_page, error_out=False)
  75. return tags.items, tags.total
  76. @staticmethod
  77. def save_tags(args: dict) -> Tag:
  78. name = args["name"]
  79. type = args["type"]
  80. tenant_id = current_user.current_tenant_id
  81. tag = TagService.get_tag_by_tag_name(type, tenant_id, name)
  82. if tag:
  83. raise TagNameDuplicateError(f"Tag with name {name} already exists.")
  84. tag = Tag(
  85. id=str(uuid.uuid4()),
  86. name=name,
  87. type=type,
  88. created_by=current_user.id,
  89. tenant_id=tenant_id,
  90. )
  91. db.session.add(tag)
  92. db.session.commit()
  93. return tag
  94. @staticmethod
  95. def update_tags(args: dict, tag_id: str) -> Tag:
  96. tag = db.session.query(Tag).filter(Tag.id == tag_id).first()
  97. if not tag:
  98. raise NotFound("Tag not found")
  99. tag.name = args["name"]
  100. db.session.commit()
  101. return tag
  102. @staticmethod
  103. def get_tag_binding_count(tag_id: str) -> int:
  104. count = db.session.query(TagBinding).filter(TagBinding.tag_id == tag_id).count()
  105. return count
  106. @staticmethod
  107. def delete_tag(tag_id: str):
  108. tag = db.session.query(Tag).filter(Tag.id == tag_id).first()
  109. if not tag:
  110. raise NotFound("Tag not found")
  111. db.session.delete(tag)
  112. # delete tag binding
  113. tag_bindings = db.session.query(TagBinding).filter(TagBinding.tag_id == tag_id).all()
  114. if tag_bindings:
  115. for tag_binding in tag_bindings:
  116. db.session.delete(tag_binding)
  117. db.session.commit()
  118. @staticmethod
  119. def save_tag_binding(args):
  120. # check if target exists
  121. TagService.check_target_exists(args["type"], args["target_id"])
  122. # save tag binding
  123. for tag_id in args["tag_ids"]:
  124. tag_binding = (
  125. db.session.query(TagBinding)
  126. .filter(TagBinding.tag_id == tag_id, TagBinding.target_id == args["target_id"])
  127. .first()
  128. )
  129. if tag_binding:
  130. continue
  131. new_tag_binding = TagBinding(
  132. tag_id=tag_id,
  133. target_id=args["target_id"],
  134. tenant_id=current_user.current_tenant_id,
  135. created_by=current_user.id,
  136. )
  137. db.session.add(new_tag_binding)
  138. db.session.commit()
  139. @staticmethod
  140. def delete_tag_binding(args):
  141. # check if target exists
  142. TagService.check_target_exists(args["type"], args["target_id"])
  143. # delete tag binding
  144. tag_bindings = (
  145. db.session.query(TagBinding)
  146. .filter(TagBinding.target_id == args["target_id"], TagBinding.tag_id == (args["tag_id"]))
  147. .first()
  148. )
  149. if tag_bindings:
  150. db.session.delete(tag_bindings)
  151. db.session.commit()
  152. @staticmethod
  153. def check_target_exists(type: str, target_id: str):
  154. if type in {"knowledge", "knowledge_category"}:
  155. dataset = (
  156. db.session.query(Dataset)
  157. .filter(Dataset.tenant_id == current_user.current_tenant_id, Dataset.id == target_id)
  158. .first()
  159. )
  160. if not dataset:
  161. raise NotFound("Dataset not found")
  162. elif type == "app":
  163. app = (
  164. db.session.query(App)
  165. .filter(App.tenant_id == current_user.current_tenant_id, App.id == target_id)
  166. .first()
  167. )
  168. if not app:
  169. raise NotFound("App not found")
  170. else:
  171. raise NotFound("Invalid binding type")