intention.py 25 KB


  1. import json
  2. import logging
  3. import os
  4. import zipfile
  5. from flask import request, send_file
  6. from flask_restful import Resource, marshal, marshal_with, reqparse
  7. from werkzeug.exceptions import Forbidden, NotFound
  8. import services
  9. from controllers.console import api
  10. from controllers.console.error import FileTooLargeError, UnsupportedFileTypeError
  11. from controllers.console.wraps import account_initialization_required, setup_required
  12. from extensions.ext_storage import storage
  13. from fields.intention_fields import (
  14. intention_corpus_detail_fields,
  15. intention_corpus_similarity_question_fields,
  16. intention_detail_fields,
  17. intention_keyword_detail_fields,
  18. intention_keyword_fields,
  19. intention_page_fields,
  20. intention_train_file_binding_fields,
  21. intention_train_file_fields,
  22. intention_train_task_fields,
  23. intention_type_detail_fields,
  24. intention_type_page_fields,
  25. )
  26. from libs.login import current_user, login_required
  27. from models import UploadFile
  28. from models.intention import IntentionTrainFile, IntentionTrainTask
  29. from services.errors.intention import IntentionTrainFileDuplicateError
  30. from services.file_service import FileService
  31. from services.intention_service import (
  32. IntentionCorpusService,
  33. IntentionCorpusSimilarityQuestionService,
  34. IntentionKeywordService,
  35. IntentionService,
  36. IntentionTrainFileBindingService,
  37. IntentionTrainFileService,
  38. IntentionTrainTaskService,
  39. IntentionTypeService,
  40. )
  41. from services.upload_file_service import UploadFileService
  42. class IntentionListApi(Resource):
  43. @setup_required
  44. @login_required
  45. @account_initialization_required
  46. def get(self):
  47. page = request.args.get("page", default=1, type=int)
  48. limit = request.args.get("limit", default=20, type=int)
  49. type_id = request.args.get("type_id", default=None, type=str)
  50. name_search = request.args.get("name_search", default=None, type=str)
  51. intentions, total = IntentionService.get_intentions(page, limit, type_id, name_search)
  52. data = marshal(intentions, intention_page_fields)
  53. response = {"data": data, "has_more": len(intentions) == limit, "limit": limit,
  54. "total": total, "page": page}
  55. return response, 200
  56. @setup_required
  57. @login_required
  58. @account_initialization_required
  59. def post(self):
  60. parser = reqparse.RequestParser()
  61. parser.add_argument(
  62. "name",
  63. nullable=False,
  64. required=True,
  65. help="type is required. Name must be between 1 to 40 characters.",
  66. )
  67. parser.add_argument(
  68. "type_id",
  69. nullable=False,
  70. required=True,
  71. help="type is required.",
  72. )
  73. args = parser.parse_args()
  74. intention = IntentionService.save_intention(args)
  75. response = marshal(intention, intention_detail_fields)
  76. return response, 200
  77. class IntentionApi(Resource):
  78. @setup_required
  79. @login_required
  80. @account_initialization_required
  81. def get(self, intention_id):
  82. intention = IntentionService.get_intention(intention_id)
  83. return marshal(intention, intention_detail_fields)
  84. @setup_required
  85. @login_required
  86. @account_initialization_required
  87. def patch(self, intention_id):
  88. parser = reqparse.RequestParser()
  89. parser.add_argument(
  90. "name",
  91. nullable=False,
  92. required=True,
  93. help="type is required. Name must be between 1 to 40 characters.",
  94. )
  95. parser.add_argument(
  96. "type_id",
  97. nullable=False,
  98. required=True,
  99. help="type is required.",
  100. )
  101. args = parser.parse_args()
  102. intention = IntentionService.update_intention(intention_id, args)
  103. response = marshal(intention, intention_detail_fields)
  104. return response, 200
  105. @setup_required
  106. @login_required
  107. @account_initialization_required
  108. def delete(self, intention_id):
  109. IntentionService.delete_intention(intention_id)
  110. return 200
  111. class IntentionTypeListApi(Resource):
  112. @setup_required
  113. @login_required
  114. @account_initialization_required
  115. def get(self):
  116. page = request.args.get("page", default=1, type=int)
  117. limit = request.args.get("limit", default=20, type=int)
  118. search = request.args.get("search", default=None, type=str)
  119. intention_types, total = IntentionTypeService.get_intention_types(page, limit, search)
  120. data = marshal(intention_types, intention_type_page_fields)
  121. response = {"data": data, "has_more": len(intention_types) == limit, "limit": limit,
  122. "total": total, "page": page}
  123. return response, 200
  124. @setup_required
  125. @login_required
  126. @account_initialization_required
  127. def post(self):
  128. parser = reqparse.RequestParser()
  129. parser.add_argument(
  130. "name",
  131. nullable=False,
  132. required=True,
  133. help="type is required. Name must be between 1 to 40 characters.",
  134. )
  135. args = parser.parse_args()
  136. intention_type = IntentionTypeService.save_intention_type(args)
  137. response = marshal(intention_type, intention_type_detail_fields)
  138. return response, 200
  139. class IntentionTypeApi(Resource):
  140. @setup_required
  141. @login_required
  142. @account_initialization_required
  143. def get(self, intention_type_id):
  144. intention_type = IntentionTypeService.get_intention_type(intention_type_id)
  145. return marshal(intention_type, intention_type_detail_fields)
  146. @setup_required
  147. @login_required
  148. @account_initialization_required
  149. def patch(self, intention_type_id):
  150. parser = reqparse.RequestParser()
  151. parser.add_argument(
  152. "name",
  153. nullable=False,
  154. required=True,
  155. help="type is required. Name must be between 1 to 40 characters.",
  156. )
  157. args = parser.parse_args()
  158. intention_type = IntentionTypeService.update_intention_type(intention_type_id, args)
  159. return marshal(intention_type, intention_type_detail_fields), 200
  160. @setup_required
  161. @login_required
  162. @account_initialization_required
  163. def delete(self, intention_type_id):
  164. IntentionTypeService.delete_intention_type(intention_type_id)
  165. return 200
  166. class IntentionKeywordListApi(Resource):
  167. @setup_required
  168. @login_required
  169. @account_initialization_required
  170. @marshal_with(intention_keyword_fields)
  171. def get(self, intention_id):
  172. search = request.args.get("search", default=None, type=str)
  173. intention = IntentionService.get_intention(intention_id)
  174. if not intention:
  175. raise NotFound("Intention not found")
  176. intention_keywords = IntentionKeywordService.get_intention_keywords(intention_id, search)
  177. return intention_keywords, 200
  178. @setup_required
  179. @login_required
  180. @account_initialization_required
  181. @marshal_with(intention_keyword_detail_fields)
  182. def post(self, intention_id):
  183. parser = reqparse.RequestParser()
  184. parser.add_argument(
  185. "name",
  186. nullable=False,
  187. required=True,
  188. help="type is required. Name must be between 1 to 40 characters.",
  189. )
  190. args = parser.parse_args()
  191. intention = IntentionService.get_intention(intention_id)
  192. if not intention:
  193. raise NotFound("Intention not found")
  194. intention_keyword = IntentionKeywordService.save_intention_keyword(intention_id, args)
  195. return intention_keyword, 200
  196. @setup_required
  197. @login_required
  198. @account_initialization_required
  199. def delete(self, intention_id):
  200. intention = IntentionService.get_intention(intention_id)
  201. if not intention:
  202. raise NotFound("Intention not found")
  203. IntentionKeywordService.delete_intention_keywords_by_intention_id(intention_id)
  204. return 200
  205. class IntentionKeywordApi(Resource):
  206. @setup_required
  207. @login_required
  208. @account_initialization_required
  209. def get(self, intention_keyword_id):
  210. intention_keyword = IntentionKeywordService.get_intention_keyword(intention_keyword_id)
  211. if not intention_keyword:
  212. return {}, 200
  213. return marshal(intention_keyword, intention_keyword_detail_fields), 200
  214. @setup_required
  215. @login_required
  216. @account_initialization_required
  217. @marshal_with(intention_keyword_detail_fields)
  218. def patch(self, intention_keyword_id):
  219. parser = reqparse.RequestParser()
  220. parser.add_argument(
  221. "name",
  222. nullable=False,
  223. required=True,
  224. help="type is required. Name must be between 1 to 40 characters.",
  225. )
  226. parser.add_argument(
  227. "intention_id",
  228. nullable=False,
  229. required=True,
  230. help="type is required.",
  231. )
  232. args = parser.parse_args()
  233. intention_keyword = IntentionKeywordService.update_intention_keyword(intention_keyword_id, args)
  234. return intention_keyword, 200
  235. @setup_required
  236. @login_required
  237. @account_initialization_required
  238. def delete(self, intention_keyword_id):
  239. IntentionKeywordService.delete_intention_keyword(intention_keyword_id)
  240. return 200
  241. class IntentionKeywordBatchApi(Resource):
  242. @setup_required
  243. @login_required
  244. @account_initialization_required
  245. def post(self):
  246. parser = reqparse.RequestParser()
  247. parser.add_argument(
  248. "method",
  249. nullable=False,
  250. required=True,
  251. help="method is required.",
  252. choices=["create", "update", "delete"],
  253. type=str,
  254. location="json",
  255. )
  256. parser.add_argument(
  257. "delete_data",
  258. nullable=False,
  259. required=True,
  260. help="delete_data is required.",
  261. type=list,
  262. location="json",
  263. )
  264. args = parser.parse_args()
  265. logging.info(args)
  266. method = args["method"]
  267. if method == "delete":
  268. intention_keyword_ids = args["delete_data"]
  269. IntentionKeywordService.delete_intention_keywords(intention_keyword_ids)
  270. return 200
  271. else:
  272. raise NotFound(f"method with name {method} not found")
  273. class IntentionCorpusListApi(Resource):
  274. @setup_required
  275. @login_required
  276. @account_initialization_required
  277. def get(self):
  278. page = request.args.get("page", default=1, type=int)
  279. limit = request.args.get("limit", default=20, type=int)
  280. question_search = request.args.get("question_search", default=None, type=str)
  281. intention_id = request.args.get("intention_id", default=None, type=str)
  282. intention_corpus, total = IntentionCorpusService.get_page_intention_corpus(
  283. page, limit, question_search, intention_id)
  284. data = marshal(intention_corpus, intention_corpus_detail_fields)
  285. response = {"data": data, "has_more": len(intention_corpus) == limit, "limit": limit,
  286. "total": total, "page": page}
  287. return response, 200
  288. @setup_required
  289. @login_required
  290. @account_initialization_required
  291. def post(self):
  292. parser = reqparse.RequestParser()
  293. parser.add_argument(
  294. "question",
  295. nullable=False,
  296. required=True,
  297. help="type is required. Question must be between 1 to 40 characters.",
  298. )
  299. parser.add_argument(
  300. "question_config",
  301. nullable=True,
  302. required=False,
  303. location="json",
  304. )
  305. parser.add_argument(
  306. "intention_id",
  307. nullable=False,
  308. required=True,
  309. help="type is required.",
  310. )
  311. args = parser.parse_args()
  312. intention_corpus = IntentionCorpusService.save_intention_corpus(args)
  313. return marshal(intention_corpus, intention_corpus_detail_fields), 200
  314. class IntentionCorpusApi(Resource):
  315. @setup_required
  316. @login_required
  317. @account_initialization_required
  318. def get(self, corpus_id):
  319. intention_corpus = IntentionCorpusService.get_intention_corpus(corpus_id)
  320. if not intention_corpus:
  321. raise NotFound(f"IntentionCorpus with id {corpus_id} not found")
  322. return marshal(intention_corpus, intention_corpus_detail_fields), 200
  323. @setup_required
  324. @login_required
  325. @account_initialization_required
  326. def patch(self, corpus_id):
  327. parser = reqparse.RequestParser()
  328. parser.add_argument(
  329. "question",
  330. nullable=True,
  331. required=False,
  332. type=str,
  333. location="json",
  334. )
  335. parser.add_argument(
  336. "question_config",
  337. nullable=True,
  338. required=False,
  339. location="json",
  340. )
  341. parser.add_argument(
  342. "intention_id",
  343. nullable=True,
  344. required=False,
  345. type=str,
  346. location="json",
  347. )
  348. args = parser.parse_args()
  349. intention_corpus = IntentionCorpusService.update_intention_corpus(corpus_id, args)
  350. return marshal(intention_corpus, intention_corpus_detail_fields), 200
  351. @setup_required
  352. @login_required
  353. @account_initialization_required
  354. def delete(self, corpus_id):
  355. intention_corpus = IntentionCorpusService.get_intention_corpus(corpus_id)
  356. if not intention_corpus:
  357. raise NotFound(f"未发现Id未{corpus_id}的训练语料")
  358. similarity_questions = intention_corpus.similarity_questions
  359. if similarity_questions:
  360. raise Forbidden(f"存在与其关联的相似问题,无法删除Id为{corpus_id}训练语料")
  361. IntentionCorpusService.delete_intention_corpus(intention_corpus)
  362. return 200
  363. class IntentionCorpusSimilarityQuestionApi(Resource):
  364. @setup_required
  365. @login_required
  366. @account_initialization_required
  367. def get(self, corpus_id):
  368. search = request.args.get("search", default=None, type=str)
  369. similarity_questions = (
  370. IntentionCorpusSimilarityQuestionService
  371. .get_similarity_questions_by_corpus_id_like_question(corpus_id, search)
  372. )
  373. return marshal(similarity_questions, intention_corpus_similarity_question_fields), 200
  374. @setup_required
  375. @login_required
  376. @account_initialization_required
  377. def post(self, corpus_id):
  378. parser = reqparse.RequestParser()
  379. parser.add_argument(
  380. "question",
  381. nullable=False,
  382. required=True,
  383. help="type is required. Question must be between 1 to 40 characters.",
  384. location="json",
  385. )
  386. parser.add_argument(
  387. "question_config",
  388. nullable=True,
  389. required=False,
  390. location="json",
  391. )
  392. args = parser.parse_args()
  393. intention_corpus_similarity_question = (
  394. IntentionCorpusSimilarityQuestionService.save_similarity_question(corpus_id, args)
  395. )
  396. return marshal(intention_corpus_similarity_question, intention_corpus_similarity_question_fields), 200
  397. @setup_required
  398. @login_required
  399. @account_initialization_required
  400. def delete(self, corpus_id):
  401. IntentionCorpusSimilarityQuestionService.delete_similarity_question_by_corpus_id(corpus_id)
  402. return 200
  403. class IntentionCorpusSimilarityQuestionUpdateAndDeleteApi(Resource):
  404. @setup_required
  405. @login_required
  406. @account_initialization_required
  407. def patch(self, similarity_question_id):
  408. parser = reqparse.RequestParser()
  409. parser.add_argument(
  410. "question",
  411. nullable=True,
  412. required=False,
  413. help="type is required. Question must be between 1 to 40 characters.",
  414. location="json",
  415. )
  416. parser.add_argument(
  417. "question_config",
  418. nullable=True,
  419. required=False,
  420. location="json",
  421. )
  422. parser.add_argument(
  423. "corpus_id",
  424. nullable=True,
  425. required=False,
  426. location="json",
  427. )
  428. args = parser.parse_args()
  429. similarity_question = (
  430. IntentionCorpusSimilarityQuestionService.update_similarity_question(similarity_question_id, args)
  431. )
  432. return marshal(similarity_question, intention_corpus_similarity_question_fields), 200
  433. @setup_required
  434. @login_required
  435. @account_initialization_required
  436. def delete(self, similarity_question_id):
  437. IntentionCorpusSimilarityQuestionService.delete_similarity_question_by_id(similarity_question_id)
  438. return 200
  439. class IntentionCorpusSimilarityQuestionBatchApi(Resource):
  440. @setup_required
  441. @login_required
  442. @account_initialization_required
  443. def post(self):
  444. parser = reqparse.RequestParser()
  445. parser.add_argument(
  446. "method",
  447. nullable=False,
  448. required=True,
  449. help="method is required.",
  450. choices=["create", "update", "delete"],
  451. type=str,
  452. location="json",
  453. )
  454. parser.add_argument(
  455. "data",
  456. nullable=False,
  457. required=True,
  458. help="data is required.",
  459. type=list,
  460. location="json",
  461. )
  462. args = parser.parse_args()
  463. logging.info(args)
  464. method = args["method"]
  465. if method == "delete":
  466. similarity_question_ids = args["data"]
  467. IntentionCorpusSimilarityQuestionService.delete_similarity_questions_by_ids(similarity_question_ids)
  468. return 200
  469. else:
  470. raise NotFound(f"method with name {method} not found")
  471. class IntentionTrainTaskListApi(Resource):
  472. @setup_required
  473. @login_required
  474. @account_initialization_required
  475. def get(self):
  476. page = request.args.get("page", default=1, type=int)
  477. limit = request.args.get("limit", default=20, type=int)
  478. search = request.args.get("search", default=None, type=str)
  479. intention_train_tasks, total = IntentionTrainTaskService.get_page_intention_train_tasks(
  480. page, limit, search)
  481. data = marshal(intention_train_tasks, intention_train_task_fields)
  482. response = {"data": data, "has_more": len(intention_train_tasks) == limit, "limit": limit,
  483. "total": total, "page": page}
  484. return response, 200
  485. @setup_required
  486. @login_required
  487. @account_initialization_required
  488. def post(self):
  489. parser = reqparse.RequestParser()
  490. parser.add_argument(
  491. "name",
  492. nullable=False,
  493. required=True,
  494. help="name is required.",
  495. location="json",
  496. )
  497. parser.add_argument(
  498. "status",
  499. nullable=False,
  500. required=True,
  501. help="status is required.",
  502. choices=IntentionTrainTask.STATUS_LIST,
  503. location="json",
  504. )
  505. args = parser.parse_args()
  506. train_task = IntentionTrainTaskService.save_train_task(args)
  507. return marshal(train_task, intention_train_task_fields), 200
  508. class IntentionTrainTaskApi(Resource):
  509. @setup_required
  510. @login_required
  511. @account_initialization_required
  512. def patch(self, task_id):
  513. parser = reqparse.RequestParser()
  514. parser.add_argument(
  515. "name",
  516. nullable=False,
  517. required=True,
  518. help="name is required.",
  519. location="json",
  520. )
  521. parser.add_argument(
  522. "status",
  523. nullable=False,
  524. required=True,
  525. help="status is required.",
  526. choices=IntentionTrainTask.STATUS_LIST,
  527. location="json",
  528. )
  529. args = parser.parse_args()
  530. train_task = IntentionTrainTaskService.update_train_task(task_id, args)
  531. return marshal(train_task, intention_train_task_fields), 200
  532. class IntentionTrainTaskDownloadApi(Resource):
  533. @setup_required
  534. @login_required
  535. @account_initialization_required
  536. def get(self, task_id):
  537. train_task = IntentionTrainTaskService.get_train_task(task_id)
  538. if train_task.status != "COMPLETED":
  539. raise Forbidden(f"Task with id {task_id} not completed")
  540. dataset_info: IntentionTrainFile = train_task.dataset_info
  541. model_info: IntentionTrainFile = train_task.model_info
  542. # 生成待下载的zip包
  543. zip_filename = f"{train_task.name}.zip"
  544. with zipfile.ZipFile(zip_filename, "w", compression=zipfile.ZIP_DEFLATED) as zip_file:
  545. for train_file in [dataset_info, model_info]:
  546. source_info = json.loads(train_file.data_source_info)
  547. upload_file_id = source_info["upload_file_id"]
  548. upload_file: UploadFile = UploadFileService.get_upload_file(upload_file_id)
  549. storage.download(upload_file.key, upload_file.name)
  550. zip_file.write(upload_file.name, upload_file.name)
  551. os.remove(upload_file.name)
  552. # 下载zip包
  553. response = send_file(zip_filename, as_attachment=True, download_name=zip_filename)
  554. # 清除临时文件
  555. os.remove(zip_filename)
  556. return response
  557. class IntentionTrainFileApi(Resource):
  558. @setup_required
  559. @login_required
  560. @account_initialization_required
  561. def get(self):
  562. name = request.args.get("name", default=None, type=str)
  563. version = request.args.get("version", default=None, type=str)
  564. type = request.args.get("type", default=None, type=str)
  565. train_files = IntentionTrainFileService.get_train_files(name, version, type)
  566. return marshal(train_files, intention_train_file_fields), 200
  567. @setup_required
  568. @login_required
  569. @account_initialization_required
  570. def post(self):
  571. name = request.form.get("name")
  572. version = request.form.get("version")
  573. type = request.form.get("type")
  574. train_file = IntentionTrainFileService.get_train_file(name, version, type)
  575. if train_file:
  576. raise IntentionTrainFileDuplicateError(f"IntentionTrainFile with name-version-type "
  577. f"{name}-{version}-{type} already exists.")
  578. data_source_type = request.form.get("data_source_type")
  579. # get file from request
  580. file = request.files["file"]
  581. filename = file.filename
  582. mimetype = file.mimetype
  583. if not filename or not mimetype:
  584. raise Forbidden("Invalid request.")
  585. try:
  586. upload_file = FileService.upload_file(
  587. filename=filename,
  588. content=file.read(),
  589. mimetype=mimetype,
  590. user=current_user,
  591. source=None,
  592. )
  593. args = {
  594. "name": name,
  595. "version": version,
  596. "type": type,
  597. "data_source_type": data_source_type,
  598. "data_source_info": {
  599. "upload_file_id": upload_file.id
  600. }
  601. }
  602. intention_train_file = IntentionTrainFileService.save_train_file(args)
  603. return marshal(intention_train_file, intention_train_file_fields), 200
  604. except services.errors.file.FileTooLargeError as file_too_large_error:
  605. raise FileTooLargeError(file_too_large_error.description)
  606. except services.errors.file.UnsupportedFileTypeError:
  607. raise UnsupportedFileTypeError()
  608. class IntentionTrainFileBindingApi(Resource):
  609. @setup_required
  610. @login_required
  611. @account_initialization_required
  612. def post(self):
  613. parser = reqparse.RequestParser()
  614. parser.add_argument(
  615. "file_id",
  616. nullable=False,
  617. required=True,
  618. help="file_id is required.",
  619. location="json",
  620. )
  621. parser.add_argument(
  622. "task_id",
  623. nullable=False,
  624. required=True,
  625. help="task_id is required.",
  626. location="json",
  627. )
  628. args = parser.parse_args()
  629. train_file_binding = IntentionTrainFileBindingService.save_train_file_binding(args)
  630. return marshal(train_file_binding, intention_train_file_binding_fields), 200
  631. api.add_resource(IntentionListApi, "/intentions")
  632. api.add_resource(IntentionApi, "/intentions/<uuid:intention_id>")
  633. api.add_resource(IntentionTypeListApi, "/intentions/types")
  634. api.add_resource(IntentionTypeApi, "/intentions/types/<uuid:intention_type_id>")
  635. api.add_resource(IntentionKeywordListApi, "/intentions/<uuid:intention_id>/keywords")
  636. api.add_resource(IntentionKeywordApi, "/intentions/keywords/<uuid:intention_keyword_id>")
  637. api.add_resource(IntentionKeywordBatchApi, "/intentions/keywords/batch")
  638. api.add_resource(IntentionCorpusListApi, "/intentions/corpus")
  639. api.add_resource(IntentionCorpusApi, "/intentions/corpus/<uuid:corpus_id>")
  640. api.add_resource(IntentionCorpusSimilarityQuestionApi, "/intentions/corpus/<uuid:corpus_id>/similarity_questions")
  641. api.add_resource(IntentionCorpusSimilarityQuestionUpdateAndDeleteApi,
  642. "/intentions/similarity_questions/<uuid:similarity_question_id>")
  643. api.add_resource(IntentionCorpusSimilarityQuestionBatchApi, "/intentions/similarity_questions/batch")
  644. api.add_resource(IntentionTrainTaskListApi, "/intentions/train_tasks")
  645. api.add_resource(IntentionTrainTaskApi, "/intentions/train_tasks/<uuid:task_id>")
  646. api.add_resource(IntentionTrainTaskDownloadApi, "/intentions/train_tasks/download/<uuid:task_id>")
  647. api.add_resource(IntentionTrainFileApi, "/intentions/train_files")
  648. api.add_resource(IntentionTrainFileBindingApi, "/intentions/train_file_bindings")