| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113 | import loggingfrom flask_login import login_required, current_userfrom flask_restful import Resource, reqparse, marshal, fieldsfrom werkzeug.exceptions import InternalServerError, NotFound, Forbiddenimport servicesfrom controllers.console import apifrom controllers.console.app.error import ProviderNotInitializeError, ProviderQuotaExceededError, \    ProviderModelCurrentlyNotSupportErrorfrom controllers.console.datasets.error import HighQualityDatasetOnlyError, DatasetNotInitializedErrorfrom controllers.console.setup import setup_requiredfrom controllers.console.wraps import account_initialization_requiredfrom core.model_providers.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportErrorfrom libs.helper import TimestampFieldfrom services.dataset_service import DatasetServicefrom services.hit_testing_service import HitTestingServicedocument_fields = {    'id': fields.String,    'data_source_type': fields.String,    'name': fields.String,    'doc_type': fields.String,}segment_fields = {    'id': fields.String,    'position': fields.Integer,    'document_id': fields.String,    'content': fields.String,    'answer': fields.String,    'word_count': fields.Integer,    'tokens': fields.Integer,    'keywords': fields.List(fields.String),    'index_node_id': fields.String,    'index_node_hash': fields.String,    'hit_count': fields.Integer,    'enabled': fields.Boolean,    'disabled_at': TimestampField,    'disabled_by': fields.String,    'status': fields.String,    'created_by': fields.String,    'created_at': TimestampField,    'indexing_at': TimestampField,    'completed_at': TimestampField,    'error': fields.String,    'stopped_at': TimestampField,    'document': fields.Nested(document_fields),}hit_testing_record_fields = {    'segment': fields.Nested(segment_fields),    'score': fields.Float,    'tsne_position': fields.Raw}class HitTestingApi(Resource):    @setup_required    @login_required    @account_initialization_required    def post(self, dataset_id):        dataset_id_str = str(dataset_id)        dataset = DatasetService.get_dataset(dataset_id_str)        if dataset is None:            raise NotFound("Dataset not found.")        try:            DatasetService.check_dataset_permission(dataset, current_user)        except services.errors.account.NoPermissionError as e:            raise Forbidden(str(e))        # only high quality dataset can be used for hit testing        if dataset.indexing_technique != 'high_quality':            raise HighQualityDatasetOnlyError()        parser = reqparse.RequestParser()        parser.add_argument('query', type=str, location='json')        args = parser.parse_args()        query = args['query']        if not query or len(query) > 250:            raise ValueError('Query is required and cannot exceed 250 characters')        try:            response = HitTestingService.retrieve(                dataset=dataset,                query=query,                account=current_user,                limit=10,            )            return {"query": response['query'], 'records': marshal(response['records'], hit_testing_record_fields)}        except services.errors.index.IndexNotInitializedError:            raise DatasetNotInitializedError()        except ProviderTokenNotInitError as ex:            raise ProviderNotInitializeError(ex.description)        except QuotaExceededError:            raise ProviderQuotaExceededError()        except ModelCurrentlyNotSupportError:            raise ProviderModelCurrentlyNotSupportError()        except ValueError as e:            raise ValueError(str(e))        except Exception as e:            logging.exception("Hit testing failed.")            raise InternalServerError(str(e))api.add_resource(HitTestingApi, '/datasets/<uuid:dataset_id>/hit-testing')
 |