tag_service.py 6.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191
  1. import uuid
  2. from typing import Optional
  3. from flask_login import current_user # type: ignore
  4. from sqlalchemy import func
  5. from werkzeug.exceptions import NotFound
  6. from extensions.ext_database import db
  7. from models.dataset import Dataset
  8. from models.model import App, Tag, TagBinding
  9. from services.errors.tag import TagNameDuplicateError
  10. class TagService:
  11. @staticmethod
  12. def get_tags(tag_type: str, current_tenant_id: str, keyword: Optional[str] = None) -> list:
  13. query = (
  14. db.session.query(Tag.id, Tag.type, Tag.name, func.count(TagBinding.id).label("binding_count"))
  15. .outerjoin(TagBinding, Tag.id == TagBinding.tag_id)
  16. .filter(Tag.type == tag_type, Tag.tenant_id == current_tenant_id)
  17. )
  18. if keyword:
  19. query = query.filter(db.and_(Tag.name.ilike(f"%{keyword}%")))
  20. query = query.group_by(Tag.id, Tag.type, Tag.name)
  21. results: list = query.order_by(Tag.created_at.desc()).all()
  22. return results
  23. @staticmethod
  24. def get_tag_by_tag_name(tag_type: str, tenant_id: str, tag_name: str) -> Optional[Tag]:
  25. tag: Optional[Tag] = db.session.query(Tag).filter(Tag.type == tag_type, Tag.tenant_id == tenant_id,
  26. Tag.name == tag_name).first()
  27. return tag
  28. @staticmethod
  29. def get_target_ids_by_tag_ids(tag_type: str, current_tenant_id: str, tag_ids: list) -> list:
  30. tags = (
  31. db.session.query(Tag)
  32. .filter(Tag.id.in_(tag_ids), Tag.tenant_id == current_tenant_id, Tag.type == tag_type)
  33. .all()
  34. )
  35. if not tags:
  36. return []
  37. tag_ids = [tag.id for tag in tags]
  38. tag_bindings = (
  39. db.session.query(TagBinding.target_id)
  40. .filter(TagBinding.tag_id.in_(tag_ids), TagBinding.tenant_id == current_tenant_id)
  41. .all()
  42. )
  43. if not tag_bindings:
  44. return []
  45. results = [tag_binding.target_id for tag_binding in tag_bindings]
  46. return results
  47. @staticmethod
  48. def get_tags_by_target_id(tag_type: str, current_tenant_id: str, target_id: str) -> list:
  49. tags = (
  50. db.session.query(Tag)
  51. .join(TagBinding, Tag.id == TagBinding.tag_id)
  52. .filter(
  53. TagBinding.target_id == target_id,
  54. TagBinding.tenant_id == current_tenant_id,
  55. Tag.tenant_id == current_tenant_id,
  56. Tag.type == tag_type,
  57. )
  58. .all()
  59. )
  60. return tags or []
  61. @staticmethod
  62. def get_tag(tag_id: str):
  63. tag = db.session.query(Tag).filter(Tag.id == tag_id).first()
  64. if not tag:
  65. raise NotFound("Tag not found")
  66. return tag
  67. @staticmethod
  68. def get_page_tags(page, per_page, tag_type, tenant_id):
  69. query = (
  70. db.session.query(Tag.id, Tag.type, Tag.name, func.count(TagBinding.id).label("binding_count"))
  71. .outerjoin(TagBinding, Tag.id == TagBinding.tag_id)
  72. .filter(Tag.type == tag_type, Tag.tenant_id == tenant_id)
  73. )
  74. query = query.group_by(Tag.id, Tag.type, Tag.name)
  75. tags = query.paginate(page=page, per_page=per_page, error_out=False)
  76. return tags.items, tags.total
  77. @staticmethod
  78. def save_tags(args: dict) -> Tag:
  79. name = args["name"]
  80. type = args["type"]
  81. tenant_id = current_user.current_tenant_id
  82. tag = TagService.get_tag_by_tag_name(type, tenant_id, name)
  83. if tag:
  84. raise TagNameDuplicateError(f"Tag with name {name} already exists.")
  85. tag = Tag(
  86. id=str(uuid.uuid4()),
  87. name=name,
  88. type=type,
  89. created_by=current_user.id,
  90. tenant_id=tenant_id,
  91. )
  92. db.session.add(tag)
  93. db.session.commit()
  94. return tag
  95. @staticmethod
  96. def update_tags(args: dict, tag_id: str) -> Tag:
  97. tag = db.session.query(Tag).filter(Tag.id == tag_id).first()
  98. if not tag:
  99. raise NotFound("Tag not found")
  100. tag.name = args["name"]
  101. db.session.commit()
  102. return tag
  103. @staticmethod
  104. def get_tag_binding_count(tag_id: str) -> int:
  105. count = db.session.query(TagBinding).filter(TagBinding.tag_id == tag_id).count()
  106. return count
  107. @staticmethod
  108. def delete_tag(tag_id: str):
  109. tag = db.session.query(Tag).filter(Tag.id == tag_id).first()
  110. if not tag:
  111. raise NotFound("Tag not found")
  112. db.session.delete(tag)
  113. # delete tag binding
  114. tag_bindings = db.session.query(TagBinding).filter(TagBinding.tag_id == tag_id).all()
  115. if tag_bindings:
  116. for tag_binding in tag_bindings:
  117. db.session.delete(tag_binding)
  118. db.session.commit()
  119. @staticmethod
  120. def save_tag_binding(args):
  121. # check if target exists
  122. TagService.check_target_exists(args["type"], args["target_id"])
  123. # save tag binding
  124. for tag_id in args["tag_ids"]:
  125. tag_binding = (
  126. db.session.query(TagBinding)
  127. .filter(TagBinding.tag_id == tag_id, TagBinding.target_id == args["target_id"])
  128. .first()
  129. )
  130. if tag_binding:
  131. continue
  132. new_tag_binding = TagBinding(
  133. tag_id=tag_id,
  134. target_id=args["target_id"],
  135. tenant_id=current_user.current_tenant_id,
  136. created_by=current_user.id,
  137. )
  138. db.session.add(new_tag_binding)
  139. db.session.commit()
  140. @staticmethod
  141. def delete_tag_binding(args):
  142. # check if target exists
  143. TagService.check_target_exists(args["type"], args["target_id"])
  144. # delete tag binding
  145. tag_bindings = (
  146. db.session.query(TagBinding)
  147. .filter(TagBinding.target_id == args["target_id"], TagBinding.tag_id == (args["tag_id"]))
  148. .first()
  149. )
  150. if tag_bindings:
  151. db.session.delete(tag_bindings)
  152. db.session.commit()
  153. @staticmethod
  154. def check_target_exists(type: str, target_id: str):
  155. if type in {"knowledge", "knowledge_category"}:
  156. dataset = (
  157. db.session.query(Dataset)
  158. .filter(Dataset.tenant_id == current_user.current_tenant_id, Dataset.id == target_id)
  159. .first()
  160. )
  161. if not dataset:
  162. raise NotFound("Dataset not found")
  163. elif type == "app":
  164. app = (
  165. db.session.query(App)
  166. .filter(App.tenant_id == current_user.current_tenant_id, App.id == target_id)
  167. .first()
  168. )
  169. if not app:
  170. raise NotFound("App not found")
  171. else:
  172. raise NotFound("Invalid binding type")