tag_service.py 7.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202
  1. import uuid
  2. from typing import Optional
  3. from flask_login import current_user # type: ignore
  4. from sqlalchemy import func
  5. from werkzeug.exceptions import NotFound
  6. from extensions.ext_database import db
  7. from models.dataset import Dataset
  8. from models.model import App, Tag, TagBinding
  9. from services.errors.tag import TagNameDuplicateError
  10. class TagService:
  11. @staticmethod
  12. def get_tags(tag_type: str, current_tenant_id: str, keyword: Optional[str] = None) -> list:
  13. query = (
  14. db.session.query(Tag.id, Tag.type, Tag.name, func.count(TagBinding.id).label("binding_count"))
  15. .outerjoin(TagBinding, Tag.id == TagBinding.tag_id)
  16. .filter(Tag.type == tag_type, Tag.tenant_id == current_tenant_id)
  17. )
  18. if keyword:
  19. query = query.filter(db.and_(Tag.name.ilike(f"%{keyword}%")))
  20. query = query.group_by(Tag.id, Tag.type, Tag.name)
  21. results: list = query.order_by(Tag.created_at.desc()).all()
  22. return results
  23. @staticmethod
  24. def get_tag_by_tag_name(tag_type: str, tenant_id: str, tag_name: str) -> Optional[Tag]:
  25. tag: Optional[Tag] = (
  26. db.session.query(Tag).filter(Tag.type == tag_type, Tag.tenant_id == tenant_id, Tag.name == tag_name).first()
  27. )
  28. return tag
  29. @staticmethod
  30. def get_target_ids_by_tag_ids(tag_type: str, current_tenant_id: str, tag_ids: list) -> list:
  31. tags = (
  32. db.session.query(Tag)
  33. .filter(Tag.id.in_(tag_ids), Tag.tenant_id == current_tenant_id, Tag.type == tag_type)
  34. .all()
  35. )
  36. if not tags:
  37. return []
  38. tag_ids = [tag.id for tag in tags]
  39. tag_bindings = (
  40. db.session.query(TagBinding.target_id)
  41. .filter(TagBinding.tag_id.in_(tag_ids), TagBinding.tenant_id == current_tenant_id)
  42. .all()
  43. )
  44. if not tag_bindings:
  45. return []
  46. results = [tag_binding.target_id for tag_binding in tag_bindings]
  47. return results
  48. @staticmethod
  49. def get_tags_count(tenant_id: str, keyword: Optional[str] = None) -> int:
  50. query = db.session.query(Tag).filter(Tag.type == "knowledge")
  51. if tenant_id:
  52. query = query.filter(Tag.tenant_id == tenant_id)
  53. if keyword:
  54. query = query.filter(Tag.name.ilike(f"%{keyword}%"))
  55. return query.count()
  56. @staticmethod
  57. def get_tags_by_target_id(tag_type: str, current_tenant_id: str, target_id: str) -> list:
  58. tags = (
  59. db.session.query(Tag)
  60. .join(TagBinding, Tag.id == TagBinding.tag_id)
  61. .filter(
  62. TagBinding.target_id == target_id,
  63. TagBinding.tenant_id == current_tenant_id,
  64. Tag.tenant_id == current_tenant_id,
  65. Tag.type == tag_type,
  66. )
  67. .all()
  68. )
  69. return tags or []
  70. @staticmethod
  71. def get_tag(tag_id: str):
  72. tag = db.session.query(Tag).filter(Tag.id == tag_id).first()
  73. if not tag:
  74. raise NotFound("Tag not found")
  75. return tag
  76. @staticmethod
  77. def get_page_tags(page, per_page, tag_type, tenant_id):
  78. query = (
  79. db.session.query(Tag.id, Tag.type, Tag.name, func.count(TagBinding.id).label("binding_count"))
  80. .outerjoin(TagBinding, Tag.id == TagBinding.tag_id)
  81. .filter(Tag.type == tag_type, Tag.tenant_id == tenant_id)
  82. )
  83. query = query.group_by(Tag.id, Tag.type, Tag.name)
  84. tags = query.paginate(page=page, per_page=per_page, error_out=False)
  85. return tags.items, tags.total
  86. @staticmethod
  87. def save_tags(args: dict) -> Tag:
  88. name = args["name"]
  89. type = args["type"]
  90. tenant_id = current_user.current_tenant_id
  91. tag = TagService.get_tag_by_tag_name(type, tenant_id, name)
  92. if tag:
  93. raise TagNameDuplicateError(f"Tag with name {name} already exists.")
  94. tag = Tag(
  95. id=str(uuid.uuid4()),
  96. name=name,
  97. type=type,
  98. created_by=current_user.id,
  99. tenant_id=tenant_id,
  100. )
  101. db.session.add(tag)
  102. db.session.commit()
  103. return tag
  104. @staticmethod
  105. def update_tags(args: dict, tag_id: str) -> Tag:
  106. tag = db.session.query(Tag).filter(Tag.id == tag_id).first()
  107. if not tag:
  108. raise NotFound("Tag not found")
  109. tag.name = args["name"]
  110. db.session.commit()
  111. return tag
  112. @staticmethod
  113. def get_tag_binding_count(tag_id: str) -> int:
  114. count = db.session.query(TagBinding).filter(TagBinding.tag_id == tag_id).count()
  115. return count
  116. @staticmethod
  117. def delete_tag(tag_id: str):
  118. tag = db.session.query(Tag).filter(Tag.id == tag_id).first()
  119. if not tag:
  120. raise NotFound("Tag not found")
  121. db.session.delete(tag)
  122. # delete tag binding
  123. tag_bindings = db.session.query(TagBinding).filter(TagBinding.tag_id == tag_id).all()
  124. if tag_bindings:
  125. for tag_binding in tag_bindings:
  126. db.session.delete(tag_binding)
  127. db.session.commit()
  128. @staticmethod
  129. def save_tag_binding(args):
  130. # check if target exists
  131. TagService.check_target_exists(args["type"], args["target_id"])
  132. # save tag binding
  133. for tag_id in args["tag_ids"]:
  134. tag_binding = (
  135. db.session.query(TagBinding)
  136. .filter(TagBinding.tag_id == tag_id, TagBinding.target_id == args["target_id"])
  137. .first()
  138. )
  139. if tag_binding:
  140. continue
  141. new_tag_binding = TagBinding(
  142. tag_id=tag_id,
  143. target_id=args["target_id"],
  144. tenant_id=current_user.current_tenant_id,
  145. created_by=current_user.id,
  146. )
  147. db.session.add(new_tag_binding)
  148. db.session.commit()
  149. @staticmethod
  150. def delete_tag_binding(args):
  151. # check if target exists
  152. TagService.check_target_exists(args["type"], args["target_id"])
  153. # delete tag binding
  154. tag_bindings = (
  155. db.session.query(TagBinding)
  156. .filter(TagBinding.target_id == args["target_id"], TagBinding.tag_id == (args["tag_id"]))
  157. .first()
  158. )
  159. if tag_bindings:
  160. db.session.delete(tag_bindings)
  161. db.session.commit()
  162. @staticmethod
  163. def check_target_exists(type: str, target_id: str):
  164. if type in {"knowledge", "knowledge_category"}:
  165. dataset = (
  166. db.session.query(Dataset)
  167. .filter(Dataset.tenant_id == current_user.current_tenant_id, Dataset.id == target_id)
  168. .first()
  169. )
  170. if not dataset:
  171. raise NotFound("Dataset not found")
  172. elif type == "app":
  173. app = (
  174. db.session.query(App)
  175. .filter(App.tenant_id == current_user.current_tenant_id, App.id == target_id)
  176. .first()
  177. )
  178. if not app:
  179. raise NotFound("App not found")
  180. else:
  181. raise NotFound("Invalid binding type")