liangxunge 2 miesięcy temu
rodzic
commit
47e65995f7

+ 207 - 2
api/controllers/console/intention.py

@@ -1,10 +1,15 @@
+import json
 import logging
+import os
+import zipfile
 
-from flask import request
+from flask import request, send_file
 from flask_restful import Resource, marshal, marshal_with, reqparse
 from werkzeug.exceptions import Forbidden, NotFound
 
+import services
 from controllers.console import api
+from controllers.console.error import FileTooLargeError, UnsupportedFileTypeError
 from controllers.console.wraps import account_initialization_required, setup_required
 from fields.intention_fields import (
     intention_corpus_detail_fields,
@@ -13,17 +18,28 @@ from fields.intention_fields import (
     intention_keyword_detail_fields,
     intention_keyword_fields,
     intention_page_fields,
+    intention_train_file_binding_fields,
+    intention_train_file_fields,
+    intention_train_task_fields,
     intention_type_detail_fields,
     intention_type_page_fields,
 )
-from libs.login import login_required
+from libs.login import current_user, login_required
+from models import UploadFile
+from models.intention import IntentionTrainTask
+from services.errors.intention import IntentionTrainFileDuplicateError
+from services.file_service import FileService
 from services.intention_service import (
     IntentionCorpusService,
     IntentionCorpusSimilarityQuestionService,
     IntentionKeywordService,
     IntentionService,
+    IntentionTrainFileBindingService,
+    IntentionTrainFileService,
+    IntentionTrainTaskService,
     IntentionTypeService,
 )
+from services.upload_file_service import UploadFileService
 
 
 class IntentionListApi(Resource):
@@ -493,6 +509,188 @@ class IntentionCorpusSimilarityQuestionBatchApi(Resource):
         else:
             raise NotFound(f"method with name {method} not found")
 
+class IntentionTrainTaskListApi(Resource):
+    @setup_required
+    @login_required
+    @account_initialization_required
+    def get(self):
+        page = request.args.get("page", default=1, type=int)
+        limit = request.args.get("limit", default=20, type=int)
+        search = request.args.get("search", default=None, type=str)
+        intention_train_tasks, total = IntentionTrainTaskService.get_page_intention_train_tasks(
+            page, limit, search)
+        data = marshal(intention_train_tasks, intention_train_task_fields)
+        response = {"data": data, "has_more": len(intention_train_tasks) == limit, "limit": limit,
+                    "total": total, "page": page}
+        return response, 200
+
+    @setup_required
+    @login_required
+    @account_initialization_required
+    def post(self):
+        parser = reqparse.RequestParser()
+        parser.add_argument(
+            "name",
+            nullable=False,
+            required=True,
+            help="name is required.",
+            location="json",
+        )
+        parser.add_argument(
+            "status",
+            nullable=False,
+            required=True,
+            help="status is required.",
+            choices=IntentionTrainTask.STATUS_LIST,
+            location="json",
+        )
+        args = parser.parse_args()
+        train_task = IntentionTrainTaskService.save_train_task(args)
+        return marshal(train_task, intention_train_task_fields), 200
+
+class IntentionTrainTaskApi(Resource):
+    @setup_required
+    @login_required
+    @account_initialization_required
+    def patch(self, task_id):
+        parser = reqparse.RequestParser()
+        parser.add_argument(
+            "name",
+            nullable=False,
+            required=True,
+            help="name is required.",
+            location="json",
+        )
+        parser.add_argument(
+            "status",
+            nullable=False,
+            required=True,
+            help="status is required.",
+            choices=IntentionTrainTask.STATUS_LIST,
+            location="json",
+        )
+        args = parser.parse_args()
+
+        train_task = IntentionTrainTaskService.update_train_task(task_id, args)
+        return marshal(train_task, intention_train_task_fields), 200
+
+class IntentionTrainTaskDownloadApi(Resource):
+    @setup_required
+    @login_required
+    @account_initialization_required
+    def get(self, task_id):
+        train_task = IntentionTrainTaskService.get_train_task(task_id)
+        if train_task.status != "COMPLETED":
+            raise Forbidden(f"Task with id {task_id} not completed")
+
+        dataset_info = train_task.dataset_info
+        dataset_source_info = json.loads(dataset_info.data_source_info)
+        dataset_file_id = dataset_source_info["upload_file_id"]
+        dataset_upload_file: UploadFile = UploadFileService.get_upload_file(dataset_file_id)
+
+        model_info = train_task.model_info
+        model_source_info = json.loads(model_info.data_source_info)
+        model_file_id = model_source_info["upload_file_id"]
+        model_upload_file: UploadFile = UploadFileService.get_upload_file(model_file_id)
+
+        def file2zip(zip_filename: str, upload_files: list[UploadFile]):
+            with zipfile.ZipFile(zip_filename, "w", compression=zipfile.ZIP_DEFLATED) as zip_file:
+                for upload_file in upload_files:
+                    filename = f"storage/{dataset_upload_file.key}"
+                    zip_file.write(filename, arcname=upload_file.name)
+
+        # 生成待下载的zip包
+        zip_filename = f"{train_task.name}.zip"
+        upload_files: list[UploadFile] = [dataset_upload_file, model_upload_file]
+        file2zip(zip_filename, upload_files)
+
+        # 下载zip包
+        response = send_file(zip_filename, as_attachment=True, download_name=zip_filename)
+        os.remove(zip_filename)
+        return response
+
+class IntentionTrainFileApi(Resource):
+    @setup_required
+    @login_required
+    @account_initialization_required
+    def get(self):
+        name = request.args.get("name", default=None, type=str)
+        version = request.args.get("version", default=None, type=str)
+        type = request.args.get("type", default=None, type=str)
+        train_files = IntentionTrainFileService.get_train_files(name, version, type)
+        return marshal(train_files, intention_train_file_fields), 200
+
+    @setup_required
+    @login_required
+    @account_initialization_required
+    def post(self):
+        name = request.form.get("name")
+        version = request.form.get("version")
+        type = request.form.get("type")
+        train_file = IntentionTrainFileService.get_train_file(name, version, type)
+        if train_file:
+            raise IntentionTrainFileDuplicateError(f"IntentionTrainFile with name-version-type "
+                                                   f"{name}-{version}-{type} already exists.")
+
+        data_source_type = request.form.get("data_source_type")
+
+        # get file from request
+        file = request.files["file"]
+        filename = file.filename
+        mimetype = file.mimetype
+        if not filename or not mimetype:
+            raise Forbidden("Invalid request.")
+
+        try:
+            upload_file = FileService.upload_file(
+                filename=filename,
+                content=file.read(),
+                mimetype=mimetype,
+                user=current_user,
+                source=None,
+            )
+
+            args = {
+                "name": name,
+                "version": version,
+                "type": type,
+                "data_source_type": data_source_type,
+                "data_source_info": {
+                    "upload_file_id": upload_file.id
+                }
+            }
+            intention_train_file = IntentionTrainFileService.save_train_file(args)
+            return marshal(intention_train_file, intention_train_file_fields), 200
+        except services.errors.file.FileTooLargeError as file_too_large_error:
+            raise FileTooLargeError(file_too_large_error.description)
+        except services.errors.file.UnsupportedFileTypeError:
+            raise UnsupportedFileTypeError()
+
+class IntentionTrainFileBindingApi(Resource):
+    @setup_required
+    @login_required
+    @account_initialization_required
+    def post(self):
+        parser = reqparse.RequestParser()
+        parser.add_argument(
+            "file_id",
+            nullable=False,
+            required=True,
+            help="file_id is required.",
+            location="json",
+        )
+        parser.add_argument(
+            "task_id",
+            nullable=False,
+            required=True,
+            help="task_id is required.",
+            location="json",
+        )
+        args = parser.parse_args()
+
+        train_file_binding = IntentionTrainFileBindingService.save_train_file_binding(args)
+        return marshal(train_file_binding, intention_train_file_binding_fields), 200
+
 api.add_resource(IntentionListApi, "/intentions")
 api.add_resource(IntentionApi, "/intentions/<uuid:intention_id>")
 api.add_resource(IntentionTypeListApi, "/intentions/types")
@@ -509,6 +707,13 @@ api.add_resource(IntentionCorpusSimilarityQuestionUpdateAndDeleteApi,
                  "/intentions/similarity_questions/<uuid:similarity_question_id>")
 api.add_resource(IntentionCorpusSimilarityQuestionBatchApi, "/intentions/similarity_questions/batch")
 
+api.add_resource(IntentionTrainTaskListApi, "/intentions/train_tasks")
+api.add_resource(IntentionTrainTaskApi, "/intentions/train_tasks/<uuid:task_id>")
+api.add_resource(IntentionTrainTaskDownloadApi, "/intentions/train_tasks/download/<uuid:task_id>")
+
+api.add_resource(IntentionTrainFileApi, "/intentions/train_files")
+
+api.add_resource(IntentionTrainFileBindingApi, "/intentions/train_file_bindings")
 
 
 

+ 26 - 0
api/fields/intention_fields.py

@@ -94,4 +94,30 @@ intention_detail_fields = {
     "updated_at": TimestampField,
 }
 
+intention_train_file_fields = {
+    "id": fields.String,
+    "name": fields.String,
+    "version": fields.String,
+    "type": fields.String,
+    "data_source_type": fields.String,
+    "data_source_info": fields.String,
+    "created_by": fields.String,
+    "created_at": TimestampField,
+}
+
+intention_train_file_binding_fields = {
+    "id": fields.String,
+    "file_id": fields.String,
+    "task_id": fields.String,
+}
+
+intention_train_task_fields = {
+    "id": fields.String,
+    "name": fields.String,
+    "status": fields.String,
+    "dataset_info": fields.Nested(intention_train_file_fields, allow_null=True),
+    "model_info": fields.Nested(intention_train_file_fields, allow_null=True),
+    "created_by": fields.String,
+    "created_at": TimestampField,
+}
 

+ 66 - 0
api/models/intention.py

@@ -172,5 +172,71 @@ class IntentionCorpusSimilarityQuestion(db.Model):
             .first()
         )
 
