瀏覽代碼

训练日志[优化下载训练日志接口,使其支持从文件服务器下载]

liangxunge 5 月之前
父節點
當前提交
7770a542ae
共有 1 個文件被更改,包括 14 次插入18 次删除
  1. 14 18
      api/controllers/console/intention.py

+ 14 - 18
api/controllers/console/intention.py

@@ -11,6 +11,7 @@ import services
 from controllers.console import api
 from controllers.console.error import FileTooLargeError, UnsupportedFileTypeError
 from controllers.console.wraps import account_initialization_required, setup_required
+from extensions.ext_storage import storage
 from fields.intention_fields import (
     intention_corpus_detail_fields,
     intention_corpus_similarity_question_fields,
@@ -26,7 +27,7 @@ from fields.intention_fields import (
 )
 from libs.login import current_user, login_required
 from models import UploadFile
-from models.intention import IntentionTrainTask
+from models.intention import IntentionTrainTask, IntentionTrainFile
 from services.errors.intention import IntentionTrainFileDuplicateError
 from services.file_service import FileService
 from services.intention_service import (
@@ -583,29 +584,24 @@ class IntentionTrainTaskDownloadApi(Resource):
         if train_task.status != "COMPLETED":
             raise Forbidden(f"Task with id {task_id} not completed")
 
-        dataset_info = train_task.dataset_info
-        dataset_source_info = json.loads(dataset_info.data_source_info)
-        dataset_file_id = dataset_source_info["upload_file_id"]
-        dataset_upload_file: UploadFile = UploadFileService.get_upload_file(dataset_file_id)
-
-        model_info = train_task.model_info
-        model_source_info = json.loads(model_info.data_source_info)
-        model_file_id = model_source_info["upload_file_id"]
-        model_upload_file: UploadFile = UploadFileService.get_upload_file(model_file_id)
-
-        def file2zip(zip_filename: str, upload_files: list[UploadFile]):
-            with zipfile.ZipFile(zip_filename, "w", compression=zipfile.ZIP_DEFLATED) as zip_file:
-                for upload_file in upload_files:
-                    filename = f"storage/{dataset_upload_file.key}"
-                    zip_file.write(filename, arcname=upload_file.name)
+        dataset_info: IntentionTrainFile = train_task.dataset_info
+        model_info: IntentionTrainFile = train_task.model_info
 
         # 生成待下载的zip包
         zip_filename = f"{train_task.name}.zip"
-        upload_files: list[UploadFile] = [dataset_upload_file, model_upload_file]
-        file2zip(zip_filename, upload_files)
+        with zipfile.ZipFile(zip_filename, "w", compression=zipfile.ZIP_DEFLATED) as zip_file:
+            for train_file in [dataset_info, model_info]:
+                source_info = json.loads(train_file.data_source_info)
+                upload_file_id = source_info["upload_file_id"]
+                upload_file: UploadFile = UploadFileService.get_upload_file(upload_file_id)
+                storage.download(upload_file.key, upload_file.name)
+                zip_file.write(upload_file.name, upload_file.name)
+                os.remove(upload_file.name)
 
         # 下载zip包
         response = send_file(zip_filename, as_attachment=True, download_name=zip_filename)
+
+        # 清除临时文件
         os.remove(zip_filename)
         return response