Przeglądaj źródła

Merge branch '1.1.3-lxg' into 1.1.3-master

# Conflicts:
#	api/controllers/console/intention.py
#	api/fields/intention_fields.py
#	api/models/intention.py
#	api/services/errors/intention.py
#	api/services/intention_service.py
liangxunge 2 miesięcy temu
rodzic
commit
820192e1e7

+ 231 - 41
api/controllers/console/intention.py

@@ -1,10 +1,15 @@
+import json
 import logging
+import os
+import zipfile
 
-from flask import request
+from flask import request, send_file
 from flask_restful import Resource, marshal, marshal_with, reqparse
 from werkzeug.exceptions import Forbidden, NotFound
 
+import services
 from controllers.console import api
+from controllers.console.error import FileTooLargeError, UnsupportedFileTypeError
 from controllers.console.wraps import account_initialization_required, setup_required
 from fields.intention_fields import (
     intention_corpus_detail_fields,
@@ -13,20 +18,32 @@ from fields.intention_fields import (
     intention_keyword_detail_fields,
     intention_keyword_fields,
     intention_page_fields,
+    intention_train_file_binding_fields,
+    intention_train_file_fields,
+    intention_train_task_fields,
     intention_type_detail_fields,
     intention_type_page_fields,
 )
-from libs.login import login_required
+from libs.login import current_user, login_required
+from models import UploadFile
+from models.intention import IntentionTrainTask
+from services.errors.intention import IntentionTrainFileDuplicateError
+from services.file_service import FileService
 from services.intention_service import (
     IntentionCorpusService,
     IntentionCorpusSimilarityQuestionService,
     IntentionKeywordService,
     IntentionService,
+    IntentionTrainFileBindingService,
+    IntentionTrainFileService,
+    IntentionTrainTaskService,
     IntentionTypeService,
 )
+from services.upload_file_service import UploadFileService
 
 
 class IntentionListApi(Resource):
+
     @setup_required
     @login_required
     @account_initialization_required
@@ -37,7 +54,8 @@ class IntentionListApi(Resource):
         name_search = request.args.get("name_search", default=None, type=str)
         intentions, total = IntentionService.get_intentions(page, limit, type_id, name_search)
         data = marshal(intentions, intention_page_fields)
-        response = {"data": data, "has_more": len(intentions) == limit, "limit": limit, "total": total, "page": page}
+        response = {"data": data, "has_more": len(intentions) == limit, "limit": limit,
+                    "total": total, "page": page}
         return response, 200
 
     @setup_required
@@ -62,7 +80,6 @@ class IntentionListApi(Resource):
         response = marshal(intention, intention_detail_fields)
         return response, 200
 
-
 class IntentionApi(Resource):
     @setup_required
     @login_required
@@ -100,8 +117,8 @@ class IntentionApi(Resource):
         IntentionService.delete_intention(intention_id)
         return 200
 
-
 class IntentionTypeListApi(Resource):
+
     @setup_required
     @login_required
     @account_initialization_required
@@ -111,13 +128,8 @@ class IntentionTypeListApi(Resource):
         search = request.args.get("search", default=None, type=str)
         intention_types, total = IntentionTypeService.get_intention_types(page, limit, search)
         data = marshal(intention_types, intention_type_page_fields)
-        response = {
-            "data": data,
-            "has_more": len(intention_types) == limit,
-            "limit": limit,
-            "total": total,
-            "page": page,
-        }
+        response = {"data": data, "has_more": len(intention_types) == limit, "limit": limit,
+                    "total": total, "page": page}
         return response, 200
 
     @setup_required
@@ -136,7 +148,6 @@ class IntentionTypeListApi(Resource):
         response = marshal(intention_type, intention_type_detail_fields)
         return response, 200
 
-
 class IntentionTypeApi(Resource):
     @setup_required
     @login_required
@@ -167,7 +178,6 @@ class IntentionTypeApi(Resource):
         IntentionTypeService.delete_intention_type(intention_type_id)
         return 200
 
-
 class IntentionKeywordListApi(Resource):
     @setup_required
     @login_required
@@ -211,8 +221,8 @@ class IntentionKeywordListApi(Resource):
         IntentionKeywordService.delete_intention_keywords_by_intention_id(intention_id)
         return 200
 
-
 class IntentionKeywordApi(Resource):
+
     @setup_required
     @login_required
     @account_initialization_required
@@ -251,8 +261,8 @@ class IntentionKeywordApi(Resource):
         IntentionKeywordService.delete_intention_keyword(intention_keyword_id)
         return 200
 
-
 class IntentionKeywordBatchApi(Resource):
+
     @setup_required
     @login_required
     @account_initialization_required
@@ -286,7 +296,6 @@ class IntentionKeywordBatchApi(Resource):
         else:
             raise NotFound(f"method with name {method} not found")
 
-
 class IntentionCorpusListApi(Resource):
     @setup_required
     @login_required
@@ -297,16 +306,10 @@ class IntentionCorpusListApi(Resource):
         question_search = request.args.get("question_search", default=None, type=str)
         intention_id = request.args.get("intention_id", default=None, type=str)
         intention_corpus, total = IntentionCorpusService.get_page_intention_corpus(
-            page, limit, question_search, intention_id
-        )
+            page, limit, question_search, intention_id)
         data = marshal(intention_corpus, intention_corpus_detail_fields)
