Browse Source

document segmentApi Add get&update&delete operate (#1285)

Co-authored-by: luowei <glpat-EjySCyNjWiLqAED-YmwM>
Charlie.Wei 1 year ago
parent
commit
9e7efa45d4

+ 145 - 3
api/controllers/service_api/dataset/segment.py

@@ -1,7 +1,6 @@
 from flask_login import current_user
 from flask_restful import reqparse, marshal
 from werkzeug.exceptions import NotFound
-
 from controllers.service_api import api
 from controllers.service_api.app.error import ProviderNotInitializeError
 from controllers.service_api.wraps import DatasetApiResource
@@ -9,8 +8,8 @@ from core.model_providers.error import ProviderTokenNotInitError, LLMBadRequestE
 from core.model_providers.model_factory import ModelFactory
 from extensions.ext_database import db
 from fields.segment_fields import segment_fields
-from models.dataset import Dataset
-from services.dataset_service import DocumentService, SegmentService
+from models.dataset import Dataset, DocumentSegment
+from services.dataset_service import DatasetService, DocumentService, SegmentService
 
 
 class SegmentApi(DatasetApiResource):
@@ -24,6 +23,8 @@ class SegmentApi(DatasetApiResource):
             Dataset.tenant_id == tenant_id,
             Dataset.id == dataset_id
         ).first()
+        if not dataset:
+            raise NotFound('Dataset not found.')
         # check document
         document_id = str(document_id)
         document = DocumentService.get_document(dataset.id, document_id)
@@ -55,5 +56,146 @@ class SegmentApi(DatasetApiResource):
             'doc_form': document.doc_form
         }, 200
 
+    def get(self, tenant_id, dataset_id, document_id):
+        """Create single segment."""
+        # check dataset
+        dataset_id = str(dataset_id)
+        tenant_id = str(tenant_id)
+        dataset = db.session.query(Dataset).filter(
+            Dataset.tenant_id == tenant_id,
+            Dataset.id == dataset_id
+        ).first()
+        if not dataset:
+            raise NotFound('Dataset not found.')
+        # check document
+        document_id = str(document_id)
+        document = DocumentService.get_document(dataset.id, document_id)
+        if not document:
+            raise NotFound('Document not found.')
+        # check embedding model setting
+        if dataset.indexing_technique == 'high_quality':
+            try:
+                ModelFactory.get_embedding_model(
+                    tenant_id=current_user.current_tenant_id,
+                    model_provider_name=dataset.embedding_model_provider,
+                    model_name=dataset.embedding_model
+                )
+            except LLMBadRequestError:
+                raise ProviderNotInitializeError(
+                    f"No Embedding Model available. Please configure a valid provider "
+                    f"in the Settings -> Model Provider.")
+            except ProviderTokenNotInitError as ex:
+                raise ProviderNotInitializeError(ex.description)
+
+        parser = reqparse.RequestParser()
+        parser.add_argument('status', type=str,
+                            action='append', default=[], location='args')
+        parser.add_argument('keyword', type=str, default=None, location='args')
+        args = parser.parse_args()
+
+        status_list = args['status']
+        keyword = args['keyword']
+
+        query = DocumentSegment.query.filter(
+            DocumentSegment.document_id == str(document_id),
+            DocumentSegment.tenant_id == current_user.current_tenant_id
+        )
+
+        if status_list:
+            query = query.filter(DocumentSegment.status.in_(status_list))
+
+        if keyword:
+            query = query.where(DocumentSegment.content.ilike(f'%{keyword}%'))
+
+        total = query.count()
+        segments = query.order_by(DocumentSegment.position).all()
+        return {
+            'data': marshal(segments, segment_fields),
+            'doc_form': document.doc_form,
+            'total': total
+        }, 200
+
+
+class DatasetSegmentApi(DatasetApiResource):
+    def delete(self, tenant_id, dataset_id, document_id, segment_id):
+        # check dataset
+        dataset_id = str(dataset_id)
+        tenant_id = str(tenant_id)
+        dataset = db.session.query(Dataset).filter(
+            Dataset.tenant_id == tenant_id,
+            Dataset.id == dataset_id
+        ).first()
+        if not dataset:
+            raise NotFound('Dataset not found.')
+        # check user's model setting
+        DatasetService.check_dataset_model_setting(dataset)
+        # check document
+        document_id = str(document_id)
+        document = DocumentService.get_document(dataset_id, document_id)
+        if not document:
+            raise NotFound('Document not found.')
+        # check segment
+        segment = DocumentSegment.query.filter(
+            DocumentSegment.id == str(segment_id),
+            DocumentSegment.tenant_id == current_user.current_tenant_id
+        ).first()
+        if not segment:
+            raise NotFound('Segment not found.')
+        SegmentService.delete_segment(segment, document, dataset)
+        return {'result': 'success'}, 200
+
+    def post(self, tenant_id, dataset_id, document_id, segment_id):
+        # check dataset
+        dataset_id = str(dataset_id)
+        tenant_id = str(tenant_id)
+        dataset = db.session.query(Dataset).filter(
+            Dataset.tenant_id == tenant_id,
+            Dataset.id == dataset_id
+        ).first()
+        if not dataset:
+            raise NotFound('Dataset not found.')
+        # check user's model setting
+        DatasetService.check_dataset_model_setting(dataset)
+        # check document
+        document_id = str(document_id)
+        document = DocumentService.get_document(dataset_id, document_id)
+        if not document:
+            raise NotFound('Document not found.')
+        if dataset.indexing_technique == 'high_quality':
+            # check embedding model setting
+            try:
+                ModelFactory.get_embedding_model(
+                    tenant_id=current_user.current_tenant_id,
+                    model_provider_name=dataset.embedding_model_provider,
+                    model_name=dataset.embedding_model
+                )
+            except LLMBadRequestError:
+                raise ProviderNotInitializeError(
+                    f"No Embedding Model available. Please configure a valid provider "
+                    f"in the Settings -> Model Provider.")
+            except ProviderTokenNotInitError as ex:
+                raise ProviderNotInitializeError(ex.description)
+            # check segment
+        segment_id = str(segment_id)
+        segment = DocumentSegment.query.filter(
+            DocumentSegment.id == str(segment_id),
+            DocumentSegment.tenant_id == current_user.current_tenant_id
+        ).first()
+        if not segment:
+            raise NotFound('Segment not found.')
+
+        # validate args
+        parser = reqparse.RequestParser()
+        parser.add_argument('segments', type=dict, required=False, nullable=True, location='json')
+        args = parser.parse_args()
+
+        SegmentService.segment_create_args_validate(args['segments'], document)
+        segment = SegmentService.update_segment(args['segments'], segment, document, dataset)
+        return {
+            'data': marshal(segment, segment_fields),
+            'doc_form': document.doc_form
+        }, 200
+
 
 api.add_resource(SegmentApi, '/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments')
+api.add_resource(DatasetSegmentApi, '/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments/<uuid:segment_id>')

+ 2 - 0
api/services/dataset_service.py

@@ -1091,6 +1091,8 @@ class SegmentService:
                     segment.answer = args['answer']
                 if args['keywords']:
                     segment.keywords = args['keywords']
+                if args['enabled'] is not None:
+                    segment.enabled = args['enabled']
                 db.session.add(segment)
                 db.session.commit()
                 # update segment index task

File diff suppressed because it is too large
+ 244 - 38
web/app/(commonLayout)/datasets/template/template.en.mdx


File diff suppressed because it is too large
+ 249 - 42
web/app/(commonLayout)/datasets/template/template.zh.mdx