tag_service.py 7.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201
  1. import uuid
  2. from typing import Optional
  3. from flask_login import current_user # type: ignore
  4. from sqlalchemy import func
  5. from werkzeug.exceptions import NotFound
  6. from extensions.ext_database import db
  7. from models.dataset import Dataset
  8. from models.model import App, Tag, TagBinding
  9. from services.errors.tag import TagNameDuplicateError
  10. class TagService:
  11. @staticmethod
  12. def get_tags(tag_type: str, current_tenant_id: str, keyword: Optional[str] = None) -> list:
  13. query = (
  14. db.session.query(Tag.id, Tag.type, Tag.name, func.count(TagBinding.id).label("binding_count"))
  15. .outerjoin(TagBinding, Tag.id == TagBinding.tag_id)
  16. .filter(Tag.type == tag_type, Tag.tenant_id == current_tenant_id)
  17. )
  18. if keyword:
  19. query = query.filter(db.and_(Tag.name.ilike(f"%{keyword}%")))
  20. query = query.group_by(Tag.id, Tag.type, Tag.name)
  21. results: list = query.order_by(Tag.created_at.desc()).all()
  22. return results
  23. @staticmethod
  24. def get_tag_by_tag_name(tag_type: str, tenant_id: str, tag_name: str) -> Optional[Tag]:
  25. tag: Optional[Tag] = db.session.query(Tag).filter(Tag.type == tag_type, Tag.tenant_id == tenant_id,
  26. Tag.name == tag_name).first()
  27. return tag
  28. @staticmethod
  29. def get_target_ids_by_tag_ids(tag_type: str, current_tenant_id: str, tag_ids: list) -> list:
  30. tags = (
  31. db.session.query(Tag)
  32. .filter(Tag.id.in_(tag_ids), Tag.tenant_id == current_tenant_id, Tag.type == tag_type)
  33. .all()
  34. )
  35. if not tags:
  36. return []
  37. tag_ids = [tag.id for tag in tags]
  38. tag_bindings = (
  39. db.session.query(TagBinding.target_id)
  40. .filter(TagBinding.tag_id.in_(tag_ids), TagBinding.tenant_id == current_tenant_id)
  41. .all()
  42. )
  43. if not tag_bindings:
  44. return []
  45. results = [tag_binding.target_id for tag_binding in tag_bindings]
  46. return results
  47. @staticmethod
  48. def get_tags_count(tenant_id: str, keyword: Optional[str] = None) -> int:
  49. query = db.session.query(Tag).filter(Tag.type == "knowledge")
  50. if tenant_id:
  51. query = query.filter(Tag.tenant_id == tenant_id)
  52. if keyword:
  53. query = query.filter(Tag.name.ilike(f"%{keyword}%"))
  54. return query.count()
  55. @staticmethod
  56. def get_tags_by_target_id(tag_type: str, current_tenant_id: str, target_id: str) -> list:
  57. tags = (
  58. db.session.query(Tag)
  59. .join(TagBinding, Tag.id == TagBinding.tag_id)
  60. .filter(
  61. TagBinding.target_id == target_id,
  62. TagBinding.tenant_id == current_tenant_id,
  63. Tag.tenant_id == current_tenant_id,
  64. Tag.type == tag_type,
  65. )
  66. .all()
  67. )
  68. return tags or []
  69. @staticmethod
  70. def get_tag(tag_id: str):
  71. tag = db.session.query(Tag).filter(Tag.id == tag_id).first()
  72. if not tag:
  73. raise NotFound("Tag not found")
  74. return tag
  75. @staticmethod
  76. def get_page_tags(page, per_page, tag_type, tenant_id):
  77. query = (
  78. db.session.query(Tag.id, Tag.type, Tag.name, func.count(TagBinding.id).label("binding_count"))
  79. .outerjoin(TagBinding, Tag.id == TagBinding.tag_id)
  80. .filter(Tag.type == tag_type, Tag.tenant_id == tenant_id)
  81. )
  82. query = query.group_by(Tag.id, Tag.type, Tag.name)
  83. tags = query.paginate(page=page, per_page=per_page, error_out=False)
  84. return tags.items, tags.total
  85. @staticmethod
  86. def save_tags(args: dict) -> Tag:
  87. name = args["name"]
  88. type = args["type"]
  89. tenant_id = current_user.current_tenant_id
  90. tag = TagService.get_tag_by_tag_name(type, tenant_id, name)
  91. if tag:
  92. raise TagNameDuplicateError(f"Tag with name {name} already exists.")
  93. tag = Tag(
  94. id=str(uuid.uuid4()),
  95. name=name,
  96. type=type,
  97. created_by=current_user.id,
  98. tenant_id=tenant_id,
  99. )
  100. db.session.add(tag)
  101. db.session.commit()
  102. return tag
  103. @staticmethod
  104. def update_tags(args: dict, tag_id: str) -> Tag:
  105. tag = db.session.query(Tag).filter(Tag.id == tag_id).first()
  106. if not tag:
  107. raise NotFound("Tag not found")
  108. tag.name = args["name"]
  109. db.session.commit()
  110. return tag
  111. @staticmethod
  112. def get_tag_binding_count(tag_id: str) -> int:
  113. count = db.session.query(TagBinding).filter(TagBinding.tag_id == tag_id).count()
  114. return count
  115. @staticmethod
  116. def delete_tag(tag_id: str):
  117. tag = db.session.query(Tag).filter(Tag.id == tag_id).first()
  118. if not tag:
  119. raise NotFound("Tag not found")
  120. db.session.delete(tag)
  121. # delete tag binding
  122. tag_bindings = db.session.query(TagBinding).filter(TagBinding.tag_id == tag_id).all()
  123. if tag_bindings:
  124. for tag_binding in tag_bindings:
  125. db.session.delete(tag_binding)
  126. db.session.commit()
  127. @staticmethod
  128. def save_tag_binding(args):
  129. # check if target exists
  130. TagService.check_target_exists(args["type"], args["target_id"])
  131. # save tag binding
  132. for tag_id in args["tag_ids"]:
  133. tag_binding = (
  134. db.session.query(TagBinding)
  135. .filter(TagBinding.tag_id == tag_id, TagBinding.target_id == args["target_id"])
  136. .first()
  137. )
  138. if tag_binding:
  139. continue
  140. new_tag_binding = TagBinding(
  141. tag_id=tag_id,
  142. target_id=args["target_id"],
  143. tenant_id=current_user.current_tenant_id,
  144. created_by=current_user.id,
  145. )
  146. db.session.add(new_tag_binding)
  147. db.session.commit()
  148. @staticmethod
  149. def delete_tag_binding(args):
  150. # check if target exists
  151. TagService.check_target_exists(args["type"], args["target_id"])
  152. # delete tag binding
  153. tag_bindings = (
  154. db.session.query(TagBinding)
  155. .filter(TagBinding.target_id == args["target_id"], TagBinding.tag_id == (args["tag_id"]))
  156. .first()
  157. )
  158. if tag_bindings:
  159. db.session.delete(tag_bindings)
  160. db.session.commit()
  161. @staticmethod
  162. def check_target_exists(type: str, target_id: str):
  163. if type in {"knowledge", "knowledge_category"}:
  164. dataset = (
  165. db.session.query(Dataset)
  166. .filter(Dataset.tenant_id == current_user.current_tenant_id, Dataset.id == target_id)
  167. .first()
  168. )
  169. if not dataset:
  170. raise NotFound("Dataset not found")
  171. elif type == "app":
  172. app = (
  173. db.session.query(App)
  174. .filter(App.tenant_id == current_user.current_tenant_id, App.id == target_id)
  175. .first()
  176. )
  177. if not app:
  178. raise NotFound("App not found")
  179. else:
  180. raise NotFound("Invalid binding type")