datasets_segments.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436
  1. # -*- coding:utf-8 -*-
  2. import uuid
  3. from datetime import datetime
  4. import pandas as pd
  5. import services
  6. from controllers.console import api
  7. from controllers.console.app.error import ProviderNotInitializeError
  8. from controllers.console.datasets.error import InvalidActionError, NoFileUploadedError, TooManyFilesError
  9. from controllers.console.setup import setup_required
  10. from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check
  11. from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError
  12. from core.model_manager import ModelManager
  13. from core.model_runtime.entities.model_entities import ModelType
  14. from extensions.ext_database import db
  15. from extensions.ext_redis import redis_client
  16. from fields.segment_fields import segment_fields
  17. from flask import request
  18. from flask_login import current_user
  19. from flask_restful import Resource, marshal, reqparse
  20. from libs.login import login_required
  21. from models.dataset import DocumentSegment
  22. from services.dataset_service import DatasetService, DocumentService, SegmentService
  23. from tasks.batch_create_segment_to_index_task import batch_create_segment_to_index_task
  24. from tasks.disable_segment_from_index_task import disable_segment_from_index_task
  25. from tasks.enable_segment_to_index_task import enable_segment_to_index_task
  26. from werkzeug.exceptions import Forbidden, NotFound
  27. class DatasetDocumentSegmentListApi(Resource):
  28. @setup_required
  29. @login_required
  30. @account_initialization_required
  31. def get(self, dataset_id, document_id):
  32. dataset_id = str(dataset_id)
  33. document_id = str(document_id)
  34. dataset = DatasetService.get_dataset(dataset_id)
  35. if not dataset:
  36. raise NotFound('Dataset not found.')
  37. try:
  38. DatasetService.check_dataset_permission(dataset, current_user)
  39. except services.errors.account.NoPermissionError as e:
  40. raise Forbidden(str(e))
  41. document = DocumentService.get_document(dataset_id, document_id)
  42. if not document:
  43. raise NotFound('Document not found.')
  44. parser = reqparse.RequestParser()
  45. parser.add_argument('last_id', type=str, default=None, location='args')
  46. parser.add_argument('limit', type=int, default=20, location='args')
  47. parser.add_argument('status', type=str,
  48. action='append', default=[], location='args')
  49. parser.add_argument('hit_count_gte', type=int,
  50. default=None, location='args')
  51. parser.add_argument('enabled', type=str, default='all', location='args')
  52. parser.add_argument('keyword', type=str, default=None, location='args')
  53. args = parser.parse_args()
  54. last_id = args['last_id']
  55. limit = min(args['limit'], 100)
  56. status_list = args['status']
  57. hit_count_gte = args['hit_count_gte']
  58. keyword = args['keyword']
  59. query = DocumentSegment.query.filter(
  60. DocumentSegment.document_id == str(document_id),
  61. DocumentSegment.tenant_id == current_user.current_tenant_id
  62. )
  63. if last_id is not None:
  64. last_segment = DocumentSegment.query.get(str(last_id))
  65. if last_segment:
  66. query = query.filter(
  67. DocumentSegment.position > last_segment.position)
  68. else:
  69. return {'data': [], 'has_more': False, 'limit': limit}, 200
  70. if status_list:
  71. query = query.filter(DocumentSegment.status.in_(status_list))
  72. if hit_count_gte is not None:
  73. query = query.filter(DocumentSegment.hit_count >= hit_count_gte)
  74. if keyword:
  75. query = query.where(DocumentSegment.content.ilike(f'%{keyword}%'))
  76. if args['enabled'].lower() != 'all':
  77. if args['enabled'].lower() == 'true':
  78. query = query.filter(DocumentSegment.enabled == True)
  79. elif args['enabled'].lower() == 'false':
  80. query = query.filter(DocumentSegment.enabled == False)
  81. total = query.count()
  82. segments = query.order_by(DocumentSegment.position).limit(limit + 1).all()
  83. has_more = False
  84. if len(segments) > limit:
  85. has_more = True
  86. segments = segments[:-1]
  87. return {
  88. 'data': marshal(segments, segment_fields),
  89. 'doc_form': document.doc_form,
  90. 'has_more': has_more,
  91. 'limit': limit,
  92. 'total': total
  93. }, 200
  94. class DatasetDocumentSegmentApi(Resource):
  95. @setup_required
  96. @login_required
  97. @account_initialization_required
  98. @cloud_edition_billing_resource_check('vector_space')
  99. def patch(self, dataset_id, segment_id, action):
  100. dataset_id = str(dataset_id)
  101. dataset = DatasetService.get_dataset(dataset_id)
  102. if not dataset:
  103. raise NotFound('Dataset not found.')
  104. # check user's model setting
  105. DatasetService.check_dataset_model_setting(dataset)
  106. # The role of the current user in the ta table must be admin or owner
  107. if not current_user.is_admin_or_owner:
  108. raise Forbidden()
  109. try:
  110. DatasetService.check_dataset_permission(dataset, current_user)
  111. except services.errors.account.NoPermissionError as e:
  112. raise Forbidden(str(e))
  113. if dataset.indexing_technique == 'high_quality':
  114. # check embedding model setting
  115. try:
  116. model_manager = ModelManager()
  117. model_manager.get_model_instance(
  118. tenant_id=current_user.current_tenant_id,
  119. provider=dataset.embedding_model_provider,
  120. model_type=ModelType.TEXT_EMBEDDING,
  121. model=dataset.embedding_model
  122. )
  123. except LLMBadRequestError:
  124. raise ProviderNotInitializeError(
  125. f"No Embedding Model available. Please configure a valid provider "
  126. f"in the Settings -> Model Provider.")
  127. except ProviderTokenNotInitError as ex:
  128. raise ProviderNotInitializeError(ex.description)
  129. segment = DocumentSegment.query.filter(
  130. DocumentSegment.id == str(segment_id),
  131. DocumentSegment.tenant_id == current_user.current_tenant_id
  132. ).first()
  133. if not segment:
  134. raise NotFound('Segment not found.')
  135. if segment.status != 'completed':
  136. raise NotFound('Segment is not completed, enable or disable function is not allowed')
  137. document_indexing_cache_key = 'document_{}_indexing'.format(segment.document_id)
  138. cache_result = redis_client.get(document_indexing_cache_key)
  139. if cache_result is not None:
  140. raise InvalidActionError("Document is being indexed, please try again later")
  141. indexing_cache_key = 'segment_{}_indexing'.format(segment.id)
  142. cache_result = redis_client.get(indexing_cache_key)
  143. if cache_result is not None:
  144. raise InvalidActionError("Segment is being indexed, please try again later")
  145. if action == "enable":
  146. if segment.enabled:
  147. raise InvalidActionError("Segment is already enabled.")
  148. segment.enabled = True
  149. segment.disabled_at = None
  150. segment.disabled_by = None
  151. db.session.commit()
  152. # Set cache to prevent indexing the same segment multiple times
  153. redis_client.setex(indexing_cache_key, 600, 1)
  154. enable_segment_to_index_task.delay(segment.id)
  155. return {'result': 'success'}, 200
  156. elif action == "disable":
  157. if not segment.enabled:
  158. raise InvalidActionError("Segment is already disabled.")
  159. segment.enabled = False
  160. segment.disabled_at = datetime.utcnow()
  161. segment.disabled_by = current_user.id
  162. db.session.commit()
  163. # Set cache to prevent indexing the same segment multiple times
  164. redis_client.setex(indexing_cache_key, 600, 1)
  165. disable_segment_from_index_task.delay(segment.id)
  166. return {'result': 'success'}, 200
  167. else:
  168. raise InvalidActionError()
  169. class DatasetDocumentSegmentAddApi(Resource):
  170. @setup_required
  171. @login_required
  172. @account_initialization_required
  173. @cloud_edition_billing_resource_check('vector_space')
  174. def post(self, dataset_id, document_id):
  175. # check dataset
  176. dataset_id = str(dataset_id)
  177. dataset = DatasetService.get_dataset(dataset_id)
  178. if not dataset:
  179. raise NotFound('Dataset not found.')
  180. # check document
  181. document_id = str(document_id)
  182. document = DocumentService.get_document(dataset_id, document_id)
  183. if not document:
  184. raise NotFound('Document not found.')
  185. # The role of the current user in the ta table must be admin or owner
  186. if not current_user.is_admin_or_owner:
  187. raise Forbidden()
  188. # check embedding model setting
  189. if dataset.indexing_technique == 'high_quality':
  190. try:
  191. model_manager = ModelManager()
  192. model_manager.get_model_instance(
  193. tenant_id=current_user.current_tenant_id,
  194. provider=dataset.embedding_model_provider,
  195. model_type=ModelType.TEXT_EMBEDDING,
  196. model=dataset.embedding_model
  197. )
  198. except LLMBadRequestError:
  199. raise ProviderNotInitializeError(
  200. f"No Embedding Model available. Please configure a valid provider "
  201. f"in the Settings -> Model Provider.")
  202. except ProviderTokenNotInitError as ex:
  203. raise ProviderNotInitializeError(ex.description)
  204. try:
  205. DatasetService.check_dataset_permission(dataset, current_user)
  206. except services.errors.account.NoPermissionError as e:
  207. raise Forbidden(str(e))
  208. # validate args
  209. parser = reqparse.RequestParser()
  210. parser.add_argument('content', type=str, required=True, nullable=False, location='json')
  211. parser.add_argument('answer', type=str, required=False, nullable=True, location='json')
  212. parser.add_argument('keywords', type=list, required=False, nullable=True, location='json')
  213. args = parser.parse_args()
  214. SegmentService.segment_create_args_validate(args, document)
  215. segment = SegmentService.create_segment(args, document, dataset)
  216. return {
  217. 'data': marshal(segment, segment_fields),
  218. 'doc_form': document.doc_form
  219. }, 200
  220. class DatasetDocumentSegmentUpdateApi(Resource):
  221. @setup_required
  222. @login_required
  223. @account_initialization_required
  224. @cloud_edition_billing_resource_check('vector_space')
  225. def patch(self, dataset_id, document_id, segment_id):
  226. # check dataset
  227. dataset_id = str(dataset_id)
  228. dataset = DatasetService.get_dataset(dataset_id)
  229. if not dataset:
  230. raise NotFound('Dataset not found.')
  231. # check user's model setting
  232. DatasetService.check_dataset_model_setting(dataset)
  233. # check document
  234. document_id = str(document_id)
  235. document = DocumentService.get_document(dataset_id, document_id)
  236. if not document:
  237. raise NotFound('Document not found.')
  238. if dataset.indexing_technique == 'high_quality':
  239. # check embedding model setting
  240. try:
  241. model_manager = ModelManager()
  242. model_manager.get_model_instance(
  243. tenant_id=current_user.current_tenant_id,
  244. provider=dataset.embedding_model_provider,
  245. model_type=ModelType.TEXT_EMBEDDING,
  246. model=dataset.embedding_model
  247. )
  248. except LLMBadRequestError:
  249. raise ProviderNotInitializeError(
  250. f"No Embedding Model available. Please configure a valid provider "
  251. f"in the Settings -> Model Provider.")
  252. except ProviderTokenNotInitError as ex:
  253. raise ProviderNotInitializeError(ex.description)
  254. # check segment
  255. segment_id = str(segment_id)
  256. segment = DocumentSegment.query.filter(
  257. DocumentSegment.id == str(segment_id),
  258. DocumentSegment.tenant_id == current_user.current_tenant_id
  259. ).first()
  260. if not segment:
  261. raise NotFound('Segment not found.')
  262. # The role of the current user in the ta table must be admin or owner
  263. if not current_user.is_admin_or_owner:
  264. raise Forbidden()
  265. try:
  266. DatasetService.check_dataset_permission(dataset, current_user)
  267. except services.errors.account.NoPermissionError as e:
  268. raise Forbidden(str(e))
  269. # validate args
  270. parser = reqparse.RequestParser()
  271. parser.add_argument('content', type=str, required=True, nullable=False, location='json')
  272. parser.add_argument('answer', type=str, required=False, nullable=True, location='json')
  273. parser.add_argument('keywords', type=list, required=False, nullable=True, location='json')
  274. args = parser.parse_args()
  275. SegmentService.segment_create_args_validate(args, document)
  276. segment = SegmentService.update_segment(args, segment, document, dataset)
  277. return {
  278. 'data': marshal(segment, segment_fields),
  279. 'doc_form': document.doc_form
  280. }, 200
  281. @setup_required
  282. @login_required
  283. @account_initialization_required
  284. def delete(self, dataset_id, document_id, segment_id):
  285. # check dataset
  286. dataset_id = str(dataset_id)
  287. dataset = DatasetService.get_dataset(dataset_id)
  288. if not dataset:
  289. raise NotFound('Dataset not found.')
  290. # check user's model setting
  291. DatasetService.check_dataset_model_setting(dataset)
  292. # check document
  293. document_id = str(document_id)
  294. document = DocumentService.get_document(dataset_id, document_id)
  295. if not document:
  296. raise NotFound('Document not found.')
  297. # check segment
  298. segment_id = str(segment_id)
  299. segment = DocumentSegment.query.filter(
  300. DocumentSegment.id == str(segment_id),
  301. DocumentSegment.tenant_id == current_user.current_tenant_id
  302. ).first()
  303. if not segment:
  304. raise NotFound('Segment not found.')
  305. # The role of the current user in the ta table must be admin or owner
  306. if not current_user.is_admin_or_owner:
  307. raise Forbidden()
  308. try:
  309. DatasetService.check_dataset_permission(dataset, current_user)
  310. except services.errors.account.NoPermissionError as e:
  311. raise Forbidden(str(e))
  312. SegmentService.delete_segment(segment, document, dataset)
  313. return {'result': 'success'}, 200
  314. class DatasetDocumentSegmentBatchImportApi(Resource):
  315. @setup_required
  316. @login_required
  317. @account_initialization_required
  318. @cloud_edition_billing_resource_check('vector_space')
  319. def post(self, dataset_id, document_id):
  320. # check dataset
  321. dataset_id = str(dataset_id)
  322. dataset = DatasetService.get_dataset(dataset_id)
  323. if not dataset:
  324. raise NotFound('Dataset not found.')
  325. # check document
  326. document_id = str(document_id)
  327. document = DocumentService.get_document(dataset_id, document_id)
  328. if not document:
  329. raise NotFound('Document not found.')
  330. # get file from request
  331. file = request.files['file']
  332. # check file
  333. if 'file' not in request.files:
  334. raise NoFileUploadedError()
  335. if len(request.files) > 1:
  336. raise TooManyFilesError()
  337. # check file type
  338. if not file.filename.endswith('.csv'):
  339. raise ValueError("Invalid file type. Only CSV files are allowed")
  340. try:
  341. # Skip the first row
  342. df = pd.read_csv(file)
  343. result = []
  344. for index, row in df.iterrows():
  345. if document.doc_form == 'qa_model':
  346. data = {'content': row[0], 'answer': row[1]}
  347. else:
  348. data = {'content': row[0]}
  349. result.append(data)
  350. if len(result) == 0:
  351. raise ValueError("The CSV file is empty.")
  352. # async job
  353. job_id = str(uuid.uuid4())
  354. indexing_cache_key = 'segment_batch_import_{}'.format(str(job_id))
  355. # send batch add segments task
  356. redis_client.setnx(indexing_cache_key, 'waiting')
  357. batch_create_segment_to_index_task.delay(str(job_id), result, dataset_id, document_id,
  358. current_user.current_tenant_id, current_user.id)
  359. except Exception as e:
  360. return {'error': str(e)}, 500
  361. return {
  362. 'job_id': job_id,
  363. 'job_status': 'waiting'
  364. }, 200
  365. @setup_required
  366. @login_required
  367. @account_initialization_required
  368. def get(self, job_id):
  369. job_id = str(job_id)
  370. indexing_cache_key = 'segment_batch_import_{}'.format(job_id)
  371. cache_result = redis_client.get(indexing_cache_key)
  372. if cache_result is None:
  373. raise ValueError("The job is not exist.")
  374. return {
  375. 'job_id': job_id,
  376. 'job_status': cache_result.decode()
  377. }, 200
  378. api.add_resource(DatasetDocumentSegmentListApi,
  379. '/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments')
  380. api.add_resource(DatasetDocumentSegmentApi,
  381. '/datasets/<uuid:dataset_id>/segments/<uuid:segment_id>/<string:action>')
  382. api.add_resource(DatasetDocumentSegmentAddApi,
  383. '/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segment')
  384. api.add_resource(DatasetDocumentSegmentUpdateApi,
  385. '/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments/<uuid:segment_id>')
  386. api.add_resource(DatasetDocumentSegmentBatchImportApi,
  387. '/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments/batch_import',
  388. '/datasets/batch_import_status/<uuid:job_id>')