from typing import Literal from flask import request, send_file from flask_login import current_user # type: ignore from flask_restful import Resource, fields, marshal, marshal_with # type: ignore from werkzeug.exceptions import Forbidden, NotFound import services from controllers.common.errors import FilenameNotExistsError from controllers.console import api from controllers.console.app.error import ( ProviderModelCurrentlyNotSupportError, ProviderNotInitializeError, ProviderQuotaExceededError, ) from controllers.console.error import ( FileTooLargeError, NoFileUploadedError, TooManyFilesError, UnsupportedFileTypeError, ) from controllers.console.wraps import ( account_initialization_required, cloud_edition_billing_resource_check, setup_required, ) from core.errors.error import ( ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError, ) from fields.template_fields import ( template_fields, ) from libs.login import login_required from models import Template from services.dataset_service import DatasetService, TemplateService from services.file_service import FileService class DatasetTemplateListApi(Resource): @setup_required @login_required @account_initialization_required def get(self, dataset_id): dataset_id = str(dataset_id) page = request.args.get("page", default=1, type=int) limit = request.args.get("limit", default=20, type=int) search = request.args.get("keyword", default=None, type=str) dataset = DatasetService.get_dataset(dataset_id) if not dataset: raise NotFound("Dataset not found.") try: DatasetService.check_dataset_permission(dataset, current_user) except services.errors.account.NoPermissionError as e: raise Forbidden(str(e)) query = Template.query.filter_by(dataset_id=str(dataset_id), tenant_id=current_user.current_tenant_id) if search: search = f"%{search}%" query = query.filter(Template.name.like(search)) paginated_templates = query.paginate(page=page, per_page=limit, max_per_page=100, error_out=False) templates = paginated_templates.items data = marshal(templates, template_fields) response = { "data": data, "has_more": len(templates) == limit, "limit": limit, "total": paginated_templates.total, "page": page, } return response templates_and_batch_fields = {"templates": fields.List(fields.Nested(template_fields)), "batch": fields.String} @setup_required @login_required @account_initialization_required @marshal_with(templates_and_batch_fields) @cloud_edition_billing_resource_check("vector_space") def post(self, dataset_id): if not current_user.is_admin: raise Forbidden() dataset_id = str(dataset_id) dataset = DatasetService.get_dataset(dataset_id) if not dataset: raise NotFound("Dataset not found.") if not current_user.is_dataset_editor: raise Forbidden() try: DatasetService.check_dataset_permission(dataset, current_user) except services.errors.account.NoPermissionError as e: raise Forbidden(str(e)) file = request.files["file"] source_str = request.form.get("source") source: Literal["datasets"] | None = "datasets" if source_str == "datasets" else None if "file" not in request.files: raise NoFileUploadedError() if len(request.files) > 1: raise TooManyFilesError() if not file.filename: raise FilenameNotExistsError if source == "datasets" and not current_user.is_dataset_editor: raise Forbidden() if source not in ("datasets", None): source = None try: upload_file = FileService.upload_file( filename=file.filename, content=file.read(), mimetype=file.mimetype, user=current_user, source=source, ) except services.errors.file.FileTooLargeError as file_too_large_error: raise FileTooLargeError(file_too_large_error.description) except services.errors.file.UnsupportedFileTypeError: raise UnsupportedFileTypeError() try: templates, batch = TemplateService.save_template_with_dataset_id(upload_file, dataset, current_user) except ProviderTokenNotInitError as ex: raise ProviderNotInitializeError(ex.description) except QuotaExceededError: raise ProviderQuotaExceededError() except ModelCurrentlyNotSupportError: raise ProviderModelCurrentlyNotSupportError() return {"templates": templates, "batch": batch} @setup_required @login_required @account_initialization_required def delete(self, dataset_id): if not current_user.is_admin: raise Forbidden() dataset_id = str(dataset_id) dataset = DatasetService.get_dataset(dataset_id) if dataset is None: raise NotFound("Dataset not found.") DatasetService.check_dataset_model_setting(dataset) template_ids = request.args.getlist("template_id") TemplateService.delete_templates(dataset, template_ids) return {"result": "success"}, 204 class DatasetTemplateApi(Resource): @setup_required @login_required @account_initialization_required def get(self, template_id): # The role of the current user in the ta table must be admin, owner, or editor if not current_user.is_editor: raise Forbidden() template_id = str(template_id) template = Template.query.filter_by(id=template_id).first() # as_attachment下载作为附件下载 return send_file(template.file_url, as_attachment=True) @setup_required @login_required @account_initialization_required def delete(self, template_id): if not current_user.is_admin: raise Forbidden() template_id = str(template_id) template = TemplateService.get_templates(template_id) if template is None: raise NotFound("Dataset not found.") TemplateService.delete_template(template) return {"result": "success"}, 204 api.add_resource(DatasetTemplateListApi, "/datasets//templates") api.add_resource(DatasetTemplateApi, "/datasets/template/")