|
@@ -1,10 +1,15 @@
|
|
|
+import json
|
|
|
import logging
|
|
|
+import os
|
|
|
+import zipfile
|
|
|
|
|
|
-from flask import request
|
|
|
+from flask import request, send_file
|
|
|
from flask_restful import Resource, marshal, marshal_with, reqparse
|
|
|
from werkzeug.exceptions import Forbidden, NotFound
|
|
|
|
|
|
+import services
|
|
|
from controllers.console import api
|
|
|
+from controllers.console.error import FileTooLargeError, UnsupportedFileTypeError
|
|
|
from controllers.console.wraps import account_initialization_required, setup_required
|
|
|
from fields.intention_fields import (
|
|
|
intention_corpus_detail_fields,
|
|
@@ -13,17 +18,28 @@ from fields.intention_fields import (
|
|
|
intention_keyword_detail_fields,
|
|
|
intention_keyword_fields,
|
|
|
intention_page_fields,
|
|
|
+ intention_train_file_binding_fields,
|
|
|
+ intention_train_file_fields,
|
|
|
+ intention_train_task_fields,
|
|
|
intention_type_detail_fields,
|
|
|
intention_type_page_fields,
|
|
|
)
|
|
|
-from libs.login import login_required
|
|
|
+from libs.login import current_user, login_required
|
|
|
+from models import UploadFile
|
|
|
+from models.intention import IntentionTrainTask
|
|
|
+from services.errors.intention import IntentionTrainFileDuplicateError
|
|
|
+from services.file_service import FileService
|
|
|
from services.intention_service import (
|
|
|
IntentionCorpusService,
|
|
|
IntentionCorpusSimilarityQuestionService,
|
|
|
IntentionKeywordService,
|
|
|
IntentionService,
|
|
|
+ IntentionTrainFileBindingService,
|
|
|
+ IntentionTrainFileService,
|
|
|
+ IntentionTrainTaskService,
|
|
|
IntentionTypeService,
|
|
|
)
|
|
|
+from services.upload_file_service import UploadFileService
|
|
|
|
|
|
|
|
|
class IntentionListApi(Resource):
|
|
@@ -493,6 +509,188 @@ class IntentionCorpusSimilarityQuestionBatchApi(Resource):
|
|
|
else:
|
|
|
raise NotFound(f"method with name {method} not found")
|
|
|
|
|
|
+class IntentionTrainTaskListApi(Resource):
|
|
|
+ @setup_required
|
|
|
+ @login_required
|
|
|
+ @account_initialization_required
|
|
|
+ def get(self):
|
|
|
+ page = request.args.get("page", default=1, type=int)
|
|
|
+ limit = request.args.get("limit", default=20, type=int)
|
|
|
+ search = request.args.get("search", default=None, type=str)
|
|
|
+ intention_train_tasks, total = IntentionTrainTaskService.get_page_intention_train_tasks(
|
|
|
+ page, limit, search)
|
|
|
+ data = marshal(intention_train_tasks, intention_train_task_fields)
|
|
|
+ response = {"data": data, "has_more": len(intention_train_tasks) == limit, "limit": limit,
|
|
|
+ "total": total, "page": page}
|
|
|
+ return response, 200
|
|
|
+
|
|
|
+ @setup_required
|
|
|
+ @login_required
|
|
|
+ @account_initialization_required
|
|
|
+ def post(self):
|
|
|
+ parser = reqparse.RequestParser()
|
|
|
+ parser.add_argument(
|
|
|
+ "name",
|
|
|
+ nullable=False,
|
|
|
+ required=True,
|
|
|
+ help="name is required.",
|
|
|
+ location="json",
|
|
|
+ )
|
|
|
+ parser.add_argument(
|
|
|
+ "status",
|
|
|
+ nullable=False,
|
|
|
+ required=True,
|
|
|
+ help="status is required.",
|
|
|
+ choices=IntentionTrainTask.STATUS_LIST,
|
|
|
+ location="json",
|
|
|
+ )
|
|
|
+ args = parser.parse_args()
|
|
|
+ train_task = IntentionTrainTaskService.save_train_task(args)
|
|
|
+ return marshal(train_task, intention_train_task_fields), 200
|
|
|
+
|
|
|
+class IntentionTrainTaskApi(Resource):
|
|
|
+ @setup_required
|
|
|
+ @login_required
|
|
|
+ @account_initialization_required
|
|
|
+ def patch(self, task_id):
|
|
|
+ parser = reqparse.RequestParser()
|
|
|
+ parser.add_argument(
|
|
|
+ "name",
|
|
|
+ nullable=False,
|
|
|
+ required=True,
|
|
|
+ help="name is required.",
|
|
|
+ location="json",
|
|
|
+ )
|
|
|
+ parser.add_argument(
|
|
|
+ "status",
|
|
|
+ nullable=False,
|
|
|
+ required=True,
|
|
|
+ help="status is required.",
|
|
|
+ choices=IntentionTrainTask.STATUS_LIST,
|
|
|
+ location="json",
|
|
|
+ )
|
|
|
+ args = parser.parse_args()
|
|
|
+
|
|
|
+ train_task = IntentionTrainTaskService.update_train_task(task_id, args)
|
|
|
+ return marshal(train_task, intention_train_task_fields), 200
|
|
|
+
|
|
|
+class IntentionTrainTaskDownloadApi(Resource):
|
|
|
+ @setup_required
|
|
|
+ @login_required
|
|
|
+ @account_initialization_required
|
|
|
+ def get(self, task_id):
|
|
|
+ train_task = IntentionTrainTaskService.get_train_task(task_id)
|
|
|
+ if train_task.status != "COMPLETED":
|
|
|
+ raise Forbidden(f"Task with id {task_id} not completed")
|
|
|
+
|
|
|
+ dataset_info = train_task.dataset_info
|
|
|
+ dataset_source_info = json.loads(dataset_info.data_source_info)
|
|
|
+ dataset_file_id = dataset_source_info["upload_file_id"]
|
|
|
+ dataset_upload_file: UploadFile = UploadFileService.get_upload_file(dataset_file_id)
|
|
|
+
|
|
|
+ model_info = train_task.model_info
|
|
|
+ model_source_info = json.loads(model_info.data_source_info)
|
|
|
+ model_file_id = model_source_info["upload_file_id"]
|
|
|
+ model_upload_file: UploadFile = UploadFileService.get_upload_file(model_file_id)
|
|
|
+
|
|
|
+ def file2zip(zip_filename: str, upload_files: list[UploadFile]):
|
|
|
+ with zipfile.ZipFile(zip_filename, "w", compression=zipfile.ZIP_DEFLATED) as zip_file:
|
|
|
+ for upload_file in upload_files:
|
|
|
+ filename = f"storage/{dataset_upload_file.key}"
|
|
|
+ zip_file.write(filename, arcname=upload_file.name)
|
|
|
+
|
|
|
+ # 生成待下载的zip包
|
|
|
+ zip_filename = f"{train_task.name}.zip"
|
|
|
+ upload_files: list[UploadFile] = [dataset_upload_file, model_upload_file]
|
|
|
+ file2zip(zip_filename, upload_files)
|
|
|
+
|
|
|
+ # 下载zip包
|
|
|
+ response = send_file(zip_filename, as_attachment=True, download_name=zip_filename)
|
|
|
+ os.remove(zip_filename)
|
|
|
+ return response
|
|
|
+
|
|
|
+class IntentionTrainFileApi(Resource):
|
|
|
+ @setup_required
|
|
|
+ @login_required
|
|
|
+ @account_initialization_required
|
|
|
+ def get(self):
|
|
|
+ name = request.args.get("name", default=None, type=str)
|
|
|
+ version = request.args.get("version", default=None, type=str)
|
|
|
+ type = request.args.get("type", default=None, type=str)
|
|
|
+ train_files = IntentionTrainFileService.get_train_files(name, version, type)
|
|
|
+ return marshal(train_files, intention_train_file_fields), 200
|
|
|
+
|
|
|
+ @setup_required
|
|
|
+ @login_required
|
|
|
+ @account_initialization_required
|
|
|
+ def post(self):
|
|
|
+ name = request.form.get("name")
|
|
|
+ version = request.form.get("version")
|
|
|
+ type = request.form.get("type")
|
|
|
+ train_file = IntentionTrainFileService.get_train_file(name, version, type)
|
|
|
+ if train_file:
|
|
|
+ raise IntentionTrainFileDuplicateError(f"IntentionTrainFile with name-version-type "
|
|
|
+ f"{name}-{version}-{type} already exists.")
|
|
|
+
|
|
|
+ data_source_type = request.form.get("data_source_type")
|
|
|
+
|
|
|
+ # get file from request
|
|
|
+ file = request.files["file"]
|
|
|
+ filename = file.filename
|
|
|
+ mimetype = file.mimetype
|
|
|
+ if not filename or not mimetype:
|
|
|
+ raise Forbidden("Invalid request.")
|
|
|
+
|
|
|
+ try:
|
|
|
+ upload_file = FileService.upload_file(
|
|
|
+ filename=filename,
|
|
|
+ content=file.read(),
|
|
|
+ mimetype=mimetype,
|
|
|
+ user=current_user,
|
|
|
+ source=None,
|
|
|
+ )
|
|
|
+
|
|
|
+ args = {
|
|
|
+ "name": name,
|
|
|
+ "version": version,
|
|
|
+ "type": type,
|
|
|
+ "data_source_type": data_source_type,
|
|
|
+ "data_source_info": {
|
|
|
+ "upload_file_id": upload_file.id
|
|
|
+ }
|
|
|
+ }
|
|
|
+ intention_train_file = IntentionTrainFileService.save_train_file(args)
|
|
|
+ return marshal(intention_train_file, intention_train_file_fields), 200
|
|
|
+ except services.errors.file.FileTooLargeError as file_too_large_error:
|
|
|
+ raise FileTooLargeError(file_too_large_error.description)
|
|
|
+ except services.errors.file.UnsupportedFileTypeError:
|
|
|
+ raise UnsupportedFileTypeError()
|
|
|
+
|
|
|
+class IntentionTrainFileBindingApi(Resource):
|
|
|
+ @setup_required
|
|
|
+ @login_required
|
|
|
+ @account_initialization_required
|
|
|
+ def post(self):
|
|
|
+ parser = reqparse.RequestParser()
|
|
|
+ parser.add_argument(
|
|
|
+ "file_id",
|
|
|
+ nullable=False,
|
|
|
+ required=True,
|
|
|
+ help="file_id is required.",
|
|
|
+ location="json",
|
|
|
+ )
|
|
|
+ parser.add_argument(
|
|
|
+ "task_id",
|
|
|
+ nullable=False,
|
|
|
+ required=True,
|
|
|
+ help="task_id is required.",
|
|
|
+ location="json",
|
|
|
+ )
|
|
|
+ args = parser.parse_args()
|
|
|
+
|
|
|
+ train_file_binding = IntentionTrainFileBindingService.save_train_file_binding(args)
|
|
|
+ return marshal(train_file_binding, intention_train_file_binding_fields), 200
|
|
|
+
|
|
|
api.add_resource(IntentionListApi, "/intentions")
|
|
|
api.add_resource(IntentionApi, "/intentions/<uuid:intention_id>")
|
|
|
api.add_resource(IntentionTypeListApi, "/intentions/types")
|
|
@@ -509,6 +707,13 @@ api.add_resource(IntentionCorpusSimilarityQuestionUpdateAndDeleteApi,
|
|
|
"/intentions/similarity_questions/<uuid:similarity_question_id>")
|
|
|
api.add_resource(IntentionCorpusSimilarityQuestionBatchApi, "/intentions/similarity_questions/batch")
|
|
|
|
|
|
+api.add_resource(IntentionTrainTaskListApi, "/intentions/train_tasks")
|
|
|
+api.add_resource(IntentionTrainTaskApi, "/intentions/train_tasks/<uuid:task_id>")
|
|
|
+api.add_resource(IntentionTrainTaskDownloadApi, "/intentions/train_tasks/download/<uuid:task_id>")
|
|
|
+
|
|
|
+api.add_resource(IntentionTrainFileApi, "/intentions/train_files")
|
|
|
+
|
|
|
+api.add_resource(IntentionTrainFileBindingApi, "/intentions/train_file_bindings")
|
|
|
|
|
|
|
|
|
|