datasets_segments.py 29 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669
  1. import uuid
  2. import pandas as pd
  3. from flask import request
  4. from flask_login import current_user # type: ignore
  5. from flask_restful import Resource, marshal, reqparse # type: ignore
  6. from werkzeug.exceptions import Forbidden, NotFound
  7. import services
  8. from controllers.console import api
  9. from controllers.console.app.error import ProviderNotInitializeError
  10. from controllers.console.datasets.error import (
  11. ChildChunkDeleteIndexError,
  12. ChildChunkIndexingError,
  13. InvalidActionError,
  14. NoFileUploadedError,
  15. TooManyFilesError,
  16. )
  17. from controllers.console.wraps import (
  18. account_initialization_required,
  19. cloud_edition_billing_knowledge_limit_check,
  20. cloud_edition_billing_rate_limit_check,
  21. cloud_edition_billing_resource_check,
  22. setup_required,
  23. )
  24. from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError
  25. from core.model_manager import ModelManager
  26. from core.model_runtime.entities.model_entities import ModelType
  27. from extensions.ext_redis import redis_client
  28. from fields.segment_fields import child_chunk_fields, segment_fields
  29. from libs.login import login_required
  30. from models.dataset import ChildChunk, DocumentSegment
  31. from services.dataset_service import DatasetService, DocumentService, SegmentService
  32. from services.entities.knowledge_entities.knowledge_entities import ChildChunkUpdateArgs, SegmentUpdateArgs
  33. from services.errors.chunk import ChildChunkDeleteIndexError as ChildChunkDeleteIndexServiceError
  34. from services.errors.chunk import ChildChunkIndexingError as ChildChunkIndexingServiceError
  35. from tasks.batch_create_segment_to_index_task import batch_create_segment_to_index_task
  36. class DatasetDocumentSegmentListApi(Resource):
  37. @setup_required
  38. @login_required
  39. @account_initialization_required
  40. def get(self, dataset_id, document_id):
  41. dataset_id = str(dataset_id)
  42. document_id = str(document_id)
  43. dataset = DatasetService.get_dataset(dataset_id)
  44. if not dataset:
  45. raise NotFound("Dataset not found.")
  46. try:
  47. DatasetService.check_dataset_permission(dataset, current_user)
  48. except services.errors.account.NoPermissionError as e:
  49. raise Forbidden(str(e))
  50. document = DocumentService.get_document(dataset_id, document_id)
  51. if not document:
  52. raise NotFound("Document not found.")
  53. parser = reqparse.RequestParser()
  54. parser.add_argument("limit", type=int, default=20, location="args")
  55. parser.add_argument("status", type=str, action="append", default=[], location="args")
  56. parser.add_argument("hit_count_gte", type=int, default=None, location="args")
  57. parser.add_argument("enabled", type=str, default="all", location="args")
  58. parser.add_argument("keyword", type=str, default=None, location="args")
  59. parser.add_argument("page", type=int, default=1, location="args")
  60. args = parser.parse_args()
  61. page = args["page"]
  62. limit = min(args["limit"], 100)
  63. status_list = args["status"]
  64. hit_count_gte = args["hit_count_gte"]
  65. keyword = args["keyword"]
  66. query = DocumentSegment.query.filter(
  67. DocumentSegment.document_id == str(document_id), DocumentSegment.tenant_id == current_user.current_tenant_id
  68. ).order_by(DocumentSegment.position.asc())
  69. if status_list:
  70. query = query.filter(DocumentSegment.status.in_(status_list))
  71. if hit_count_gte is not None:
  72. query = query.filter(DocumentSegment.hit_count >= hit_count_gte)
  73. if keyword:
  74. query = query.where(DocumentSegment.content.ilike(f"%{keyword}%"))
  75. if args["enabled"].lower() != "all":
  76. if args["enabled"].lower() == "true":
  77. query = query.filter(DocumentSegment.enabled == True)
  78. elif args["enabled"].lower() == "false":
  79. query = query.filter(DocumentSegment.enabled == False)
  80. segments = query.paginate(page=page, per_page=limit, max_per_page=100, error_out=False)
  81. response = {
  82. "data": marshal(segments.items, segment_fields),
  83. "limit": limit,
  84. "total": segments.total,
  85. "total_pages": segments.pages,
  86. "page": page,
  87. }
  88. return response, 200
  89. @setup_required
  90. @login_required
  91. @account_initialization_required
  92. @cloud_edition_billing_rate_limit_check("knowledge")
  93. def delete(self, dataset_id, document_id):
  94. # check dataset
  95. dataset_id = str(dataset_id)
  96. dataset = DatasetService.get_dataset(dataset_id)
  97. if not dataset:
  98. raise NotFound("Dataset not found.")
  99. # check user's model setting
  100. DatasetService.check_dataset_model_setting(dataset)
  101. # check document
  102. document_id = str(document_id)
  103. document = DocumentService.get_document(dataset_id, document_id)
  104. if not document:
  105. raise NotFound("Document not found.")
  106. segment_ids = request.args.getlist("segment_id")
  107. # The role of the current user in the ta table must be admin, owner, dataset_operator, or editor
  108. if not current_user.is_dataset_editor:
  109. raise Forbidden()
  110. try:
  111. DatasetService.check_dataset_permission(dataset, current_user)
  112. except services.errors.account.NoPermissionError as e:
  113. raise Forbidden(str(e))
  114. SegmentService.delete_segments(segment_ids, document, dataset)
  115. return {"result": "success"}, 200
  116. class DatasetDocumentSegmentApi(Resource):
  117. @setup_required
  118. @login_required
  119. @account_initialization_required
  120. @cloud_edition_billing_resource_check("vector_space")
  121. @cloud_edition_billing_rate_limit_check("knowledge")
  122. def patch(self, dataset_id, document_id, action):
  123. dataset_id = str(dataset_id)
  124. dataset = DatasetService.get_dataset(dataset_id)
  125. if not dataset:
  126. raise NotFound("Dataset not found.")
  127. document_id = str(document_id)
  128. document = DocumentService.get_document(dataset_id, document_id)
  129. if not document:
  130. raise NotFound("Document not found.")
  131. # check user's model setting
  132. DatasetService.check_dataset_model_setting(dataset)
  133. # The role of the current user in the ta table must be admin, owner, dataset_operator, or editor
  134. if not current_user.is_dataset_editor:
  135. raise Forbidden()
  136. try:
  137. DatasetService.check_dataset_permission(dataset, current_user)
  138. except services.errors.account.NoPermissionError as e:
  139. raise Forbidden(str(e))
  140. if dataset.indexing_technique == "high_quality":
  141. # check embedding model setting
  142. try:
  143. model_manager = ModelManager()
  144. model_manager.get_model_instance(
  145. tenant_id=current_user.current_tenant_id,
  146. provider=dataset.embedding_model_provider,
  147. model_type=ModelType.TEXT_EMBEDDING,
  148. model=dataset.embedding_model,
  149. )
  150. except LLMBadRequestError:
  151. raise ProviderNotInitializeError(
  152. "No Embedding Model available. Please configure a valid provider in the Settings -> Model Provider."
  153. )
  154. except ProviderTokenNotInitError as ex:
  155. raise ProviderNotInitializeError(ex.description)
  156. segment_ids = request.args.getlist("segment_id")
  157. document_indexing_cache_key = "document_{}_indexing".format(document.id)
  158. cache_result = redis_client.get(document_indexing_cache_key)
  159. if cache_result is not None:
  160. raise InvalidActionError("Document is being indexed, please try again later")
  161. try:
  162. SegmentService.update_segments_status(segment_ids, action, dataset, document)
  163. except Exception as e:
  164. raise InvalidActionError(str(e))
  165. return {"result": "success"}, 200
  166. class DatasetDocumentSegmentAddApi(Resource):
  167. @setup_required
  168. @login_required
  169. @account_initialization_required
  170. @cloud_edition_billing_resource_check("vector_space")
  171. @cloud_edition_billing_knowledge_limit_check("add_segment")
  172. @cloud_edition_billing_rate_limit_check("knowledge")
  173. def post(self, dataset_id, document_id):
  174. # check dataset
  175. dataset_id = str(dataset_id)
  176. dataset = DatasetService.get_dataset(dataset_id)
  177. if not dataset:
  178. raise NotFound("Dataset not found.")
  179. # check document
  180. document_id = str(document_id)
  181. document = DocumentService.get_document(dataset_id, document_id)
  182. if not document:
  183. raise NotFound("Document not found.")
  184. if not current_user.is_dataset_editor:
  185. raise Forbidden()
  186. # check embedding model setting
  187. if dataset.indexing_technique == "high_quality":
  188. try:
  189. model_manager = ModelManager()
  190. model_manager.get_model_instance(
  191. tenant_id=current_user.current_tenant_id,
  192. provider=dataset.embedding_model_provider,
  193. model_type=ModelType.TEXT_EMBEDDING,
  194. model=dataset.embedding_model,
  195. )
  196. except LLMBadRequestError:
  197. raise ProviderNotInitializeError(
  198. "No Embedding Model available. Please configure a valid provider in the Settings -> Model Provider."
  199. )
  200. except ProviderTokenNotInitError as ex:
  201. raise ProviderNotInitializeError(ex.description)
  202. try:
  203. DatasetService.check_dataset_permission(dataset, current_user)
  204. except services.errors.account.NoPermissionError as e:
  205. raise Forbidden(str(e))
  206. # validate args
  207. parser = reqparse.RequestParser()
  208. parser.add_argument("content", type=str, required=True, nullable=False, location="json")
  209. parser.add_argument("answer", type=str, required=False, nullable=True, location="json")
  210. parser.add_argument("keywords", type=list, required=False, nullable=True, location="json")
  211. args = parser.parse_args()
  212. SegmentService.segment_create_args_validate(args, document)
  213. segment = SegmentService.create_segment(args, document, dataset)
  214. return {"data": marshal(segment, segment_fields), "doc_form": document.doc_form}, 200
  215. class DatasetDocumentSegmentUpdateApi(Resource):
  216. @setup_required
  217. @login_required
  218. @account_initialization_required
  219. @cloud_edition_billing_resource_check("vector_space")
  220. @cloud_edition_billing_rate_limit_check("knowledge")
  221. def patch(self, dataset_id, document_id, segment_id):
  222. # check dataset
  223. dataset_id = str(dataset_id)
  224. dataset = DatasetService.get_dataset(dataset_id)
  225. if not dataset:
  226. raise NotFound("Dataset not found.")
  227. # check user's model setting
  228. DatasetService.check_dataset_model_setting(dataset)
  229. # check document
  230. document_id = str(document_id)
  231. document = DocumentService.get_document(dataset_id, document_id)
  232. if not document:
  233. raise NotFound("Document not found.")
  234. if dataset.indexing_technique == "high_quality":
  235. # check embedding model setting
  236. try:
  237. model_manager = ModelManager()
  238. model_manager.get_model_instance(
  239. tenant_id=current_user.current_tenant_id,
  240. provider=dataset.embedding_model_provider,
  241. model_type=ModelType.TEXT_EMBEDDING,
  242. model=dataset.embedding_model,
  243. )
  244. except LLMBadRequestError:
  245. raise ProviderNotInitializeError(
  246. "No Embedding Model available. Please configure a valid provider in the Settings -> Model Provider."
  247. )
  248. except ProviderTokenNotInitError as ex:
  249. raise ProviderNotInitializeError(ex.description)
  250. # check segment
  251. segment_id = str(segment_id)
  252. segment = DocumentSegment.query.filter(
  253. DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id
  254. ).first()
  255. if not segment:
  256. raise NotFound("Segment not found.")
  257. # The role of the current user in the ta table must be admin, owner, dataset_operator, or editor
  258. if not current_user.is_dataset_editor:
  259. raise Forbidden()
  260. try:
  261. DatasetService.check_dataset_permission(dataset, current_user)
  262. except services.errors.account.NoPermissionError as e:
  263. raise Forbidden(str(e))
  264. # validate args
  265. parser = reqparse.RequestParser()
  266. parser.add_argument("content", type=str, required=True, nullable=False, location="json")
  267. parser.add_argument("answer", type=str, required=False, nullable=True, location="json")
  268. parser.add_argument("keywords", type=list, required=False, nullable=True, location="json")
  269. parser.add_argument(
  270. "regenerate_child_chunks", type=bool, required=False, nullable=True, default=False, location="json"
  271. )
  272. args = parser.parse_args()
  273. SegmentService.segment_create_args_validate(args, document)
  274. segment = SegmentService.update_segment(SegmentUpdateArgs(**args), segment, document, dataset)
  275. return {"data": marshal(segment, segment_fields), "doc_form": document.doc_form}, 200
  276. @setup_required
  277. @login_required
  278. @account_initialization_required
  279. @cloud_edition_billing_rate_limit_check("knowledge")
  280. def delete(self, dataset_id, document_id, segment_id):
  281. # check dataset
  282. dataset_id = str(dataset_id)
  283. dataset = DatasetService.get_dataset(dataset_id)
  284. if not dataset:
  285. raise NotFound("Dataset not found.")
  286. # check user's model setting
  287. DatasetService.check_dataset_model_setting(dataset)
  288. # check document
  289. document_id = str(document_id)
  290. document = DocumentService.get_document(dataset_id, document_id)
  291. if not document:
  292. raise NotFound("Document not found.")
  293. # check segment
  294. segment_id = str(segment_id)
  295. segment = DocumentSegment.query.filter(
  296. DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id
  297. ).first()
  298. if not segment:
  299. raise NotFound("Segment not found.")
  300. # The role of the current user in the ta table must be admin, owner, dataset_operator, or editor
  301. if not current_user.is_dataset_editor:
  302. raise Forbidden()
  303. try:
  304. DatasetService.check_dataset_permission(dataset, current_user)
  305. except services.errors.account.NoPermissionError as e:
  306. raise Forbidden(str(e))
  307. SegmentService.delete_segment(segment, document, dataset)
  308. return {"result": "success"}, 200
  309. class DatasetDocumentSegmentBatchImportApi(Resource):
  310. @setup_required
  311. @login_required
  312. @account_initialization_required
  313. @cloud_edition_billing_resource_check("vector_space")
  314. @cloud_edition_billing_knowledge_limit_check("add_segment")
  315. @cloud_edition_billing_rate_limit_check("knowledge")
  316. def post(self, dataset_id, document_id):
  317. # check dataset
  318. dataset_id = str(dataset_id)
  319. dataset = DatasetService.get_dataset(dataset_id)
  320. if not dataset:
  321. raise NotFound("Dataset not found.")
  322. # check document
  323. document_id = str(document_id)
  324. document = DocumentService.get_document(dataset_id, document_id)
  325. if not document:
  326. raise NotFound("Document not found.")
  327. # get file from request
  328. file = request.files["file"]
  329. # check file
  330. if "file" not in request.files:
  331. raise NoFileUploadedError()
  332. if len(request.files) > 1:
  333. raise TooManyFilesError()
  334. # check file type
  335. if not file.filename.endswith(".csv"):
  336. raise ValueError("Invalid file type. Only CSV files are allowed")
  337. try:
  338. # Skip the first row
  339. df = pd.read_csv(file)
  340. result = []
  341. for index, row in df.iterrows():
  342. if document.doc_form == "qa_model":
  343. data = {"content": row.iloc[0], "answer": row.iloc[1]}
  344. else:
  345. data = {"content": row.iloc[0]}
  346. result.append(data)
  347. if len(result) == 0:
  348. raise ValueError("The CSV file is empty.")
  349. # async job
  350. job_id = str(uuid.uuid4())
  351. indexing_cache_key = "segment_batch_import_{}".format(str(job_id))
  352. # send batch add segments task
  353. redis_client.setnx(indexing_cache_key, "waiting")
  354. batch_create_segment_to_index_task.delay(
  355. str(job_id), result, dataset_id, document_id, current_user.current_tenant_id, current_user.id
  356. )
  357. except Exception as e:
  358. return {"error": str(e)}, 500
  359. return {"job_id": job_id, "job_status": "waiting"}, 200
  360. @setup_required
  361. @login_required
  362. @account_initialization_required
  363. def get(self, job_id):
  364. job_id = str(job_id)
  365. indexing_cache_key = "segment_batch_import_{}".format(job_id)
  366. cache_result = redis_client.get(indexing_cache_key)
  367. if cache_result is None:
  368. raise ValueError("The job is not exist.")
  369. return {"job_id": job_id, "job_status": cache_result.decode()}, 200
  370. class ChildChunkAddApi(Resource):
  371. @setup_required
  372. @login_required
  373. @account_initialization_required
  374. @cloud_edition_billing_resource_check("vector_space")
  375. @cloud_edition_billing_knowledge_limit_check("add_segment")
  376. @cloud_edition_billing_rate_limit_check("knowledge")
  377. def post(self, dataset_id, document_id, segment_id):
  378. # check dataset
  379. dataset_id = str(dataset_id)
  380. dataset = DatasetService.get_dataset(dataset_id)
  381. if not dataset:
  382. raise NotFound("Dataset not found.")
  383. # check document
  384. document_id = str(document_id)
  385. document = DocumentService.get_document(dataset_id, document_id)
  386. if not document:
  387. raise NotFound("Document not found.")
  388. # check segment
  389. segment_id = str(segment_id)
  390. segment = DocumentSegment.query.filter(
  391. DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id
  392. ).first()
  393. if not segment:
  394. raise NotFound("Segment not found.")
  395. if not current_user.is_dataset_editor:
  396. raise Forbidden()
  397. # check embedding model setting
  398. if dataset.indexing_technique == "high_quality":
  399. try:
  400. model_manager = ModelManager()
  401. model_manager.get_model_instance(
  402. tenant_id=current_user.current_tenant_id,
  403. provider=dataset.embedding_model_provider,
  404. model_type=ModelType.TEXT_EMBEDDING,
  405. model=dataset.embedding_model,
  406. )
  407. except LLMBadRequestError:
  408. raise ProviderNotInitializeError(
  409. "No Embedding Model available. Please configure a valid provider in the Settings -> Model Provider."
  410. )
  411. except ProviderTokenNotInitError as ex:
  412. raise ProviderNotInitializeError(ex.description)
  413. try:
  414. DatasetService.check_dataset_permission(dataset, current_user)
  415. except services.errors.account.NoPermissionError as e:
  416. raise Forbidden(str(e))
  417. # validate args
  418. parser = reqparse.RequestParser()
  419. parser.add_argument("content", type=str, required=True, nullable=False, location="json")
  420. args = parser.parse_args()
  421. try:
  422. child_chunk = SegmentService.create_child_chunk(args.get("content"), segment, document, dataset)
  423. except ChildChunkIndexingServiceError as e:
  424. raise ChildChunkIndexingError(str(e))
  425. return {"data": marshal(child_chunk, child_chunk_fields)}, 200
  426. @setup_required
  427. @login_required
  428. @account_initialization_required
  429. def get(self, dataset_id, document_id, segment_id):
  430. # check dataset
  431. dataset_id = str(dataset_id)
  432. dataset = DatasetService.get_dataset(dataset_id)
  433. if not dataset:
  434. raise NotFound("Dataset not found.")
  435. # check user's model setting
  436. DatasetService.check_dataset_model_setting(dataset)
  437. # check document
  438. document_id = str(document_id)
  439. document = DocumentService.get_document(dataset_id, document_id)
  440. if not document:
  441. raise NotFound("Document not found.")
  442. # check segment
  443. segment_id = str(segment_id)
  444. segment = DocumentSegment.query.filter(
  445. DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id
  446. ).first()
  447. if not segment:
  448. raise NotFound("Segment not found.")
  449. parser = reqparse.RequestParser()
  450. parser.add_argument("limit", type=int, default=20, location="args")
  451. parser.add_argument("keyword", type=str, default=None, location="args")
  452. parser.add_argument("page", type=int, default=1, location="args")
  453. args = parser.parse_args()
  454. page = args["page"]
  455. limit = min(args["limit"], 100)
  456. keyword = args["keyword"]
  457. child_chunks = SegmentService.get_child_chunks(segment_id, document_id, dataset_id, page, limit, keyword)
  458. return {
  459. "data": marshal(child_chunks.items, child_chunk_fields),
  460. "total": child_chunks.total,
  461. "total_pages": child_chunks.pages,
  462. "page": page,
  463. "limit": limit,
  464. }, 200
  465. @setup_required
  466. @login_required
  467. @account_initialization_required
  468. @cloud_edition_billing_resource_check("vector_space")
  469. @cloud_edition_billing_rate_limit_check("knowledge")
  470. def patch(self, dataset_id, document_id, segment_id):
  471. # check dataset
  472. dataset_id = str(dataset_id)
  473. dataset = DatasetService.get_dataset(dataset_id)
  474. if not dataset:
  475. raise NotFound("Dataset not found.")
  476. # check user's model setting
  477. DatasetService.check_dataset_model_setting(dataset)
  478. # check document
  479. document_id = str(document_id)
  480. document = DocumentService.get_document(dataset_id, document_id)
  481. if not document:
  482. raise NotFound("Document not found.")
  483. # check segment
  484. segment_id = str(segment_id)
  485. segment = DocumentSegment.query.filter(
  486. DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id
  487. ).first()
  488. if not segment:
  489. raise NotFound("Segment not found.")
  490. # The role of the current user in the ta table must be admin, owner, dataset_operator, or editor
  491. if not current_user.is_dataset_editor:
  492. raise Forbidden()
  493. try:
  494. DatasetService.check_dataset_permission(dataset, current_user)
  495. except services.errors.account.NoPermissionError as e:
  496. raise Forbidden(str(e))
  497. # validate args
  498. parser = reqparse.RequestParser()
  499. parser.add_argument("chunks", type=list, required=True, nullable=False, location="json")
  500. args = parser.parse_args()
  501. try:
  502. chunks = [ChildChunkUpdateArgs(**chunk) for chunk in args.get("chunks")]
  503. child_chunks = SegmentService.update_child_chunks(chunks, segment, document, dataset)
  504. except ChildChunkIndexingServiceError as e:
  505. raise ChildChunkIndexingError(str(e))
  506. return {"data": marshal(child_chunks, child_chunk_fields)}, 200
  507. class ChildChunkUpdateApi(Resource):
  508. @setup_required
  509. @login_required
  510. @account_initialization_required
  511. @cloud_edition_billing_rate_limit_check("knowledge")
  512. def delete(self, dataset_id, document_id, segment_id, child_chunk_id):
  513. # check dataset
  514. dataset_id = str(dataset_id)
  515. dataset = DatasetService.get_dataset(dataset_id)
  516. if not dataset:
  517. raise NotFound("Dataset not found.")
  518. # check user's model setting
  519. DatasetService.check_dataset_model_setting(dataset)
  520. # check document
  521. document_id = str(document_id)
  522. document = DocumentService.get_document(dataset_id, document_id)
  523. if not document:
  524. raise NotFound("Document not found.")
  525. # check segment
  526. segment_id = str(segment_id)
  527. segment = DocumentSegment.query.filter(
  528. DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id
  529. ).first()
  530. if not segment:
  531. raise NotFound("Segment not found.")
  532. # check child chunk
  533. child_chunk_id = str(child_chunk_id)
  534. child_chunk = ChildChunk.query.filter(
  535. ChildChunk.id == str(child_chunk_id), ChildChunk.tenant_id == current_user.current_tenant_id
  536. ).first()
  537. if not child_chunk:
  538. raise NotFound("Child chunk not found.")
  539. # The role of the current user in the ta table must be admin, owner, dataset_operator, or editor
  540. if not current_user.is_dataset_editor:
  541. raise Forbidden()
  542. try:
  543. DatasetService.check_dataset_permission(dataset, current_user)
  544. except services.errors.account.NoPermissionError as e:
  545. raise Forbidden(str(e))
  546. try:
  547. SegmentService.delete_child_chunk(child_chunk, dataset)
  548. except ChildChunkDeleteIndexServiceError as e:
  549. raise ChildChunkDeleteIndexError(str(e))
  550. return {"result": "success"}, 200
  551. @setup_required
  552. @login_required
  553. @account_initialization_required
  554. @cloud_edition_billing_resource_check("vector_space")
  555. @cloud_edition_billing_rate_limit_check("knowledge")
  556. def patch(self, dataset_id, document_id, segment_id, child_chunk_id):
  557. # check dataset
  558. dataset_id = str(dataset_id)
  559. dataset = DatasetService.get_dataset(dataset_id)
  560. if not dataset:
  561. raise NotFound("Dataset not found.")
  562. # check user's model setting
  563. DatasetService.check_dataset_model_setting(dataset)
  564. # check document
  565. document_id = str(document_id)
  566. document = DocumentService.get_document(dataset_id, document_id)
  567. if not document:
  568. raise NotFound("Document not found.")
  569. # check segment
  570. segment_id = str(segment_id)
  571. segment = DocumentSegment.query.filter(
  572. DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id
  573. ).first()
  574. if not segment:
  575. raise NotFound("Segment not found.")
  576. # check child chunk
  577. child_chunk_id = str(child_chunk_id)
  578. child_chunk = ChildChunk.query.filter(
  579. ChildChunk.id == str(child_chunk_id), ChildChunk.tenant_id == current_user.current_tenant_id
  580. ).first()
  581. if not child_chunk:
  582. raise NotFound("Child chunk not found.")
  583. # The role of the current user in the ta table must be admin, owner, dataset_operator, or editor
  584. if not current_user.is_dataset_editor:
  585. raise Forbidden()
  586. try:
  587. DatasetService.check_dataset_permission(dataset, current_user)
  588. except services.errors.account.NoPermissionError as e:
  589. raise Forbidden(str(e))
  590. # validate args
  591. parser = reqparse.RequestParser()
  592. parser.add_argument("content", type=str, required=True, nullable=False, location="json")
  593. args = parser.parse_args()
  594. try:
  595. child_chunk = SegmentService.update_child_chunk(
  596. args.get("content"), child_chunk, segment, document, dataset
  597. )
  598. except ChildChunkIndexingServiceError as e:
  599. raise ChildChunkIndexingError(str(e))
  600. return {"data": marshal(child_chunk, child_chunk_fields)}, 200
  601. api.add_resource(DatasetDocumentSegmentListApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments")
  602. api.add_resource(
  603. DatasetDocumentSegmentApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segment/<string:action>"
  604. )
  605. api.add_resource(DatasetDocumentSegmentAddApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segment")
  606. api.add_resource(
  607. DatasetDocumentSegmentUpdateApi,
  608. "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments/<uuid:segment_id>",
  609. )
  610. api.add_resource(
  611. DatasetDocumentSegmentBatchImportApi,
  612. "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments/batch_import",
  613. "/datasets/batch_import_status/<uuid:job_id>",
  614. )
  615. api.add_resource(
  616. ChildChunkAddApi,
  617. "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments/<uuid:segment_id>/child_chunks",
  618. )
  619. api.add_resource(
  620. ChildChunkUpdateApi,
  621. "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments/<uuid:segment_id>/child_chunks/<uuid:child_chunk_id>",
  622. )