123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189 |
- from typing import Literal
- from flask import request, send_file
- from flask_login import current_user # type: ignore
- from flask_restful import Resource, fields, marshal, marshal_with # type: ignore
- from werkzeug.exceptions import Forbidden, NotFound
- import services
- from controllers.common.errors import FilenameNotExistsError
- from controllers.console import api
- from controllers.console.app.error import (
- ProviderModelCurrentlyNotSupportError,
- ProviderNotInitializeError,
- ProviderQuotaExceededError,
- )
- from controllers.console.error import (
- FileTooLargeError,
- NoFileUploadedError,
- TooManyFilesError,
- UnsupportedFileTypeError,
- )
- from controllers.console.wraps import (
- account_initialization_required,
- cloud_edition_billing_resource_check,
- setup_required,
- )
- from core.errors.error import (
- ModelCurrentlyNotSupportError,
- ProviderTokenNotInitError,
- QuotaExceededError,
- )
- from fields.template_fields import (
- template_fields,
- )
- from libs.login import login_required
- from models import Template
- from services.dataset_service import DatasetService, TemplateService
- from services.file_service import FileService
- class DatasetTemplateListApi(Resource):
- @setup_required
- @login_required
- @account_initialization_required
- def get(self, dataset_id):
- dataset_id = str(dataset_id)
- page = request.args.get("page", default=1, type=int)
- limit = request.args.get("limit", default=20, type=int)
- search = request.args.get("keyword", default=None, type=str)
- dataset = DatasetService.get_dataset(dataset_id)
- if not dataset:
- raise NotFound("Dataset not found.")
- try:
- DatasetService.check_dataset_permission(dataset, current_user)
- except services.errors.account.NoPermissionError as e:
- raise Forbidden(str(e))
- query = Template.query.filter_by(dataset_id=str(dataset_id), tenant_id=current_user.current_tenant_id)
- if search:
- search = f"%{search}%"
- query = query.filter(Template.name.like(search))
- paginated_templates = query.paginate(page=page, per_page=limit, max_per_page=100, error_out=False)
- templates = paginated_templates.items
- data = marshal(templates, template_fields)
- response = {
- "data": data,
- "has_more": len(templates) == limit,
- "limit": limit,
- "total": paginated_templates.total,
- "page": page,
- }
- return response
- templates_and_batch_fields = {"templates": fields.List(fields.Nested(template_fields)), "batch": fields.String}
- @setup_required
- @login_required
- @account_initialization_required
- @marshal_with(templates_and_batch_fields)
- @cloud_edition_billing_resource_check("vector_space")
- def post(self, dataset_id):
- if not current_user.is_admin:
- raise Forbidden()
- dataset_id = str(dataset_id)
- dataset = DatasetService.get_dataset(dataset_id)
- if not dataset:
- raise NotFound("Dataset not found.")
- if not current_user.is_dataset_editor:
- raise Forbidden()
- try:
- DatasetService.check_dataset_permission(dataset, current_user)
- except services.errors.account.NoPermissionError as e:
- raise Forbidden(str(e))
- file = request.files["file"]
- source_str = request.form.get("source")
- source: Literal["datasets"] | None = "datasets" if source_str == "datasets" else None
- if "file" not in request.files:
- raise NoFileUploadedError()
- if len(request.files) > 1:
- raise TooManyFilesError()
- if not file.filename:
- raise FilenameNotExistsError
- if source == "datasets" and not current_user.is_dataset_editor:
- raise Forbidden()
- if source not in ("datasets", None):
- source = None
- try:
- upload_file = FileService.upload_file(
- filename=file.filename,
- content=file.read(),
- mimetype=file.mimetype,
- user=current_user,
- source=source,
- )
- except services.errors.file.FileTooLargeError as file_too_large_error:
- raise FileTooLargeError(file_too_large_error.description)
- except services.errors.file.UnsupportedFileTypeError:
- raise UnsupportedFileTypeError()
- try:
- templates, batch = TemplateService.save_template_with_dataset_id(upload_file, dataset, current_user)
- except ProviderTokenNotInitError as ex:
- raise ProviderNotInitializeError(ex.description)
- except QuotaExceededError:
- raise ProviderQuotaExceededError()
- except ModelCurrentlyNotSupportError:
- raise ProviderModelCurrentlyNotSupportError()
- return {"templates": templates, "batch": batch}
- @setup_required
- @login_required
- @account_initialization_required
- def delete(self, dataset_id):
- if not current_user.is_admin:
- raise Forbidden()
- dataset_id = str(dataset_id)
- dataset = DatasetService.get_dataset(dataset_id)
- if dataset is None:
- raise NotFound("Dataset not found.")
- DatasetService.check_dataset_model_setting(dataset)
- template_ids = request.args.getlist("template_id")
- TemplateService.delete_templates(dataset, template_ids)
- return {"result": "success"}, 204
- class DatasetTemplateApi(Resource):
- @setup_required
- @login_required
- @account_initialization_required
- def get(self, template_id):
- # The role of the current user in the ta table must be admin, owner, or editor
- if not current_user.is_editor:
- raise Forbidden()
- template_id = str(template_id)
- template = Template.query.filter_by(id=template_id).first()
- # as_attachment下载作为附件下载
- return send_file(template.file_url, as_attachment=True)
- @setup_required
- @login_required
- @account_initialization_required
- def delete(self, template_id):
- if not current_user.is_admin:
- raise Forbidden()
- template_id = str(template_id)
- template = TemplateService.get_templates(template_id)
- if template is None:
- raise NotFound("Dataset not found.")
- TemplateService.delete_template(template)
- return {"result": "success"}, 204
- api.add_resource(DatasetTemplateListApi, "/datasets/<uuid:dataset_id>/templates")
- api.add_resource(DatasetTemplateApi, "/datasets/template/<uuid:template_id>")
|