segment.py 8.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202
  1. from flask_login import current_user
  2. from flask_restful import reqparse, marshal
  3. from werkzeug.exceptions import NotFound
  4. from controllers.service_api import api
  5. from controllers.service_api.app.error import ProviderNotInitializeError
  6. from controllers.service_api.wraps import DatasetApiResource
  7. from core.model_providers.error import ProviderTokenNotInitError, LLMBadRequestError
  8. from core.model_providers.model_factory import ModelFactory
  9. from extensions.ext_database import db
  10. from fields.segment_fields import segment_fields
  11. from models.dataset import Dataset, DocumentSegment
  12. from services.dataset_service import DatasetService, DocumentService, SegmentService
  13. class SegmentApi(DatasetApiResource):
  14. """Resource for segments."""
  15. def post(self, tenant_id, dataset_id, document_id):
  16. """Create single segment."""
  17. # check dataset
  18. dataset_id = str(dataset_id)
  19. tenant_id = str(tenant_id)
  20. dataset = db.session.query(Dataset).filter(
  21. Dataset.tenant_id == tenant_id,
  22. Dataset.id == dataset_id
  23. ).first()
  24. if not dataset:
  25. raise NotFound('Dataset not found.')
  26. # check document
  27. document_id = str(document_id)
  28. document = DocumentService.get_document(dataset.id, document_id)
  29. if not document:
  30. raise NotFound('Document not found.')
  31. # check embedding model setting
  32. if dataset.indexing_technique == 'high_quality':
  33. try:
  34. ModelFactory.get_embedding_model(
  35. tenant_id=current_user.current_tenant_id,
  36. model_provider_name=dataset.embedding_model_provider,
  37. model_name=dataset.embedding_model
  38. )
  39. except LLMBadRequestError:
  40. raise ProviderNotInitializeError(
  41. f"No Embedding Model available. Please configure a valid provider "
  42. f"in the Settings -> Model Provider.")
  43. except ProviderTokenNotInitError as ex:
  44. raise ProviderNotInitializeError(ex.description)
  45. # validate args
  46. parser = reqparse.RequestParser()
  47. parser.add_argument('segments', type=list, required=False, nullable=True, location='json')
  48. args = parser.parse_args()
  49. for args_item in args['segments']:
  50. SegmentService.segment_create_args_validate(args_item, document)
  51. segments = SegmentService.multi_create_segment(args['segments'], document, dataset)
  52. return {
  53. 'data': marshal(segments, segment_fields),
  54. 'doc_form': document.doc_form
  55. }, 200
  56. def get(self, tenant_id, dataset_id, document_id):
  57. """Create single segment."""
  58. # check dataset
  59. dataset_id = str(dataset_id)
  60. tenant_id = str(tenant_id)
  61. dataset = db.session.query(Dataset).filter(
  62. Dataset.tenant_id == tenant_id,
  63. Dataset.id == dataset_id
  64. ).first()
  65. if not dataset:
  66. raise NotFound('Dataset not found.')
  67. # check document
  68. document_id = str(document_id)
  69. document = DocumentService.get_document(dataset.id, document_id)
  70. if not document:
  71. raise NotFound('Document not found.')
  72. # check embedding model setting
  73. if dataset.indexing_technique == 'high_quality':
  74. try:
  75. ModelFactory.get_embedding_model(
  76. tenant_id=current_user.current_tenant_id,
  77. model_provider_name=dataset.embedding_model_provider,
  78. model_name=dataset.embedding_model
  79. )
  80. except LLMBadRequestError:
  81. raise ProviderNotInitializeError(
  82. f"No Embedding Model available. Please configure a valid provider "
  83. f"in the Settings -> Model Provider.")
  84. except ProviderTokenNotInitError as ex:
  85. raise ProviderNotInitializeError(ex.description)
  86. parser = reqparse.RequestParser()
  87. parser.add_argument('status', type=str,
  88. action='append', default=[], location='args')
  89. parser.add_argument('keyword', type=str, default=None, location='args')
  90. args = parser.parse_args()
  91. status_list = args['status']
  92. keyword = args['keyword']
  93. query = DocumentSegment.query.filter(
  94. DocumentSegment.document_id == str(document_id),
  95. DocumentSegment.tenant_id == current_user.current_tenant_id
  96. )
  97. if status_list:
  98. query = query.filter(DocumentSegment.status.in_(status_list))
  99. if keyword:
  100. query = query.where(DocumentSegment.content.ilike(f'%{keyword}%'))
  101. total = query.count()
  102. segments = query.order_by(DocumentSegment.position).all()
  103. return {
  104. 'data': marshal(segments, segment_fields),
  105. 'doc_form': document.doc_form,
  106. 'total': total
  107. }, 200
  108. class DatasetSegmentApi(DatasetApiResource):
  109. def delete(self, tenant_id, dataset_id, document_id, segment_id):
  110. # check dataset
  111. dataset_id = str(dataset_id)
  112. tenant_id = str(tenant_id)
  113. dataset = db.session.query(Dataset).filter(
  114. Dataset.tenant_id == tenant_id,
  115. Dataset.id == dataset_id
  116. ).first()
  117. if not dataset:
  118. raise NotFound('Dataset not found.')
  119. # check user's model setting
  120. DatasetService.check_dataset_model_setting(dataset)
  121. # check document
  122. document_id = str(document_id)
  123. document = DocumentService.get_document(dataset_id, document_id)
  124. if not document:
  125. raise NotFound('Document not found.')
  126. # check segment
  127. segment = DocumentSegment.query.filter(
  128. DocumentSegment.id == str(segment_id),
  129. DocumentSegment.tenant_id == current_user.current_tenant_id
  130. ).first()
  131. if not segment:
  132. raise NotFound('Segment not found.')
  133. SegmentService.delete_segment(segment, document, dataset)
  134. return {'result': 'success'}, 200
  135. def post(self, tenant_id, dataset_id, document_id, segment_id):
  136. # check dataset
  137. dataset_id = str(dataset_id)
  138. tenant_id = str(tenant_id)
  139. dataset = db.session.query(Dataset).filter(
  140. Dataset.tenant_id == tenant_id,
  141. Dataset.id == dataset_id
  142. ).first()
  143. if not dataset:
  144. raise NotFound('Dataset not found.')
  145. # check user's model setting
  146. DatasetService.check_dataset_model_setting(dataset)
  147. # check document
  148. document_id = str(document_id)
  149. document = DocumentService.get_document(dataset_id, document_id)
  150. if not document:
  151. raise NotFound('Document not found.')
  152. if dataset.indexing_technique == 'high_quality':
  153. # check embedding model setting
  154. try:
  155. ModelFactory.get_embedding_model(
  156. tenant_id=current_user.current_tenant_id,
  157. model_provider_name=dataset.embedding_model_provider,
  158. model_name=dataset.embedding_model
  159. )
  160. except LLMBadRequestError:
  161. raise ProviderNotInitializeError(
  162. f"No Embedding Model available. Please configure a valid provider "
  163. f"in the Settings -> Model Provider.")
  164. except ProviderTokenNotInitError as ex:
  165. raise ProviderNotInitializeError(ex.description)
  166. # check segment
  167. segment_id = str(segment_id)
  168. segment = DocumentSegment.query.filter(
  169. DocumentSegment.id == str(segment_id),
  170. DocumentSegment.tenant_id == current_user.current_tenant_id
  171. ).first()
  172. if not segment:
  173. raise NotFound('Segment not found.')
  174. # validate args
  175. parser = reqparse.RequestParser()
  176. parser.add_argument('segments', type=dict, required=False, nullable=True, location='json')
  177. args = parser.parse_args()
  178. SegmentService.segment_create_args_validate(args['segments'], document)
  179. segment = SegmentService.update_segment(args['segments'], segment, document, dataset)
  180. return {
  181. 'data': marshal(segment, segment_fields),
  182. 'doc_form': document.doc_form
  183. }, 200
  184. api.add_resource(SegmentApi, '/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments')
  185. api.add_resource(DatasetSegmentApi, '/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments/<uuid:segment_id>')