datasets_templates.py 6.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189
  1. from typing import Literal
  2. from flask import request, send_file
  3. from flask_login import current_user # type: ignore
  4. from flask_restful import Resource, fields, marshal, marshal_with # type: ignore
  5. from werkzeug.exceptions import Forbidden, NotFound
  6. import services
  7. from controllers.common.errors import FilenameNotExistsError
  8. from controllers.console import api
  9. from controllers.console.app.error import (
  10. ProviderModelCurrentlyNotSupportError,
  11. ProviderNotInitializeError,
  12. ProviderQuotaExceededError,
  13. )
  14. from controllers.console.error import (
  15. FileTooLargeError,
  16. NoFileUploadedError,
  17. TooManyFilesError,
  18. UnsupportedFileTypeError,
  19. )
  20. from controllers.console.wraps import (
  21. account_initialization_required,
  22. cloud_edition_billing_resource_check,
  23. setup_required,
  24. )
  25. from core.errors.error import (
  26. ModelCurrentlyNotSupportError,
  27. ProviderTokenNotInitError,
  28. QuotaExceededError,
  29. )
  30. from fields.template_fields import (
  31. template_fields,
  32. )
  33. from libs.login import login_required
  34. from models import Template
  35. from services.dataset_service import DatasetService, TemplateService
  36. from services.file_service import FileService
  37. class DatasetTemplateListApi(Resource):
  38. @setup_required
  39. @login_required
  40. @account_initialization_required
  41. def get(self, dataset_id):
  42. dataset_id = str(dataset_id)
  43. page = request.args.get("page", default=1, type=int)
  44. limit = request.args.get("limit", default=20, type=int)
  45. search = request.args.get("keyword", default=None, type=str)
  46. dataset = DatasetService.get_dataset(dataset_id)
  47. if not dataset:
  48. raise NotFound("Dataset not found.")
  49. try:
  50. DatasetService.check_dataset_permission(dataset, current_user)
  51. except services.errors.account.NoPermissionError as e:
  52. raise Forbidden(str(e))
  53. query = Template.query.filter_by(dataset_id=str(dataset_id), tenant_id=current_user.current_tenant_id)
  54. if search:
  55. search = f"%{search}%"
  56. query = query.filter(Template.name.like(search))
  57. paginated_templates = query.paginate(page=page, per_page=limit, max_per_page=100, error_out=False)
  58. templates = paginated_templates.items
  59. data = marshal(templates, template_fields)
  60. response = {
  61. "data": data,
  62. "has_more": len(templates) == limit,
  63. "limit": limit,
  64. "total": paginated_templates.total,
  65. "page": page,
  66. }
  67. return response
  68. templates_and_batch_fields = {"templates": fields.List(fields.Nested(template_fields)), "batch": fields.String}
  69. @setup_required
  70. @login_required
  71. @account_initialization_required
  72. @marshal_with(templates_and_batch_fields)
  73. @cloud_edition_billing_resource_check("vector_space")
  74. def post(self, dataset_id):
  75. if not current_user.is_admin:
  76. raise Forbidden()
  77. dataset_id = str(dataset_id)
  78. dataset = DatasetService.get_dataset(dataset_id)
  79. if not dataset:
  80. raise NotFound("Dataset not found.")
  81. if not current_user.is_dataset_editor:
  82. raise Forbidden()
  83. try:
  84. DatasetService.check_dataset_permission(dataset, current_user)
  85. except services.errors.account.NoPermissionError as e:
  86. raise Forbidden(str(e))
  87. file = request.files["file"]
  88. source_str = request.form.get("source")
  89. source: Literal["datasets"] | None = "datasets" if source_str == "datasets" else None
  90. if "file" not in request.files:
  91. raise NoFileUploadedError()
  92. if len(request.files) > 1:
  93. raise TooManyFilesError()
  94. if not file.filename:
  95. raise FilenameNotExistsError
  96. if source == "datasets" and not current_user.is_dataset_editor:
  97. raise Forbidden()
  98. if source not in ("datasets", None):
  99. source = None
  100. try:
  101. upload_file = FileService.upload_file(
  102. filename=file.filename,
  103. content=file.read(),
  104. mimetype=file.mimetype,
  105. user=current_user,
  106. source=source,
  107. )
  108. except services.errors.file.FileTooLargeError as file_too_large_error:
  109. raise FileTooLargeError(file_too_large_error.description)
  110. except services.errors.file.UnsupportedFileTypeError:
  111. raise UnsupportedFileTypeError()
  112. try:
  113. templates, batch = TemplateService.save_template_with_dataset_id(upload_file, dataset, current_user)
  114. except ProviderTokenNotInitError as ex:
  115. raise ProviderNotInitializeError(ex.description)
  116. except QuotaExceededError:
  117. raise ProviderQuotaExceededError()
  118. except ModelCurrentlyNotSupportError:
  119. raise ProviderModelCurrentlyNotSupportError()
  120. return {"templates": templates, "batch": batch}
  121. @setup_required
  122. @login_required
  123. @account_initialization_required
  124. def delete(self, dataset_id):
  125. if not current_user.is_admin:
  126. raise Forbidden()
  127. dataset_id = str(dataset_id)
  128. dataset = DatasetService.get_dataset(dataset_id)
  129. if dataset is None:
  130. raise NotFound("Dataset not found.")
  131. DatasetService.check_dataset_model_setting(dataset)
  132. template_ids = request.args.getlist("template_id")
  133. TemplateService.delete_templates(dataset, template_ids)
  134. return {"result": "success"}, 204
  135. class DatasetTemplateApi(Resource):
  136. @setup_required
  137. @login_required
  138. @account_initialization_required
  139. def get(self, template_id):
  140. # The role of the current user in the ta table must be admin, owner, or editor
  141. if not current_user.is_editor:
  142. raise Forbidden()
  143. template_id = str(template_id)
  144. template = Template.query.filter_by(id=template_id).first()
  145. # as_attachment下载作为附件下载
  146. return send_file(template.file_url, as_attachment=True)
  147. @setup_required
  148. @login_required
  149. @account_initialization_required
  150. def delete(self, template_id):
  151. if not current_user.is_admin:
  152. raise Forbidden()
  153. template_id = str(template_id)
  154. template = TemplateService.get_templates(template_id)
  155. if template is None:
  156. raise NotFound("Dataset not found.")
  157. TemplateService.delete_template(template)
  158. return {"result": "success"}, 204
  159. api.add_resource(DatasetTemplateListApi, "/datasets/<uuid:dataset_id>/templates")
  160. api.add_resource(DatasetTemplateApi, "/datasets/template/<uuid:template_id>")