|
@@ -11,6 +11,7 @@ import services
|
|
|
from controllers.console import api
|
|
|
from controllers.console.error import FileTooLargeError, UnsupportedFileTypeError
|
|
|
from controllers.console.wraps import account_initialization_required, setup_required
|
|
|
+from extensions.ext_storage import storage
|
|
|
from fields.intention_fields import (
|
|
|
intention_corpus_detail_fields,
|
|
|
intention_corpus_similarity_question_fields,
|
|
@@ -26,7 +27,7 @@ from fields.intention_fields import (
|
|
|
)
|
|
|
from libs.login import current_user, login_required
|
|
|
from models import UploadFile
|
|
|
-from models.intention import IntentionTrainTask
|
|
|
+from models.intention import IntentionTrainTask, IntentionTrainFile
|
|
|
from services.errors.intention import IntentionTrainFileDuplicateError
|
|
|
from services.file_service import FileService
|
|
|
from services.intention_service import (
|
|
@@ -583,29 +584,24 @@ class IntentionTrainTaskDownloadApi(Resource):
|
|
|
if train_task.status != "COMPLETED":
|
|
|
raise Forbidden(f"Task with id {task_id} not completed")
|
|
|
|
|
|
- dataset_info = train_task.dataset_info
|
|
|
- dataset_source_info = json.loads(dataset_info.data_source_info)
|
|
|
- dataset_file_id = dataset_source_info["upload_file_id"]
|
|
|
- dataset_upload_file: UploadFile = UploadFileService.get_upload_file(dataset_file_id)
|
|
|
-
|
|
|
- model_info = train_task.model_info
|
|
|
- model_source_info = json.loads(model_info.data_source_info)
|
|
|
- model_file_id = model_source_info["upload_file_id"]
|
|
|
- model_upload_file: UploadFile = UploadFileService.get_upload_file(model_file_id)
|
|
|
-
|
|
|
- def file2zip(zip_filename: str, upload_files: list[UploadFile]):
|
|
|
- with zipfile.ZipFile(zip_filename, "w", compression=zipfile.ZIP_DEFLATED) as zip_file:
|
|
|
- for upload_file in upload_files:
|
|
|
- filename = f"storage/{dataset_upload_file.key}"
|
|
|
- zip_file.write(filename, arcname=upload_file.name)
|
|
|
+ dataset_info: IntentionTrainFile = train_task.dataset_info
|
|
|
+ model_info: IntentionTrainFile = train_task.model_info
|
|
|
|
|
|
# 生成待下载的zip包
|
|
|
zip_filename = f"{train_task.name}.zip"
|
|
|
- upload_files: list[UploadFile] = [dataset_upload_file, model_upload_file]
|
|
|
- file2zip(zip_filename, upload_files)
|
|
|
+ with zipfile.ZipFile(zip_filename, "w", compression=zipfile.ZIP_DEFLATED) as zip_file:
|
|
|
+ for train_file in [dataset_info, model_info]:
|
|
|
+ source_info = json.loads(train_file.data_source_info)
|
|
|
+ upload_file_id = source_info["upload_file_id"]
|
|
|
+ upload_file: UploadFile = UploadFileService.get_upload_file(upload_file_id)
|
|
|
+ storage.download(upload_file.key, upload_file.name)
|
|
|
+ zip_file.write(upload_file.name, upload_file.name)
|
|
|
+ os.remove(upload_file.name)
|
|
|
|
|
|
# 下载zip包
|
|
|
response = send_file(zip_filename, as_attachment=True, download_name=zip_filename)
|
|
|
+
|
|
|
+ # 清除临时文件
|
|
|
os.remove(zip_filename)
|
|
|
return response
|
|
|
|