-        response = {
-            "data": data,
-            "has_more": len(intention_corpus) == limit,
-            "limit": limit,
-            "total": total,
-            "page": page,
-        }
+        response = {"data": data, "has_more": len(intention_corpus) == limit, "limit": limit,
+                    "total": total, "page": page}
         return response, 200
 
     @setup_required
@@ -336,7 +339,6 @@ class IntentionCorpusListApi(Resource):
         intention_corpus = IntentionCorpusService.save_intention_corpus(args)
         return marshal(intention_corpus, intention_corpus_detail_fields), 200
 
-
 class IntentionCorpusApi(Resource):
     @setup_required
     @login_required
@@ -391,7 +393,6 @@ class IntentionCorpusApi(Resource):
         IntentionCorpusService.delete_intention_corpus(intention_corpus)
         return 200
 
-
 class IntentionCorpusSimilarityQuestionApi(Resource):
     @setup_required
     @login_required
@@ -399,9 +400,8 @@ class IntentionCorpusSimilarityQuestionApi(Resource):
     def get(self, corpus_id):
         search = request.args.get("search", default=None, type=str)
         similarity_questions = (
-            IntentionCorpusSimilarityQuestionService.get_similarity_questions_by_corpus_id_like_question(
-                corpus_id, search
-            )
+            IntentionCorpusSimilarityQuestionService
+            .get_similarity_questions_by_corpus_id_like_question(corpus_id, search)
         )
         return marshal(similarity_questions, intention_corpus_similarity_question_fields), 200
 
@@ -424,8 +424,8 @@ class IntentionCorpusSimilarityQuestionApi(Resource):
             location="json",
         )
         args = parser.parse_args()
-        intention_corpus_similarity_question = IntentionCorpusSimilarityQuestionService.save_similarity_question(
-            corpus_id, args
+        intention_corpus_similarity_question = (
+            IntentionCorpusSimilarityQuestionService.save_similarity_question(corpus_id, args)
         )
         return marshal(intention_corpus_similarity_question, intention_corpus_similarity_question_fields), 200
 
@@ -436,7 +436,6 @@ class IntentionCorpusSimilarityQuestionApi(Resource):
         IntentionCorpusSimilarityQuestionService.delete_similarity_question_by_corpus_id(corpus_id)
         return 200
 
-
 class IntentionCorpusSimilarityQuestionUpdateAndDeleteApi(Resource):
     @setup_required
     @login_required
@@ -464,8 +463,8 @@ class IntentionCorpusSimilarityQuestionUpdateAndDeleteApi(Resource):
         )
         args = parser.parse_args()
 