+class IntentionTrainTask(db.Model):
+    __tablename__ = "intention_train_tasks"
+    __table_args__ = (
+        db.PrimaryKeyConstraint('id', name='intention_train_log_pkey'),
+    )
+
+    STATUS_LIST = ["CREATED", "TRAINING", "COMPLETED"]
+
+    id = db.Column(StringUUID, nullable=False, server_default=db.text("uuid_generate_v4()"))
+    name = db.Column(db.String(255), nullable=False)
+    status = db.Column(db.String(255), nullable=False)
+    created_by = db.Column(StringUUID, nullable=False)
+    created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
+
+    @property
+    def dataset_info(self):
+        dataset_info = (
+            db.session.query(IntentionTrainFile)
+            .join(IntentionTrainFileBinding, IntentionTrainFileBinding.file_id == IntentionTrainFile.id)
+            .filter(
+                IntentionTrainFile.type == "DATASET",
+                IntentionTrainFileBinding.task_id == self.id
+            )
+            .first()
+        )
+        return dataset_info
+
+    @property
+    def model_info(self):
+        model_info = (
+            db.session.query(IntentionTrainFile)
+            .join(IntentionTrainFileBinding, IntentionTrainFileBinding.file_id == IntentionTrainFile.id)
+            .filter(
+                IntentionTrainFile.type == "MODEL",
+                IntentionTrainFileBinding.task_id == self.id
+            )
+            .first()
+        )
+        return model_info
+
+class IntentionTrainFile(db.Model):
+    __tablename__ = "intention_train_files"
+    __table_args__ = (
+        db.PrimaryKeyConstraint('id', name='intention_train_file_pkey'),
+    )
+
+    TRAIN_FILE_TYPE_LIST = ["DATASET", "MODEL"]
+
+    id = db.Column(StringUUID, nullable=False, server_default=db.text("uuid_generate_v4()"))
+    name = db.Column(db.String(255), nullable=False)
+    version = db.Column(db.String(255), nullable=False)
+    type = db.Column(db.String(255), nullable=False)
+    data_source_type = db.Column(db.String(255), nullable=False)
+    data_source_info = db.Column(db.JSON, nullable=False)
+    created_by = db.Column(StringUUID, nullable=False)
+    created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
 
