tag_service.py 8.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240
  1. import uuid
  2. from typing import Optional
  3. from flask_login import current_user # type: ignore
  4. from sqlalchemy import func
  5. from werkzeug.exceptions import NotFound
  6. from extensions.ext_database import db
  7. from models.account import TenantAccountRole
  8. from models.dataset import Dataset
  9. from models.model import App, AppPermissionAll, Tag, TagBinding
  10. from services.errors.account import NoPermissionError
  11. from services.errors.tag import TagNameDuplicateError
  12. class TagService:
  13. @staticmethod
  14. def get_tags(tag_type: str, current_tenant_id: str, keyword: Optional[str] = None,
  15. label_map: Optional[str] = None) -> list:
  16. query = (
  17. db.session.query(Tag.id, Tag.type, Tag.name, func.count(TagBinding.id).label("binding_count"))
  18. .outerjoin(TagBinding, Tag.id == TagBinding.tag_id)
  19. .filter(Tag.type == tag_type, Tag.tenant_id == current_tenant_id)
  20. )
  21. if keyword:
  22. query = query.filter(db.and_(Tag.name.ilike(f"%{keyword}%")))
  23. query = query.group_by(Tag.id, Tag.type, Tag.name)
  24. if TagService.str2bool(label_map):
  25. query = query.having(func.count(TagBinding.id) >= 1)
  26. results: list = query.order_by(Tag.created_at.desc()).all()
  27. return results
  28. @staticmethod
  29. def str2bool(v: Optional[str]) -> bool:
  30. return str(v).lower() in ("true", "1", "yes")
  31. @staticmethod
  32. def get_tag_by_tag_name(tag_type: str, tenant_id: str, tag_name: str) -> Optional[Tag]:
  33. tag: Optional[Tag] = (
  34. db.session.query(Tag).filter(Tag.type == tag_type, Tag.tenant_id == tenant_id, Tag.name == tag_name).first()
  35. )
  36. return tag
  37. @staticmethod
  38. def get_target_ids_by_tag_ids(tag_type: str, current_tenant_id: str, tag_ids: list) -> list:
  39. tags = (
  40. db.session.query(Tag)
  41. .filter(Tag.id.in_(tag_ids), Tag.tenant_id == current_tenant_id, Tag.type == tag_type)
  42. .all()
  43. )
  44. if not tags:
  45. return []
  46. tag_ids = [tag.id for tag in tags]
  47. tag_bindings = (
  48. db.session.query(TagBinding.target_id)
  49. .filter(TagBinding.tag_id.in_(tag_ids), TagBinding.tenant_id == current_tenant_id)
  50. .all()
  51. )
  52. if not tag_bindings:
  53. return []
  54. results = [tag_binding.target_id for tag_binding in tag_bindings]
  55. return results
  56. @staticmethod
  57. def get_tags_count(tenant_id: str, keyword: Optional[str] = None) -> int:
  58. query = db.session.query(Tag).filter(Tag.type == "knowledge")
  59. if tenant_id:
  60. query = query.filter(Tag.tenant_id == tenant_id)
  61. if keyword:
  62. query = query.filter(Tag.name.ilike(f"%{keyword}%"))
  63. return query.count()
  64. @staticmethod
  65. def get_tags_by_target_id(tag_type: str, current_tenant_id: str, target_id: str) -> list:
  66. tags = (
  67. db.session.query(Tag)
  68. .join(TagBinding, Tag.id == TagBinding.tag_id)
  69. .filter(
  70. TagBinding.target_id == target_id,
  71. TagBinding.tenant_id == current_tenant_id,
  72. Tag.tenant_id == current_tenant_id,
  73. Tag.type == tag_type,
  74. )
  75. .all()
  76. )
  77. return tags or []
  78. @staticmethod
  79. def get_tag(tag_id: str):
  80. tag = db.session.query(Tag).filter(Tag.id == tag_id).first()
  81. if not tag:
  82. raise NotFound("Tag not found")
  83. return tag
  84. @staticmethod
  85. def get_page_tags(page, per_page, tag_type, tenant_id):
  86. query = (
  87. db.session.query(Tag.id, Tag.type, Tag.name, func.count(TagBinding.id).label("binding_count"))
  88. .outerjoin(TagBinding, Tag.id == TagBinding.tag_id)
  89. .filter(Tag.type == tag_type, Tag.tenant_id == tenant_id)
  90. )
  91. query = query.group_by(Tag.id, Tag.type, Tag.name)
  92. tags = query.paginate(page=page, per_page=per_page, error_out=False)
  93. return tags.items, tags.total
  94. @staticmethod
  95. def save_tags(args: dict) -> Tag:
  96. name = args["name"]
  97. type = args["type"]
  98. tenant_id = current_user.current_tenant_id
  99. tag = TagService.get_tag_by_tag_name(type, tenant_id, name)
  100. if tag:
  101. raise TagNameDuplicateError(f"Tag with name {name} already exists.")
  102. tag = Tag(
  103. id=str(uuid.uuid4()),
  104. name=name,
  105. type=type,
  106. created_by=current_user.id,
  107. tenant_id=tenant_id,
  108. )
  109. db.session.add(tag)
  110. db.session.commit()
  111. return tag
  112. @staticmethod
  113. def update_tags(args: dict, tag_id: str) -> Tag:
  114. tag = db.session.query(Tag).filter(Tag.id == tag_id).first()
  115. if not tag:
  116. raise NotFound("Tag not found")
  117. tag.name = args["name"]
  118. db.session.commit()
  119. return tag
  120. @staticmethod
  121. def get_tag_binding_count(tag_id: str) -> int:
  122. count = db.session.query(TagBinding).filter(TagBinding.tag_id == tag_id).count()
  123. return count
  124. @staticmethod
  125. def delete_tag(tag_id: str):
  126. tag = db.session.query(Tag).filter(Tag.id == tag_id).first()
  127. if not tag:
  128. raise NotFound("Tag not found")
  129. db.session.delete(tag)
  130. # delete tag binding
  131. tag_bindings = db.session.query(TagBinding).filter(TagBinding.tag_id == tag_id).all()
  132. if tag_bindings:
  133. for tag_binding in tag_bindings:
  134. db.session.delete(tag_binding)
  135. db.session.commit()
  136. @staticmethod
  137. def save_tag_binding(args):
  138. # 1.智能体设置可见授权的编辑权限一致,2.知识库的标签都能设置--修改为随设置权限一致
  139. TagService.check_target_edit_auth(args["type"], args["target_id"])
  140. # check if target exists
  141. TagService.check_target_exists(args["type"], args["target_id"])
  142. # save tag binding
  143. for tag_id in args["tag_ids"]:
  144. tag_binding = (
  145. db.session.query(TagBinding)
  146. .filter(TagBinding.tag_id == tag_id, TagBinding.target_id == args["target_id"])
  147. .first()
  148. )
  149. if tag_binding:
  150. continue
  151. new_tag_binding = TagBinding(
  152. tag_id=tag_id,
  153. target_id=args["target_id"],
  154. tenant_id=current_user.current_tenant_id,
  155. created_by=current_user.id,
  156. )
  157. db.session.add(new_tag_binding)
  158. db.session.commit()
  159. @staticmethod
  160. def delete_tag_binding(args):
  161. TagService.check_target_edit_auth(args["type"], args["target_id"])
  162. # check if target exists
  163. TagService.check_target_exists(args["type"], args["target_id"])
  164. # delete tag binding
  165. tag_bindings = (
  166. db.session.query(TagBinding)
  167. .filter(TagBinding.target_id == args["target_id"], TagBinding.tag_id == (args["tag_id"]))
  168. .first()
  169. )
  170. if tag_bindings:
  171. db.session.delete(tag_bindings)
  172. db.session.commit()
  173. @staticmethod
  174. def check_target_exists(type: str, target_id: str):
  175. if type in {"knowledge", "knowledge_category"}:
  176. dataset = (
  177. db.session.query(Dataset)
  178. .filter(Dataset.tenant_id == current_user.current_tenant_id, Dataset.id == target_id)
  179. .first()
  180. )
  181. if not dataset:
  182. raise NotFound("Dataset not found")
  183. elif type == "app":
  184. app = (
  185. db.session.query(App)
  186. .filter(App.tenant_id == current_user.current_tenant_id, App.id == target_id)
  187. .first()
  188. )
  189. if not app:
  190. raise NotFound("App not found")
  191. else:
  192. raise NotFound("Invalid binding type")
  193. @staticmethod
  194. def check_target_edit_auth(type: str, target_id: str):
  195. if type in {"knowledge", "knowledge_category"}:
  196. dataset = (
  197. db.session.query(Dataset)
  198. .filter(Dataset.id == target_id)
  199. .first()
  200. )
  201. if (
  202. current_user.current_role not in [TenantAccountRole.ADMIN, TenantAccountRole.OWNER]
  203. and dataset.created_by != current_user.id
  204. ):
  205. raise NoPermissionError("You do not have permission to operate this dataset.")
  206. elif type == "app":
  207. app = (
  208. db.session.query(AppPermissionAll)
  209. .filter(AppPermissionAll.has_read_permission == True,
  210. AppPermissionAll.account_id == current_user.id,
  211. AppPermissionAll.app_id == target_id)
  212. .first()
  213. )
  214. if not app:
  215. raise NoPermissionError("You do not have permission to operate this app.")
  216. else:
  217. raise NotFound("Invalid binding type")