-        similarity_question = IntentionCorpusSimilarityQuestionService.update_similarity_question(
-            similarity_question_id, args
+        similarity_question = (
+            IntentionCorpusSimilarityQuestionService.update_similarity_question(similarity_question_id, args)
         )
         return marshal(similarity_question, intention_corpus_similarity_question_fields), 200
 
@@ -476,7 +475,6 @@ class IntentionCorpusSimilarityQuestionUpdateAndDeleteApi(Resource):
         IntentionCorpusSimilarityQuestionService.delete_similarity_question_by_id(similarity_question_id)
         return 200
 
-
 class IntentionCorpusSimilarityQuestionBatchApi(Resource):
     @setup_required
     @login_required
@@ -511,6 +509,187 @@ class IntentionCorpusSimilarityQuestionBatchApi(Resource):
         else:
             raise NotFound(f"method with name {method} not found")
 
+class IntentionTrainTaskListApi(Resource):
+    @setup_required
+    @login_required
+    @account_initialization_required
+    def get(self):
+        page = request.args.get("page", default=1, type=int)
+        limit = request.args.get("limit", default=20, type=int)
+        search = request.args.get("search", default=None, type=str)
+        intention_train_tasks, total = IntentionTrainTaskService.get_page_intention_train_tasks(
+            page, limit, search)
+        data = marshal(intention_train_tasks, intention_train_task_fields)
+        response = {"data": data, "has_more": len(intention_train_tasks) == limit, "limit": limit,
+                    "total": total, "page": page}
+        return response, 200
+
+    @setup_required
+    @login_required
+    @account_initialization_required
+    def post(self):
+        parser = reqparse.RequestParser()
+        parser.add_argument(
+            "name",
+            nullable=False,
+            required=True,
+            help="name is required.",
+            location="json",
+        )
+        parser.add_argument(
+            "status",
+            nullable=False,
+            required=True,
+            help="status is required.",
+            choices=IntentionTrainTask.STATUS_LIST,
+            location="json",
+        )
+        args = parser.parse_args()
+        train_task = IntentionTrainTaskService.save_train_task(args)
+        return marshal(train_task, intention_train_task_fields), 200
+
+class IntentionTrainTaskApi(Resource):
+    @setup_required
+    @login_required
+    @account_initialization_required
+    def patch(self, task_id):
+        parser = reqparse.RequestParser()
+        parser.add_argument(
+            "name",
+            nullable=False,
+            required=True,
+            help="name is required.",
+            location="json",
+        )
+        parser.add_argument(
+            "status",
+            nullable=False,
+            required=True,
+            help="status is required.",
+            choices=IntentionTrainTask.STATUS_LIST,
+            location="json",
+        )
+        args = parser.parse_args()
+
+        train_task = IntentionTrainTaskService.update_train_task(task_id, args)
+        return marshal(train_task, intention_train_task_fields), 200
+
+class IntentionTrainTaskDownloadApi(Resource):
+    @setup_required
+    @login_required
+    @account_initialization_required
+    def get(self, task_id):
+        train_task = IntentionTrainTaskService.get_train_task(task_id)
+        if train_task.status != "COMPLETED":
+            raise Forbidden(f"Task with id {task_id} not completed")
+
+        dataset_info = train_task.dataset_info
+        dataset_source_info = json.loads(dataset_info.data_source_info)
+        dataset_file_id = dataset_source_info["upload_file_id"]
+        dataset_upload_file: UploadFile = UploadFileService.get_upload_file(dataset_file_id)
+
+        model_info = train_task.model_info
+        model_source_info = json.loads(model_info.data_source_info)
+        model_file_id = model_source_info["upload_file_id"]
+        model_upload_file: UploadFile = UploadFileService.get_upload_file(model_file_id)
+
+        def file2zip(zip_filename: str, upload_files: list[UploadFile]):
+            with zipfile.ZipFile(zip_filename, "w", compression=zipfile.ZIP_DEFLATED) as zip_file:
+                for upload_file in upload_files:
+                    filename = f"storage/{dataset_upload_file.key}"
+                    zip_file.write(filename, arcname=upload_file.name)
+
+        # 生成待下载的zip包
+        zip_filename = f"{train_task.name}.zip"
+        upload_files: list[UploadFile] = [dataset_upload_file, model_upload_file]
+        file2zip(zip_filename, upload_files)
+
+        # 下载zip包
+        response = send_file(zip_filename, as_attachment=True, download_name=zip_filename)
+        os.remove(zip_filename)
+        return response
+
+class IntentionTrainFileApi(Resource):
+    @setup_required
+    @login_required
+    @account_initialization_required
+    def get(self):
+        name = request.args.get("name", default=None, type=str)
+        version = request.args.get("version", default=None, type=str)
+        type = request.args.get("type", default=None, type=str)
+        train_files = IntentionTrainFileService.get_train_files(name, version, type)
+        return marshal(train_files, intention_train_file_fields), 200
+
+    @setup_required
+    @login_required
+    @account_initialization_required
+    def post(self):
+        name = request.form.get("name")
+        version = request.form.get("version")
+        type = request.form.get("type")
+        train_file = IntentionTrainFileService.get_train_file(name, version, type)
+        if train_file:
+            raise IntentionTrainFileDuplicateError(f"IntentionTrainFile with name-version-type "
+                                                   f"{name}-{version}-{type} already exists.")
+
+        data_source_type = request.form.get("data_source_type")
+
+        # get file from request
+        file = request.files["file"]
+        filename = file.filename
+        mimetype = file.mimetype
+        if not filename or not mimetype:
+            raise Forbidden("Invalid request.")
+
+        try:
+            upload_file = FileService.upload_file(
+                filename=filename,
+                content=file.read(),
+                mimetype=mimetype,
+                user=current_user,
+                source=None,
+            )
+
+            args = {
+                "name": name,
+                "version": version,
+                "type": type,
+                "data_source_type": data_source_type,
+                "data_source_info": {
+                    "upload_file_id": upload_file.id
+                }
+            }
+            intention_train_file = IntentionTrainFileService.save_train_file(args)
+            return marshal(intention_train_file, intention_train_file_fields), 200
+        except services.errors.file.FileTooLargeError as file_too_large_error:
+            raise FileTooLargeError(file_too_large_error.description)
+        except services.errors.file.UnsupportedFileTypeError:
+            raise UnsupportedFileTypeError()
+
+class IntentionTrainFileBindingApi(Resource):
+    @setup_required
+    @login_required
+    @account_initialization_required
+    def post(self):
+        parser = reqparse.RequestParser()
+        parser.add_argument(
+            "file_id",
+            nullable=False,
+            required=True,
+            help="file_id is required.",
+            location="json",
+        )
+        parser.add_argument(
+            "task_id",
+            nullable=False,
+            required=True,
+            help="task_id is required.",
+            location="json",
+        )
+        args = parser.parse_args()
+
+        train_file_binding = IntentionTrainFileBindingService.save_train_file_binding(args)
+        return marshal(train_file_binding, intention_train_file_binding_fields), 200
 
 api.add_resource(IntentionListApi, "/intentions")
 api.add_resource(IntentionApi, "/intentions/<uuid:intention_id>")
@@ -524,8 +703,19 @@ api.add_resource(IntentionCorpusListApi, "/intentions/corpus")
 api.add_resource(IntentionCorpusApi, "/intentions/corpus/<uuid:corpus_id>")
 api.add_resource(IntentionCorpusSimilarityQuestionApi, "/intentions/corpus/<uuid:corpus_id>/similarity_questions")
 
-api.add_resource(
-    IntentionCorpusSimilarityQuestionUpdateAndDeleteApi,
-    "/intentions/similarity_questions/<uuid:similarity_question_id>",
-)
+api.add_resource(IntentionCorpusSimilarityQuestionUpdateAndDeleteApi,
+                 "/intentions/similarity_questions/<uuid:similarity_question_id>")
 api.add_resource(IntentionCorpusSimilarityQuestionBatchApi, "/intentions/similarity_questions/batch")
+
+api.add_resource(IntentionTrainTaskListApi, "/intentions/train_tasks")
+api.add_resource(IntentionTrainTaskApi, "/intentions/train_tasks/<uuid:task_id>")
+api.add_resource(IntentionTrainTaskDownloadApi, "/intentions/train_tasks/download/<uuid:task_id>")
+
+api.add_resource(IntentionTrainFileApi, "/intentions/train_files")
+
+api.add_resource(IntentionTrainFileBindingApi, "/intentions/train_file_bindings")
+
+
+
+
+

+ 29 - 1
api/fields/intention_fields.py

@@ -40,7 +40,7 @@ intention_keyword_fields = {
 intention_keyword_detail_fields = {
     "id": fields.String,
     "name": fields.String,
-    "intention": fields.Nested(intention_fields),
+    "intention": fields.Nested(intention_fields)
 }
 
 intention_corpus_fields = {
@@ -93,3 +93,31 @@ intention_detail_fields = {
     "updated_by": fields.String,
     "updated_at": TimestampField,
 }
+
+intention_train_file_fields = {
+    "id": fields.String,
+    "name": fields.String,
+    "version": fields.String,
+    "type": fields.String,
+    "data_source_type": fields.String,
+    "data_source_info": fields.String,
+    "created_by": fields.String,
+    "created_at": TimestampField,
+}
+
+intention_train_file_binding_fields = {
+    "id": fields.String,
+    "file_id": fields.String,
+    "task_id": fields.String,
+}
+
+intention_train_task_fields = {
+    "id": fields.String,
+    "name": fields.String,
+    "status": fields.String,
+    "dataset_info": fields.Nested(intention_train_file_fields, allow_null=True),
+    "model_info": fields.Nested(intention_train_file_fields, allow_null=True),
+    "created_by": fields.String,
+    "created_at": TimestampField,
+}
+

+ 130 - 21
api/models/intention.py

@@ -5,8 +5,10 @@ from .types import StringUUID
 
 
 class IntentionType(db.Model):
-    __tablename__ = "intention_types"
-    __table_args__ = (db.PrimaryKeyConstraint("id", name="intention_type_id_pkey"),)
+    __tablename__ = 'intention_types'
+    __table_args__ = (
+        db.PrimaryKeyConstraint('id', name='intention_type_id_pkey'),
+    )
 
     id = db.Column(StringUUID, nullable=False, server_default=db.text("uuid_generate_v4()"))
     name = db.Column(db.String(255), nullable=False)
@@ -17,12 +19,19 @@ class IntentionType(db.Model):
 
     @property
     def intention_count(self):
-        return db.session.query(func.count(Intention.id)).filter(Intention.type_id == self.id).scalar()
+        return (
+            db.session.query(func.count(Intention.id))
+            .filter(Intention.type_id==self.id)
+            .scalar()
+        )
 
     @property
     def intentions(self):
-        return db.session.query(Intention).filter(Intention.type_id == self.id).all()
-
+        return (
+            db.session.query(Intention)
+            .filter(Intention.type_id == self.id)
+            .all()
+        )
 
 class Intention(db.Model):
     __tablename__ = "intentions"
@@ -41,34 +50,57 @@ class Intention(db.Model):
 
     @property
     def type_name(self):
-        return db.session.query(IntentionType.name).filter(IntentionType.id == self.type_id).first().name
+        return (
+            db.session.query(IntentionType.name)
+            .filter(IntentionType.id==self.type_id)
+            .first().name
+        )
 
     @property
     def type(self):
-        return db.session.query(IntentionType).filter(IntentionType.id == self.type_id).first()
+        return (
+            db.session.query(IntentionType)
+            .filter(IntentionType.id==self.type_id)
+            .first()
+        )
 
     @property
     def corpus(self):
-        return db.session.query(IntentionCorpus).filter(IntentionCorpus.intention_id == self.id).all()
+        return (
+            db.session.query(IntentionCorpus)
+            .filter(IntentionCorpus.intention_id==self.id)
+            .all()
+        )
 
     @property
     def keywords(self):
-        return db.session.query(IntentionKeyword).filter(IntentionKeyword.intention_id == self.id).all()
+        return (
+            db.session.query(IntentionKeyword)
+            .filter(IntentionKeyword.intention_id==self.id)
+            .all()
+        )
 
     @property
     def corpus_count(self):
-        return db.session.query(func.count(IntentionCorpus.id)).filter(IntentionCorpus.intention_id == self.id).scalar()
+        return (
+            db.session.query(func.count(IntentionCorpus.id))
+            .filter(IntentionCorpus.intention_id == self.id)
+            .scalar()
+        )
 
     @property
     def keywords_count(self):
         return (
-            db.session.query(func.count(IntentionKeyword.id)).filter(IntentionKeyword.intention_id == self.id).scalar()
+            db.session.query(func.count(IntentionKeyword.id))
+            .filter(IntentionKeyword.intention_id == self.id)
+            .scalar()
         )
 
-
 class IntentionKeyword(db.Model):
     __tablename__ = "intention_keywords"
-    __table_args__ = (db.PrimaryKeyConstraint("id", name="intention_keyword_pkey"),)
+    __table_args__ = (
+        db.PrimaryKeyConstraint('id', name='intention_keyword_pkey'),
+    )
 
     id = db.Column(StringUUID, nullable=False, server_default=db.text("uuid_generate_v4()"))
     name = db.Column(db.String(255), nullable=False)
@@ -80,13 +112,14 @@ class IntentionKeyword(db.Model):
 
     @property
     def intention(self):
-        return db.session.query(Intention).filter(Intention.id == self.intention_id).first()
-
+        return (
+            db.session.query(Intention).filter(Intention.id==self.intention_id).first()
+        )
 
 class IntentionCorpus(db.Model):
     __tablename__ = "intention_corpus"
     __table_args__ = (
-        db.PrimaryKeyConstraint("id", name="intention_corpus_pkey"),
+        db.PrimaryKeyConstraint('id', name='intention_corpus_pkey'),
         db.Index("intention_corpus_idx", "intention_id"),
     )
 
@@ -101,21 +134,24 @@ class IntentionCorpus(db.Model):
 
     @property
     def intention(self):
-        return db.session.query(Intention).filter(Intention.id == self.intention_id).first()
+        return (
+            db.session.query(Intention)
+            .filter(Intention.id==self.intention_id)
+            .first()
+        )
 
     @property
     def similarity_questions(self):
         return (
             db.session.query(IntentionCorpusSimilarityQuestion)
-            .filter(IntentionCorpusSimilarityQuestion.corpus_id == self.id)
+            .filter(IntentionCorpusSimilarityQuestion.corpus_id==self.id)
             .all()
         )
 
-
 class IntentionCorpusSimilarityQuestion(db.Model):
     __tablename__ = "intention_corpus_similarity_questions"
     __table_args__ = (
-        db.PrimaryKeyConstraint("id", name="intention_corpus_similarity_question_pkey"),
+        db.PrimaryKeyConstraint('id', name='intention_corpus_similarity_question_pkey'),
         db.Index("intention_corpus_similarity_question_idx", "corpus_id"),
     )
 
@@ -130,4 +166,77 @@ class IntentionCorpusSimilarityQuestion(db.Model):
 
     @property
     def corpus(self):
-        return db.session.query(IntentionCorpus).filter(IntentionCorpus.id == self.corpus_id).first()
+        return (
+            db.session.query(IntentionCorpus)
+            .filter(IntentionCorpus.id==self.corpus_id)
+            .first()
+        )
+
+class IntentionTrainTask(db.Model):
+    __tablename__ = "intention_train_tasks"
+    __table_args__ = (
+        db.PrimaryKeyConstraint('id', name='intention_train_log_pkey'),
+    )
+
+    STATUS_LIST = ["CREATED", "TRAINING", "COMPLETED"]
+
+    id = db.Column(StringUUID, nullable=False, server_default=db.text("uuid_generate_v4()"))
+    name = db.Column(db.String(255), nullable=False)
+    status = db.Column(db.String(255), nullable=False)
+    created_by = db.Column(StringUUID, nullable=False)
+    created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
+
+    @property
+    def dataset_info(self):
+        dataset_info = (
+            db.session.query(IntentionTrainFile)
+            .join(IntentionTrainFileBinding, IntentionTrainFileBinding.file_id == IntentionTrainFile.id)
+            .filter(
+                IntentionTrainFile.type == "DATASET",
+                IntentionTrainFileBinding.task_id == self.id
+            )
+            .first()
+        )
+        return dataset_info
+
+    @property
+    def model_info(self):
+        model_info = (
+            db.session.query(IntentionTrainFile)
+            .join(IntentionTrainFileBinding, IntentionTrainFileBinding.file_id == IntentionTrainFile.id)
+            .filter(
+                IntentionTrainFile.type == "MODEL",
+                IntentionTrainFileBinding.task_id == self.id
+            )
+            .first()
+        )
+        return model_info
+
+class IntentionTrainFile(db.Model):
+    __tablename__ = "intention_train_files"
+    __table_args__ = (
+        db.PrimaryKeyConstraint('id', name='intention_train_file_pkey'),
+    )
+
+    TRAIN_FILE_TYPE_LIST = ["DATASET", "MODEL"]
+
+    id = db.Column(StringUUID, nullable=False, server_default=db.text("uuid_generate_v4()"))
+    name = db.Column(db.String(255), nullable=False)
+    version = db.Column(db.String(255), nullable=False)
+    type = db.Column(db.String(255), nullable=False)
+    data_source_type = db.Column(db.String(255), nullable=False)
+    data_source_info = db.Column(db.JSON, nullable=False)
+    created_by = db.Column(StringUUID, nullable=False)
+    created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
+
+class IntentionTrainFileBinding(db.Model):
+    __tablename__ = "intention_train_file_bindings"
+    __table_args__ = (
+        db.PrimaryKeyConstraint('id', name='intention_train_file_binding_pkey'),
+        db.Index("intention_train_file_binding_idx", "file_id"),
+        db.Index("intention_train_file_binding_task_idx", "task_id"),
+    )
+
+    id = db.Column(StringUUID, nullable=False, server_default=db.text("uuid_generate_v4()"))
+    file_id = db.Column(StringUUID, nullable=False)
+    task_id = db.Column(StringUUID, nullable=False)

+ 3 - 4
api/services/errors/intention.py

@@ -4,18 +4,17 @@ from services.errors.base import BaseServiceError
 class IntentionNameDuplicateError(BaseServiceError):
     pass
 
-
 class IntentionTypeNameDuplicateError(BaseServiceError):
     pass
 
-
 class IntentionKeywordNameDuplicateError(BaseServiceError):
     pass
 
-
 class IntentionCorpusQuestionDuplicateError(BaseServiceError):
     pass
 
-
 class IntentionCorpusSimilarityQuestionDuplicateError(BaseServiceError):
     pass
+
+class IntentionTrainFileDuplicateError(BaseServiceError):
+    pass

+ 179 - 46
api/services/intention_service.py

@@ -12,12 +12,16 @@ from models.intention import (
     IntentionCorpus,
     IntentionCorpusSimilarityQuestion,
     IntentionKeyword,
+    IntentionTrainFile,
+    IntentionTrainFileBinding,
+    IntentionTrainTask,
     IntentionType,
 )
 from services.errors.intention import (
     IntentionCorpusQuestionDuplicateError,
     IntentionKeywordNameDuplicateError,
     IntentionNameDuplicateError,
+    IntentionTrainFileDuplicateError,
     IntentionTypeNameDuplicateError,
 )
 
@@ -33,7 +37,7 @@ class IntentionTypeService:
 
     @staticmethod
     def save_intention_type(args: dict) -> IntentionType:
-        name = args["name"]
+        name = args['name']
         intention_type = IntentionTypeService.get_intention_type_by_name(name)
         if intention_type:
             raise IntentionTypeNameDuplicateError(f"IntentionType with name {name} already exists.")
@@ -53,11 +57,13 @@ class IntentionTypeService:
         if not intention_type:
             raise NotFound("IntentionType not found")
 
-        name = args["name"]
-        intention_type_new = IntentionType.query.filter(
-            IntentionType.id != intention_type.id,
-            IntentionType.name == name,
-        ).first()
+        name = args['name']
+        intention_type_new = (
+            IntentionType.query.filter(
+                IntentionType.id != intention_type.id,
+                IntentionType.name == name,
+            ).first()
+        )
         if intention_type_new:
             raise IntentionTypeNameDuplicateError(f"IntentionType with name {name} already exists.")
         intention_type.name = name
@@ -87,7 +93,6 @@ class IntentionTypeService:
         intention_type: Optional[IntentionType] = IntentionType.query.filter_by(name=name).first()
         return intention_type
 
-
 class IntentionService:
     @staticmethod
     def get_intentions(page, per_page, type_id=None, name_search=None):
@@ -132,7 +137,12 @@ class IntentionService:
             raise NotFound("Intention not found")
 
         name = args["name"]
-        intention_new = Intention.query.filter(Intention.id != intention.id, Intention.name == name).first()
+        intention_new = (
+            Intention.query.filter(
+                Intention.id != intention.id,
+                Intention.name == name
+            ).first()
+        )
         if intention_new:
             raise IntentionNameDuplicateError(f"Intention with name {name} already exists.")
 
@@ -157,15 +167,13 @@ class IntentionService:
 
         # 1.若存在关键词,则无法删除
         if intention.keywords_count > 0:
-            raise Forbidden(
-                f"You are not allowed to delete intention, because {intention.keywords_count} keywords were found."
-            )
+            raise Forbidden(f"You are not allowed to delete intention, "
+                            f"because {intention.keywords_count} keywords were found.")
 
         # 2.若关联训练预料,则无法删除
         if intention.corpus_count > 0:
-            raise Forbidden(
-                f"You are not allowed to delete intention, because {intention.corpus_count} corpus were found."
-            )
+            raise Forbidden(f"You are not allowed to delete intention, "
+                            f"because {intention.corpus_count} corpus were found.")
 
         db.session.delete(intention)
         db.session.commit()
@@ -181,7 +189,6 @@ class IntentionService:
         intention: Optional[Intention] = Intention.query.filter(Intention.name == name).first()
         return intention
 
-
 class IntentionKeywordService:
     @staticmethod
     def get_intention_keywords(intention_id: str, search=None):
@@ -215,9 +222,12 @@ class IntentionKeywordService:
             raise NotFound("IntentionKeyword not found")
 
         name = args["name"]
-        intention_keyword_new = IntentionKeyword.query.filter(
-            IntentionKeyword.id != intention_keyword.id, IntentionKeyword.name == name
-        ).first()
+        intention_keyword_new = (
+            IntentionKeyword.query.filter(
+                IntentionKeyword.id != intention_keyword.id,
+                IntentionKeyword.name == name)
+            .first()
+        )
         if intention_keyword_new:
             raise IntentionKeywordNameDuplicateError(f"IntentionKeyword with name {name} already exists.")
         intention_keyword.name = name
@@ -258,26 +268,27 @@ class IntentionKeywordService:
 
     @staticmethod
     def get_intention_keyword(intention_keyword_id: str) -> Optional[IntentionKeyword]:
-        intention_keyword: Optional[IntentionKeyword] = IntentionKeyword.query.filter_by(
-            id=intention_keyword_id
-        ).first()
+        intention_keyword: Optional[IntentionKeyword] = (
+            IntentionKeyword.query.filter_by(id=intention_keyword_id).first()
+        )
         return intention_keyword
 
     @staticmethod
     def get_intention_keyword_by_name(intention_id, name: str) -> Optional[IntentionKeyword]:
-        intention_keyword: Optional[IntentionKeyword] = IntentionKeyword.query.filter(
-            IntentionKeyword.intention_id == intention_id,
-            IntentionKeyword.name == name,
-        ).first()
+        intention_keyword: Optional[IntentionKeyword] = (
+            IntentionKeyword.query.filter(
+                IntentionKeyword.intention_id == intention_id,
+                IntentionKeyword.name == name,
+            ).first()
+        )
         return intention_keyword
 
-
 class IntentionCorpusService:
     @staticmethod
     def get_intention_corpus(corpus_id: str) -> Optional[IntentionCorpus]:
-        intention_corpus: Optional[IntentionCorpus] = IntentionCorpus.query.filter(
-            IntentionCorpus.id == corpus_id
-        ).first()
+        intention_corpus: Optional[IntentionCorpus] = (
+            IntentionCorpus.query.filter(IntentionCorpus.id == corpus_id).first()
+        )
         return intention_corpus
 
     @staticmethod
@@ -294,9 +305,9 @@ class IntentionCorpusService:
 
     @staticmethod
     def get_intention_corpus_by_question(question: str) -> Optional[IntentionCorpus]:
-        intention_corpus: Optional[IntentionCorpus] = IntentionCorpus.query.filter(
-            IntentionCorpus.question == question
-        ).first()
+        intention_corpus: Optional[IntentionCorpus] = (
+            IntentionCorpus.query.filter(IntentionCorpus.question == question).first()
+        )
         return intention_corpus
 
     @staticmethod
@@ -334,9 +345,12 @@ class IntentionCorpusService:
 
         if "question" in args:
             question = args["question"]
-            intention_corpus_new = IntentionCorpus.query.filter(
-                IntentionCorpus.id != corpus_id, IntentionCorpus.question == question
-            ).first()
+            intention_corpus_new = (
+                IntentionCorpus.query.filter(
+                    IntentionCorpus.id != corpus_id,
+                    IntentionCorpus.question == question
+                ).first()
+            )
             if intention_corpus_new:
                 raise IntentionCorpusQuestionDuplicateError(f"IntentionCorpus with question {question} already exists.")
             intention_corpus.question = question
@@ -373,7 +387,6 @@ class IntentionCorpusService:
         db.session.delete(intention_corpus)
         db.session.commit()
 
-
 class IntentionCorpusSimilarityQuestionService:
     @staticmethod
     def save_similarity_question(corpus_id: str, args: dict):
@@ -427,9 +440,9 @@ class IntentionCorpusSimilarityQuestionService:
 
     @staticmethod
     def get_similarity_question(similarity_question_id: str) -> Optional[IntentionCorpusSimilarityQuestion]:
-        similarity_question: Optional[IntentionCorpus] = IntentionCorpusSimilarityQuestion.query.filter_by(
-            id=similarity_question_id
-        ).first()
+        similarity_question: Optional[IntentionCorpus] = (
+            IntentionCorpusSimilarityQuestion.query.filter_by(id=similarity_question_id).first()
+        )
         return similarity_question
 
     @staticmethod
@@ -440,10 +453,12 @@ class IntentionCorpusSimilarityQuestionService:
         return similarity_question
 
     @staticmethod
-    def get_similarity_questions_by_corpus_id_like_question(corpus_id, search=None):
-        query = IntentionCorpusSimilarityQuestion.query.filter(
-            IntentionCorpusSimilarityQuestion.corpus_id == corpus_id
-        ).order_by(IntentionCorpusSimilarityQuestion.created_at.desc())
+    def get_similarity_questions_by_corpus_id_like_question(corpus_id, search = None):
+        query = (
+            IntentionCorpusSimilarityQuestion.query
+            .filter(IntentionCorpusSimilarityQuestion.corpus_id==corpus_id)
+            .order_by(IntentionCorpusSimilarityQuestion.created_at.desc())
+        )
         if search:
             query = query.filter(IntentionCorpusSimilarityQuestion.question.ilike(f"%{search}%"))
 
@@ -473,11 +488,14 @@ class IntentionCorpusSimilarityQuestionService:
 
     @staticmethod
     def delete_similarity_questions_by_ids(similarity_question_ids: list[str]):
-        similarity_questions = IntentionCorpusSimilarityQuestion.query.filter(
-            IntentionCorpusSimilarityQuestion.id.in_(similarity_question_ids)
-        ).all()
+        similarity_questions = (
+            IntentionCorpusSimilarityQuestion.query
+            .filter(IntentionCorpusSimilarityQuestion.id.in_(similarity_question_ids))
+            .all()
+        )
         IntentionCorpusSimilarityQuestionService.delete_similarity_questions(similarity_questions)
 
+
     @staticmethod
     def delete_similarity_questions(similarity_questions: list[IntentionCorpusSimilarityQuestion]):
         if not similarity_questions:
@@ -486,3 +504,118 @@ class IntentionCorpusSimilarityQuestionService:
         for similarity_question in similarity_questions:
             db.session.delete(similarity_question)
         db.session.commit()
+
+class IntentionTrainTaskService:
+
+    @staticmethod
+    def get_page_intention_train_tasks(page, per_page, search=None):
+        query = (
+            IntentionTrainTask.query.order_by(IntentionTrainTask.created_at.desc())
+        )
+        if search:
+            query = query.filter(IntentionTrainTask.name.ilike(f"%{search}%"))
+
+        intention_train_tasks = query.paginate(page=page, per_page=per_page, error_out=False)
+        return intention_train_tasks.items, intention_train_tasks.total
+
+    @staticmethod
+    def get_train_task(train_task_id: str) -> Optional[IntentionTrainTask]:
+        train_task: Optional[IntentionTrainTask] = (
+            IntentionTrainTask.query.filter_by(id=train_task_id).first()
+        )
+        return train_task
+
+    @staticmethod
+    def save_train_task(args: dict):
+        train_task = IntentionTrainTask(
+            id=str(uuid.uuid4()),
+            name=args["name"],
+            status=args["status"],
+            created_by=current_user.id,
+            created_at=datetime.now(),
+        )
+        db.session.add(train_task)
+        db.session.commit()
+        return train_task
+
+    @staticmethod
+    def update_train_task(task_id: str, args: dict):
+        train_task = IntentionTrainTaskService.get_train_task(task_id)
+        if not train_task:
+            raise NotFound(f"IntentionTrainTask with id {task_id} not found")
+
+        if "name" in args:
+            train_task.name = args["name"]
+        if "status" in args:
+            train_task.status = args["status"]
+
+        db.session.add(train_task)
+        db.session.commit()
+        return train_task
+
+
+class IntentionTrainFileService:
+    @staticmethod
+    def get_train_file(name: str, version: str, type: str) -> Optional[IntentionTrainFile]:
+        train_file = (
+            IntentionTrainFile.query
+            .filter_by(
+                name=name,
+                version=version,
+                type=type,
+            )
+            .first()
+        )
+        return train_file
+
+    @staticmethod
+    def get_train_files(name=None, version=None, type=None):
+        query = IntentionTrainFile.query.order_by(IntentionTrainFile.created_at.desc())
+        if name:
+            query = query.filter_by(name=name)
+        if version:
+            query = query.filter_by(version=version)
+        if type:
+            query = query.filter_by(type=type)
+        train_files = query.all()
+        return train_files
+
+    @staticmethod
+    def save_train_file(args: dict):
+        name = args["name"]
+        version = args["version"]
+        type = args["type"]
+        train_file = IntentionTrainFileService.get_train_file(name, version, type)
+        if train_file:
+            raise IntentionTrainFileDuplicateError(f"IntentionTrainFile with name-version-type "
+                                                   f"{name}-{version}-{type} already exists.")
+
+        intention_train_file = IntentionTrainFile(
+            id=str(uuid.uuid4()),
+            name=name,
+            version=version,
+            type=type,
+            data_source_type=args["data_source_type"],
+            data_source_info=args["data_source_info"],
+            created_by=current_user.id,
+            created_at=datetime.now(),
+        )
+
+        db.session.add(intention_train_file)
+        db.session.commit()
+        return intention_train_file
+
+class IntentionTrainFileBindingService:
+    @staticmethod
+    def save_train_file_binding(args: dict):
+        file_id = args["file_id"]
+        task_id = args["task_id"]
+        train_file_binding = IntentionTrainFileBinding(
+            id=str(uuid.uuid4()),
+            file_id=file_id,
+            task_id=task_id,
+        )
+        db.session.add(train_file_binding)
+        db.session.commit()
+        return train_file_binding
+

+ 13 - 0
api/services/upload_file_service.py

@@ -0,0 +1,13 @@
+from typing import Optional
+
+from models import UploadFile, db
+
+
+class UploadFileService:
+    @staticmethod
+    def get_upload_file(upload_file_id) -> Optional[UploadFile]:
+        upload_file: Optional[UploadFile] = (
+            db.session.query(UploadFile).filter_by(id=upload_file_id).first()
+        )
+        return upload_file
+