+class IntentionTrainFileBinding(db.Model):
+    __tablename__ = "intention_train_file_bindings"
+    __table_args__ = (
+        db.PrimaryKeyConstraint('id', name='intention_train_file_binding_pkey'),
+        db.Index("intention_train_file_binding_idx", "file_id"),
+        db.Index("intention_train_file_binding_task_idx", "task_id"),
+    )
 
+    id = db.Column(StringUUID, nullable=False, server_default=db.text("uuid_generate_v4()"))
+    file_id = db.Column(StringUUID, nullable=False)
+    task_id = db.Column(StringUUID, nullable=False)

+ 3 - 0
api/services/errors/intention.py

@@ -14,4 +14,7 @@ class IntentionCorpusQuestionDuplicateError(BaseServiceError):
     pass
 
 class IntentionCorpusSimilarityQuestionDuplicateError(BaseServiceError):
+    pass
+
+class IntentionTrainFileDuplicateError(BaseServiceError):
     pass

+ 120 - 1
api/services/intention_service.py

@@ -12,12 +12,16 @@ from models.intention import (
     IntentionCorpus,
     IntentionCorpusSimilarityQuestion,
     IntentionKeyword,
+    IntentionTrainFile,
+    IntentionTrainFileBinding,
+    IntentionTrainTask,
     IntentionType,
 )
 from services.errors.intention import (
     IntentionCorpusQuestionDuplicateError,
     IntentionKeywordNameDuplicateError,
     IntentionNameDuplicateError,
+    IntentionTrainFileDuplicateError,
     IntentionTypeNameDuplicateError,
 )
 
