tag_service.py 5.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162
  1. import uuid
  2. from flask_login import current_user
  3. from sqlalchemy import func
  4. from werkzeug.exceptions import NotFound
  5. from extensions.ext_database import db
  6. from models.dataset import Dataset
  7. from models.model import App, Tag, TagBinding
  8. class TagService:
  9. @staticmethod
  10. def get_tags(tag_type: str, current_tenant_id: str, keyword: str = None) -> list:
  11. query = db.session.query(
  12. Tag.id, Tag.type, Tag.name, func.count(TagBinding.id).label('binding_count')
  13. ).outerjoin(
  14. TagBinding, Tag.id == TagBinding.tag_id
  15. ).filter(
  16. Tag.type == tag_type,
  17. Tag.tenant_id == current_tenant_id
  18. )
  19. if keyword:
  20. query = query.filter(db.and_(Tag.name.ilike(f'%{keyword}%')))
  21. query = query.group_by(
  22. Tag.id
  23. )
  24. results = query.order_by(Tag.created_at.desc()).all()
  25. return results
  26. @staticmethod
  27. def get_target_ids_by_tag_ids(tag_type: str, current_tenant_id: str, tag_ids: list) -> list:
  28. tags = db.session.query(Tag).filter(
  29. Tag.id.in_(tag_ids),
  30. Tag.tenant_id == current_tenant_id,
  31. Tag.type == tag_type
  32. ).all()
  33. if not tags:
  34. return []
  35. tag_ids = [tag.id for tag in tags]
  36. tag_bindings = db.session.query(
  37. TagBinding.target_id
  38. ).filter(
  39. TagBinding.tag_id.in_(tag_ids),
  40. TagBinding.tenant_id == current_tenant_id
  41. ).all()
  42. if not tag_bindings:
  43. return []
  44. results = [tag_binding.target_id for tag_binding in tag_bindings]
  45. return results
  46. @staticmethod
  47. def get_tags_by_target_id(tag_type: str, current_tenant_id: str, target_id: str) -> list:
  48. tags = db.session.query(Tag).join(
  49. TagBinding,
  50. Tag.id == TagBinding.tag_id
  51. ).filter(
  52. TagBinding.target_id == target_id,
  53. TagBinding.tenant_id == current_tenant_id,
  54. Tag.tenant_id == current_tenant_id,
  55. Tag.type == tag_type
  56. ).all()
  57. return tags if tags else []
  58. @staticmethod
  59. def save_tags(args: dict) -> Tag:
  60. tag = Tag(
  61. id=str(uuid.uuid4()),
  62. name=args['name'],
  63. type=args['type'],
  64. created_by=current_user.id,
  65. tenant_id=current_user.current_tenant_id
  66. )
  67. db.session.add(tag)
  68. db.session.commit()
  69. return tag
  70. @staticmethod
  71. def update_tags(args: dict, tag_id: str) -> Tag:
  72. tag = db.session.query(Tag).filter(Tag.id == tag_id).first()
  73. if not tag:
  74. raise NotFound("Tag not found")
  75. tag.name = args['name']
  76. db.session.commit()
  77. return tag
  78. @staticmethod
  79. def get_tag_binding_count(tag_id: str) -> int:
  80. count = db.session.query(TagBinding).filter(TagBinding.tag_id == tag_id).count()
  81. return count
  82. @staticmethod
  83. def delete_tag(tag_id: str):
  84. tag = db.session.query(Tag).filter(Tag.id == tag_id).first()
  85. if not tag:
  86. raise NotFound("Tag not found")
  87. db.session.delete(tag)
  88. # delete tag binding
  89. tag_bindings = db.session.query(TagBinding).filter(TagBinding.tag_id == tag_id).all()
  90. if tag_bindings:
  91. for tag_binding in tag_bindings:
  92. db.session.delete(tag_binding)
  93. db.session.commit()
  94. @staticmethod
  95. def save_tag_binding(args):
  96. # check if target exists
  97. TagService.check_target_exists(args['type'], args['target_id'])
  98. # save tag binding
  99. for tag_id in args['tag_ids']:
  100. tag_binding = db.session.query(TagBinding).filter(
  101. TagBinding.tag_id == tag_id,
  102. TagBinding.target_id == args['target_id']
  103. ).first()
  104. if tag_binding:
  105. continue
  106. new_tag_binding = TagBinding(
  107. tag_id=tag_id,
  108. target_id=args['target_id'],
  109. tenant_id=current_user.current_tenant_id,
  110. created_by=current_user.id
  111. )
  112. db.session.add(new_tag_binding)
  113. db.session.commit()
  114. @staticmethod
  115. def delete_tag_binding(args):
  116. # check if target exists
  117. TagService.check_target_exists(args['type'], args['target_id'])
  118. # delete tag binding
  119. tag_bindings = db.session.query(TagBinding).filter(
  120. TagBinding.target_id == args['target_id'],
  121. TagBinding.tag_id == (args['tag_id'])
  122. ).first()
  123. if tag_bindings:
  124. db.session.delete(tag_bindings)
  125. db.session.commit()
  126. @staticmethod
  127. def check_target_exists(type: str, target_id: str):
  128. if type == 'knowledge':
  129. dataset = db.session.query(Dataset).filter(
  130. Dataset.tenant_id == current_user.current_tenant_id,
  131. Dataset.id == target_id
  132. ).first()
  133. if not dataset:
  134. raise NotFound("Dataset not found")
  135. elif type == 'app':
  136. app = db.session.query(App).filter(
  137. App.tenant_id == current_user.current_tenant_id,
  138. App.id == target_id
  139. ).first()
  140. if not app:
  141. raise NotFound("App not found")
  142. else:
  143. raise NotFound("Invalid binding type")