intention.py 25 KB


  1. import json
  2. import logging
  3. import os
  4. import zipfile
  5. from flask import request, send_file
  6. from flask_restful import Resource, marshal, marshal_with, reqparse
  7. from werkzeug.exceptions import Forbidden, NotFound
  8. import services
  9. from controllers.console import api
  10. from controllers.console.error import FileTooLargeError, UnsupportedFileTypeError
  11. from controllers.console.wraps import account_initialization_required, setup_required
  12. from fields.intention_fields import (
  13. intention_corpus_detail_fields,
  14. intention_corpus_similarity_question_fields,
  15. intention_detail_fields,
  16. intention_keyword_detail_fields,
  17. intention_keyword_fields,
  18. intention_page_fields,
  19. intention_train_file_binding_fields,
  20. intention_train_file_fields,
  21. intention_train_task_fields,
  22. intention_type_detail_fields,
  23. intention_type_page_fields,
  24. )
  25. from libs.login import current_user, login_required
  26. from models import UploadFile
  27. from models.intention import IntentionTrainTask
  28. from services.errors.intention import IntentionTrainFileDuplicateError
  29. from services.file_service import FileService
  30. from services.intention_service import (
  31. IntentionCorpusService,
  32. IntentionCorpusSimilarityQuestionService,
  33. IntentionKeywordService,
  34. IntentionService,
  35. IntentionTrainFileBindingService,
  36. IntentionTrainFileService,
  37. IntentionTrainTaskService,
  38. IntentionTypeService,
  39. )
  40. from services.upload_file_service import UploadFileService
  41. class IntentionListApi(Resource):
  42. @setup_required
  43. @login_required
  44. @account_initialization_required
  45. def get(self):
  46. page = request.args.get("page", default=1, type=int)
  47. limit = request.args.get("limit", default=20, type=int)
  48. type_id = request.args.get("type_id", default=None, type=str)
  49. name_search = request.args.get("name_search", default=None, type=str)
  50. intentions, total = IntentionService.get_intentions(page, limit, type_id, name_search)
  51. data = marshal(intentions, intention_page_fields)
  52. response = {"data": data, "has_more": len(intentions) == limit, "limit": limit,
  53. "total": total, "page": page}
  54. return response, 200
  55. @setup_required
  56. @login_required
  57. @account_initialization_required
  58. def post(self):
  59. parser = reqparse.RequestParser()
  60. parser.add_argument(
  61. "name",
  62. nullable=False,
  63. required=True,
  64. help="type is required. Name must be between 1 to 40 characters.",
  65. )
  66. parser.add_argument(
  67. "type_id",
  68. nullable=False,
  69. required=True,
  70. help="type is required.",
  71. )
  72. args = parser.parse_args()
  73. intention = IntentionService.save_intention(args)
  74. response = marshal(intention, intention_detail_fields)
  75. return response, 200
  76. class IntentionApi(Resource):
  77. @setup_required
  78. @login_required
  79. @account_initialization_required
  80. def get(self, intention_id):
  81. intention = IntentionService.get_intention(intention_id)
  82. return marshal(intention, intention_detail_fields)
  83. @setup_required
  84. @login_required
  85. @account_initialization_required
  86. def patch(self, intention_id):
  87. parser = reqparse.RequestParser()
  88. parser.add_argument(
  89. "name",
  90. nullable=False,
  91. required=True,
  92. help="type is required. Name must be between 1 to 40 characters.",
  93. )
  94. parser.add_argument(
  95. "type_id",
  96. nullable=False,
  97. required=True,
  98. help="type is required.",
  99. )
  100. args = parser.parse_args()
  101. intention = IntentionService.update_intention(intention_id, args)
  102. response = marshal(intention, intention_detail_fields)
  103. return response, 200
  104. @setup_required
  105. @login_required
  106. @account_initialization_required
  107. def delete(self, intention_id):
  108. IntentionService.delete_intention(intention_id)
  109. return 200
  110. class IntentionTypeListApi(Resource):
  111. @setup_required
  112. @login_required
  113. @account_initialization_required
  114. def get(self):
  115. page = request.args.get("page", default=1, type=int)
  116. limit = request.args.get("limit", default=20, type=int)
  117. search = request.args.get("search", default=None, type=str)
  118. intention_types, total = IntentionTypeService.get_intention_types(page, limit, search)
  119. data = marshal(intention_types, intention_type_page_fields)
  120. response = {"data": data, "has_more": len(intention_types) == limit, "limit": limit,
  121. "total": total, "page": page}
  122. return response, 200
  123. @setup_required
  124. @login_required
  125. @account_initialization_required
  126. def post(self):
  127. parser = reqparse.RequestParser()
  128. parser.add_argument(
  129. "name",
  130. nullable=False,
  131. required=True,
  132. help="type is required. Name must be between 1 to 40 characters.",
  133. )
  134. args = parser.parse_args()
  135. intention_type = IntentionTypeService.save_intention_type(args)
  136. response = marshal(intention_type, intention_type_detail_fields)
  137. return response, 200
  138. class IntentionTypeApi(Resource):
  139. @setup_required
  140. @login_required
  141. @account_initialization_required
  142. def get(self, intention_type_id):
  143. intention_type = IntentionTypeService.get_intention_type(intention_type_id)
  144. return marshal(intention_type, intention_type_detail_fields)
  145. @setup_required
  146. @login_required
  147. @account_initialization_required
  148. def patch(self, intention_type_id):
  149. parser = reqparse.RequestParser()
  150. parser.add_argument(
  151. "name",
  152. nullable=False,
  153. required=True,
  154. help="type is required. Name must be between 1 to 40 characters.",
  155. )
  156. args = parser.parse_args()
  157. intention_type = IntentionTypeService.update_intention_type(intention_type_id, args)
  158. return marshal(intention_type, intention_type_detail_fields), 200
  159. @setup_required
  160. @login_required
  161. @account_initialization_required
  162. def delete(self, intention_type_id):
  163. IntentionTypeService.delete_intention_type(intention_type_id)
  164. return 200
  165. class IntentionKeywordListApi(Resource):
  166. @setup_required
  167. @login_required
  168. @account_initialization_required
  169. @marshal_with(intention_keyword_fields)
  170. def get(self, intention_id):
  171. search = request.args.get("search", default=None, type=str)
  172. intention = IntentionService.get_intention(intention_id)
  173. if not intention:
  174. raise NotFound("Intention not found")
  175. intention_keywords = IntentionKeywordService.get_intention_keywords(intention_id, search)
  176. return intention_keywords, 200
  177. @setup_required
  178. @login_required
  179. @account_initialization_required
  180. @marshal_with(intention_keyword_detail_fields)
  181. def post(self, intention_id):
  182. parser = reqparse.RequestParser()
  183. parser.add_argument(
  184. "name",
  185. nullable=False,
  186. required=True,
  187. help="type is required. Name must be between 1 to 40 characters.",
  188. )
  189. args = parser.parse_args()
  190. intention = IntentionService.get_intention(intention_id)
  191. if not intention:
  192. raise NotFound("Intention not found")
  193. intention_keyword = IntentionKeywordService.save_intention_keyword(intention_id, args)
  194. return intention_keyword, 200
  195. @setup_required
  196. @login_required
  197. @account_initialization_required
  198. def delete(self, intention_id):
  199. intention = IntentionService.get_intention(intention_id)
  200. if not intention:
  201. raise NotFound("Intention not found")
  202. IntentionKeywordService.delete_intention_keywords_by_intention_id(intention_id)
  203. return 200
  204. class IntentionKeywordApi(Resource):
  205. @setup_required
  206. @login_required
  207. @account_initialization_required
  208. def get(self, intention_keyword_id):
  209. intention_keyword = IntentionKeywordService.get_intention_keyword(intention_keyword_id)
  210. if not intention_keyword:
  211. return {}, 200
  212. return marshal(intention_keyword, intention_keyword_detail_fields), 200
  213. @setup_required
  214. @login_required
  215. @account_initialization_required
  216. @marshal_with(intention_keyword_detail_fields)
  217. def patch(self, intention_keyword_id):
  218. parser = reqparse.RequestParser()
  219. parser.add_argument(
  220. "name",
  221. nullable=False,
  222. required=True,
  223. help="type is required. Name must be between 1 to 40 characters.",
  224. )
  225. parser.add_argument(
  226. "intention_id",
  227. nullable=False,
  228. required=True,
  229. help="type is required.",
  230. )
  231. args = parser.parse_args()
  232. intention_keyword = IntentionKeywordService.update_intention_keyword(intention_keyword_id, args)
  233. return intention_keyword, 200
  234. @setup_required
  235. @login_required
  236. @account_initialization_required
  237. def delete(self, intention_keyword_id):
  238. IntentionKeywordService.delete_intention_keyword(intention_keyword_id)
  239. return 200
  240. class IntentionKeywordBatchApi(Resource):
  241. @setup_required
  242. @login_required
  243. @account_initialization_required
  244. def post(self):
  245. parser = reqparse.RequestParser()
  246. parser.add_argument(
  247. "method",
  248. nullable=False,
  249. required=True,
  250. help="method is required.",
  251. choices=["create", "update", "delete"],
  252. type=str,
  253. location="json",
  254. )
  255. parser.add_argument(
  256. "delete_data",
  257. nullable=False,
  258. required=True,
  259. help="delete_data is required.",
  260. type=list,
  261. location="json",
  262. )
  263. args = parser.parse_args()
  264. logging.info(args)
  265. method = args["method"]
  266. if method == "delete":
  267. intention_keyword_ids = args["delete_data"]
  268. IntentionKeywordService.delete_intention_keywords(intention_keyword_ids)
  269. return 200
  270. else:
  271. raise NotFound(f"method with name {method} not found")
  272. class IntentionCorpusListApi(Resource):
  273. @setup_required
  274. @login_required
  275. @account_initialization_required
  276. def get(self):
  277. page = request.args.get("page", default=1, type=int)
  278. limit = request.args.get("limit", default=20, type=int)
  279. question_search = request.args.get("question_search", default=None, type=str)
  280. intention_id = request.args.get("intention_id", default=None, type=str)
  281. intention_corpus, total = IntentionCorpusService.get_page_intention_corpus(
  282. page, limit, question_search, intention_id)
  283. data = marshal(intention_corpus, intention_corpus_detail_fields)
  284. response = {"data": data, "has_more": len(intention_corpus) == limit, "limit": limit,
  285. "total": total, "page": page}
  286. return response, 200
  287. @setup_required
  288. @login_required
  289. @account_initialization_required
  290. def post(self):
  291. parser = reqparse.RequestParser()
  292. parser.add_argument(
  293. "question",
  294. nullable=False,
  295. required=True,
  296. help="type is required. Question must be between 1 to 40 characters.",
  297. )
  298. parser.add_argument(
  299. "question_config",
  300. nullable=True,
  301. required=False,
  302. location="json",
  303. )
  304. parser.add_argument(
  305. "intention_id",
  306. nullable=False,
  307. required=True,
  308. help="type is required.",
  309. )
  310. args = parser.parse_args()
  311. intention_corpus = IntentionCorpusService.save_intention_corpus(args)
  312. return marshal(intention_corpus, intention_corpus_detail_fields), 200
  313. class IntentionCorpusApi(Resource):
  314. @setup_required
  315. @login_required
  316. @account_initialization_required
  317. def get(self, corpus_id):
  318. intention_corpus = IntentionCorpusService.get_intention_corpus(corpus_id)
  319. if not intention_corpus:
  320. raise NotFound(f"IntentionCorpus with id {corpus_id} not found")
  321. return marshal(intention_corpus, intention_corpus_detail_fields), 200
  322. @setup_required
  323. @login_required
  324. @account_initialization_required
  325. def patch(self, corpus_id):
  326. parser = reqparse.RequestParser()
  327. parser.add_argument(
  328. "question",
  329. nullable=True,
  330. required=False,
  331. type=str,
  332. location="json",
  333. )
  334. parser.add_argument(
  335. "question_config",
  336. nullable=True,
  337. required=False,
  338. location="json",
  339. )
  340. parser.add_argument(
  341. "intention_id",
  342. nullable=True,
  343. required=False,
  344. type=str,
  345. location="json",
  346. )
  347. args = parser.parse_args()
  348. intention_corpus = IntentionCorpusService.update_intention_corpus(corpus_id, args)
  349. return marshal(intention_corpus, intention_corpus_detail_fields), 200
  350. @setup_required
  351. @login_required
  352. @account_initialization_required
  353. def delete(self, corpus_id):
  354. intention_corpus = IntentionCorpusService.get_intention_corpus(corpus_id)
  355. if not intention_corpus:
  356. raise NotFound(f"未发现Id未{corpus_id}的训练语料")
  357. similarity_questions = intention_corpus.similarity_questions
  358. if similarity_questions:
  359. raise Forbidden(f"存在与其关联的相似问题,无法删除Id为{corpus_id}训练语料")
  360. IntentionCorpusService.delete_intention_corpus(intention_corpus)
  361. return 200
  362. class IntentionCorpusSimilarityQuestionApi(Resource):
  363. @setup_required
  364. @login_required
  365. @account_initialization_required
  366. def get(self, corpus_id):
  367. search = request.args.get("search", default=None, type=str)
  368. similarity_questions = (
  369. IntentionCorpusSimilarityQuestionService
  370. .get_similarity_questions_by_corpus_id_like_question(corpus_id, search)
  371. )
  372. return marshal(similarity_questions, intention_corpus_similarity_question_fields), 200
  373. @setup_required
  374. @login_required
  375. @account_initialization_required
  376. def post(self, corpus_id):
  377. parser = reqparse.RequestParser()
  378. parser.add_argument(
  379. "question",
  380. nullable=False,
  381. required=True,
  382. help="type is required. Question must be between 1 to 40 characters.",
  383. location="json",
  384. )
  385. parser.add_argument(
  386. "question_config",
  387. nullable=True,
  388. required=False,
  389. location="json",
  390. )
  391. args = parser.parse_args()
  392. intention_corpus_similarity_question = (
  393. IntentionCorpusSimilarityQuestionService.save_similarity_question(corpus_id, args)
  394. )
  395. return marshal(intention_corpus_similarity_question, intention_corpus_similarity_question_fields), 200
  396. @setup_required
  397. @login_required
  398. @account_initialization_required
  399. def delete(self, corpus_id):
  400. IntentionCorpusSimilarityQuestionService.delete_similarity_question_by_corpus_id(corpus_id)
  401. return 200
  402. class IntentionCorpusSimilarityQuestionUpdateAndDeleteApi(Resource):
  403. @setup_required
  404. @login_required
  405. @account_initialization_required
  406. def patch(self, similarity_question_id):
  407. parser = reqparse.RequestParser()
  408. parser.add_argument(
  409. "question",
  410. nullable=True,
  411. required=False,
  412. help="type is required. Question must be between 1 to 40 characters.",
  413. location="json",
  414. )
  415. parser.add_argument(
  416. "question_config",
  417. nullable=True,
  418. required=False,
  419. location="json",
  420. )
  421. parser.add_argument(
  422. "corpus_id",
  423. nullable=True,
  424. required=False,
  425. location="json",
  426. )
  427. args = parser.parse_args()
  428. similarity_question = (
  429. IntentionCorpusSimilarityQuestionService.update_similarity_question(similarity_question_id, args)
  430. )
  431. return marshal(similarity_question, intention_corpus_similarity_question_fields), 200
  432. @setup_required
  433. @login_required
  434. @account_initialization_required
  435. def delete(self, similarity_question_id):
  436. IntentionCorpusSimilarityQuestionService.delete_similarity_question_by_id(similarity_question_id)
  437. return 200
  438. class IntentionCorpusSimilarityQuestionBatchApi(Resource):
  439. @setup_required
  440. @login_required
  441. @account_initialization_required
  442. def post(self):
  443. parser = reqparse.RequestParser()
  444. parser.add_argument(
  445. "method",
  446. nullable=False,
  447. required=True,
  448. help="method is required.",
  449. choices=["create", "update", "delete"],
  450. type=str,
  451. location="json",
  452. )
  453. parser.add_argument(
  454. "data",
  455. nullable=False,
  456. required=True,
  457. help="data is required.",
  458. type=list,
  459. location="json",
  460. )
  461. args = parser.parse_args()
  462. logging.info(args)
  463. method = args["method"]
  464. if method == "delete":
  465. similarity_question_ids = args["data"]
  466. IntentionCorpusSimilarityQuestionService.delete_similarity_questions_by_ids(similarity_question_ids)
  467. return 200
  468. else:
  469. raise NotFound(f"method with name {method} not found")
  470. class IntentionTrainTaskListApi(Resource):
  471. @setup_required
  472. @login_required
  473. @account_initialization_required
  474. def get(self):
  475. page = request.args.get("page", default=1, type=int)
  476. limit = request.args.get("limit", default=20, type=int)
  477. search = request.args.get("search", default=None, type=str)
  478. intention_train_tasks, total = IntentionTrainTaskService.get_page_intention_train_tasks(
  479. page, limit, search)
  480. data = marshal(intention_train_tasks, intention_train_task_fields)
  481. response = {"data": data, "has_more": len(intention_train_tasks) == limit, "limit": limit,
  482. "total": total, "page": page}
  483. return response, 200
  484. @setup_required
  485. @login_required
  486. @account_initialization_required
  487. def post(self):
  488. parser = reqparse.RequestParser()
  489. parser.add_argument(
  490. "name",
  491. nullable=False,
  492. required=True,
  493. help="name is required.",
  494. location="json",
  495. )
  496. parser.add_argument(
  497. "status",
  498. nullable=False,
  499. required=True,
  500. help="status is required.",
  501. choices=IntentionTrainTask.STATUS_LIST,
  502. location="json",
  503. )
  504. args = parser.parse_args()
  505. train_task = IntentionTrainTaskService.save_train_task(args)
  506. return marshal(train_task, intention_train_task_fields), 200
  507. class IntentionTrainTaskApi(Resource):
  508. @setup_required
  509. @login_required
  510. @account_initialization_required
  511. def patch(self, task_id):
  512. parser = reqparse.RequestParser()
  513. parser.add_argument(
  514. "name",
  515. nullable=False,
  516. required=True,
  517. help="name is required.",
  518. location="json",
  519. )
  520. parser.add_argument(
  521. "status",
  522. nullable=False,
  523. required=True,
  524. help="status is required.",
  525. choices=IntentionTrainTask.STATUS_LIST,
  526. location="json",
  527. )
  528. args = parser.parse_args()
  529. train_task = IntentionTrainTaskService.update_train_task(task_id, args)
  530. return marshal(train_task, intention_train_task_fields), 200
  531. class IntentionTrainTaskDownloadApi(Resource):
  532. @setup_required
  533. @login_required
  534. @account_initialization_required
  535. def get(self, task_id):
  536. train_task = IntentionTrainTaskService.get_train_task(task_id)
  537. if train_task.status != "COMPLETED":
  538. raise Forbidden(f"Task with id {task_id} not completed")
  539. dataset_info = train_task.dataset_info
  540. dataset_source_info = json.loads(dataset_info.data_source_info)
  541. dataset_file_id = dataset_source_info["upload_file_id"]
  542. dataset_upload_file: UploadFile = UploadFileService.get_upload_file(dataset_file_id)
  543. model_info = train_task.model_info
  544. model_source_info = json.loads(model_info.data_source_info)
  545. model_file_id = model_source_info["upload_file_id"]
  546. model_upload_file: UploadFile = UploadFileService.get_upload_file(model_file_id)
  547. def file2zip(zip_filename: str, upload_files: list[UploadFile]):
  548. with zipfile.ZipFile(zip_filename, "w", compression=zipfile.ZIP_DEFLATED) as zip_file:
  549. for upload_file in upload_files:
  550. filename = f"storage/{dataset_upload_file.key}"
  551. zip_file.write(filename, arcname=upload_file.name)
  552. # 生成待下载的zip包
  553. zip_filename = f"{train_task.name}.zip"
  554. upload_files: list[UploadFile] = [dataset_upload_file, model_upload_file]
  555. file2zip(zip_filename, upload_files)
  556. # 下载zip包
  557. response = send_file(zip_filename, as_attachment=True, download_name=zip_filename)
  558. os.remove(zip_filename)
  559. return response
  560. class IntentionTrainFileApi(Resource):
  561. @setup_required
  562. @login_required
  563. @account_initialization_required
  564. def get(self):
  565. name = request.args.get("name", default=None, type=str)
  566. version = request.args.get("version", default=None, type=str)
  567. type = request.args.get("type", default=None, type=str)
  568. train_files = IntentionTrainFileService.get_train_files(name, version, type)
  569. return marshal(train_files, intention_train_file_fields), 200
  570. @setup_required
  571. @login_required
  572. @account_initialization_required
  573. def post(self):
  574. name = request.form.get("name")
  575. version = request.form.get("version")
  576. type = request.form.get("type")
  577. train_file = IntentionTrainFileService.get_train_file(name, version, type)
  578. if train_file:
  579. raise IntentionTrainFileDuplicateError(f"IntentionTrainFile with name-version-type "
  580. f"{name}-{version}-{type} already exists.")
  581. data_source_type = request.form.get("data_source_type")
  582. # get file from request
  583. file = request.files["file"]
  584. filename = file.filename
  585. mimetype = file.mimetype
  586. if not filename or not mimetype:
  587. raise Forbidden("Invalid request.")
  588. try:
  589. upload_file = FileService.upload_file(
  590. filename=filename,
  591. content=file.read(),
  592. mimetype=mimetype,
  593. user=current_user,
  594. source=None,
  595. )
  596. args = {
  597. "name": name,
  598. "version": version,
  599. "type": type,
  600. "data_source_type": data_source_type,
  601. "data_source_info": {
  602. "upload_file_id": upload_file.id
  603. }
  604. }
  605. intention_train_file = IntentionTrainFileService.save_train_file(args)
  606. return marshal(intention_train_file, intention_train_file_fields), 200
  607. except services.errors.file.FileTooLargeError as file_too_large_error:
  608. raise FileTooLargeError(file_too_large_error.description)
  609. except services.errors.file.UnsupportedFileTypeError:
  610. raise UnsupportedFileTypeError()
  611. class IntentionTrainFileBindingApi(Resource):
  612. @setup_required
  613. @login_required
  614. @account_initialization_required
  615. def post(self):
  616. parser = reqparse.RequestParser()
  617. parser.add_argument(
  618. "file_id",
  619. nullable=False,
  620. required=True,
  621. help="file_id is required.",
  622. location="json",
  623. )
  624. parser.add_argument(
  625. "task_id",
  626. nullable=False,
  627. required=True,
  628. help="task_id is required.",
  629. location="json",
  630. )
  631. args = parser.parse_args()
  632. train_file_binding = IntentionTrainFileBindingService.save_train_file_binding(args)
  633. return marshal(train_file_binding, intention_train_file_binding_fields), 200
  634. api.add_resource(IntentionListApi, "/intentions")
  635. api.add_resource(IntentionApi, "/intentions/<uuid:intention_id>")
  636. api.add_resource(IntentionTypeListApi, "/intentions/types")
  637. api.add_resource(IntentionTypeApi, "/intentions/types/<uuid:intention_type_id>")
  638. api.add_resource(IntentionKeywordListApi, "/intentions/<uuid:intention_id>/keywords")
  639. api.add_resource(IntentionKeywordApi, "/intentions/keywords/<uuid:intention_keyword_id>")
  640. api.add_resource(IntentionKeywordBatchApi, "/intentions/keywords/batch")
  641. api.add_resource(IntentionCorpusListApi, "/intentions/corpus")
  642. api.add_resource(IntentionCorpusApi, "/intentions/corpus/<uuid:corpus_id>")
  643. api.add_resource(IntentionCorpusSimilarityQuestionApi, "/intentions/corpus/<uuid:corpus_id>/similarity_questions")
  644. api.add_resource(IntentionCorpusSimilarityQuestionUpdateAndDeleteApi,
  645. "/intentions/similarity_questions/<uuid:similarity_question_id>")
  646. api.add_resource(IntentionCorpusSimilarityQuestionBatchApi, "/intentions/similarity_questions/batch")
  647. api.add_resource(IntentionTrainTaskListApi, "/intentions/train_tasks")
  648. api.add_resource(IntentionTrainTaskApi, "/intentions/train_tasks/<uuid:task_id>")
  649. api.add_resource(IntentionTrainTaskDownloadApi, "/intentions/train_tasks/download/<uuid:task_id>")
  650. api.add_resource(IntentionTrainFileApi, "/intentions/train_files")
  651. api.add_resource(IntentionTrainFileBindingApi, "/intentions/train_file_bindings")