@@ -499,4 +503,119 @@ class IntentionCorpusSimilarityQuestionService:
 
         for similarity_question in similarity_questions:
             db.session.delete(similarity_question)
-        db.session.commit()
+        db.session.commit()
+
+class IntentionTrainTaskService:
+
+    @staticmethod
+    def get_page_intention_train_tasks(page, per_page, search=None):
+        query = (
+            IntentionTrainTask.query.order_by(IntentionTrainTask.created_at.desc())
+        )
+        if search:
+            query = query.filter(IntentionTrainTask.name.ilike(f"%{search}%"))
+
+        intention_train_tasks = query.paginate(page=page, per_page=per_page, error_out=False)
+        return intention_train_tasks.items, intention_train_tasks.total
+
+    @staticmethod
+    def get_train_task(train_task_id: str) -> Optional[IntentionTrainTask]:
+        train_task: Optional[IntentionTrainTask] = (
+            IntentionTrainTask.query.filter_by(id=train_task_id).first()
+        )
+        return train_task
+
+    @staticmethod
+    def save_train_task(args: dict):
+        train_task = IntentionTrainTask(
+            id=str(uuid.uuid4()),
+            name=args["name"],
+            status=args["status"],
+            created_by=current_user.id,
+            created_at=datetime.now(),
+        )
+        db.session.add(train_task)
+        db.session.commit()
+        return train_task
+
+    @staticmethod
+    def update_train_task(task_id: str, args: dict):
+        train_task = IntentionTrainTaskService.get_train_task(task_id)
+        if not train_task:
+            raise NotFound(f"IntentionTrainTask with id {task_id} not found")
+
+        if "name" in args:
+            train_task.name = args["name"]
+        if "status" in args:
+            train_task.status = args["status"]
+
+        db.session.add(train_task)
+        db.session.commit()
+        return train_task
+
+
+class IntentionTrainFileService:
+    @staticmethod
+    def get_train_file(name: str, version: str, type: str) -> Optional[IntentionTrainFile]:
+        train_file = (
+            IntentionTrainFile.query
+            .filter_by(
+                name=name,
+                version=version,
+                type=type,
+            )
+            .first()
+        )
+        return train_file
+
+    @staticmethod
+    def get_train_files(name=None, version=None, type=None):
+        query = IntentionTrainFile.query.order_by(IntentionTrainFile.created_at.desc())
+        if name:
+            query = query.filter_by(name=name)
+        if version:
+            query = query.filter_by(version=version)
+        if type:
+            query = query.filter_by(type=type)
+        train_files = query.all()
+        return train_files
+
+    @staticmethod
+    def save_train_file(args: dict):
+        name = args["name"]
+        version = args["version"]
+        type = args["type"]
+        train_file = IntentionTrainFileService.get_train_file(name, version, type)
+        if train_file:
+            raise IntentionTrainFileDuplicateError(f"IntentionTrainFile with name-version-type "
+                                                   f"{name}-{version}-{type} already exists.")
+
+        intention_train_file = IntentionTrainFile(
+            id=str(uuid.uuid4()),
+            name=name,
+            version=version,
+            type=type,
+            data_source_type=args["data_source_type"],
+            data_source_info=args["data_source_info"],
+            created_by=current_user.id,
+            created_at=datetime.now(),
+        )
+
+        db.session.add(intention_train_file)
+        db.session.commit()
+        return intention_train_file
+
+class IntentionTrainFileBindingService:
+    @staticmethod
+    def save_train_file_binding(args: dict):
+        file_id = args["file_id"]
+        task_id = args["task_id"]
+        train_file_binding = IntentionTrainFileBinding(
+            id=str(uuid.uuid4()),
+            file_id=file_id,
+            task_id=task_id,
+        )
+        db.session.add(train_file_binding)
+        db.session.commit()
+        return train_file_binding
+

+ 13 - 0
api/services/upload_file_service.py

@@ -0,0 +1,13 @@
+from typing import Optional
+
+from models import UploadFile, db
+
+
+class UploadFileService:
+    @staticmethod
+    def get_upload_file(upload_file_id) -> Optional[UploadFile]:
+        upload_file: Optional[UploadFile] = (
+            db.session.query(UploadFile).filter_by(id=upload_file_id).first()
+        )
+        return upload_file
+