| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081 | import loggingfrom flask_login import current_userfrom libs.login import login_requiredfrom flask_restful import Resource, reqparse, marshalfrom werkzeug.exceptions import InternalServerError, NotFound, Forbiddenimport servicesfrom controllers.console import apifrom controllers.console.app.error import ProviderNotInitializeError, ProviderQuotaExceededError, \    ProviderModelCurrentlyNotSupportErrorfrom controllers.console.datasets.error import HighQualityDatasetOnlyError, DatasetNotInitializedErrorfrom controllers.console.setup import setup_requiredfrom controllers.console.wraps import account_initialization_requiredfrom core.model_providers.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError, \    LLMBadRequestErrorfrom fields.hit_testing_fields import hit_testing_record_fieldsfrom services.dataset_service import DatasetServicefrom services.hit_testing_service import HitTestingServiceclass HitTestingApi(Resource):    @setup_required    @login_required    @account_initialization_required    def post(self, dataset_id):        dataset_id_str = str(dataset_id)        dataset = DatasetService.get_dataset(dataset_id_str)        if dataset is None:            raise NotFound("Dataset not found.")        try:            DatasetService.check_dataset_permission(dataset, current_user)        except services.errors.account.NoPermissionError as e:            raise Forbidden(str(e))        # only high quality dataset can be used for hit testing        if dataset.indexing_technique != 'high_quality':            raise HighQualityDatasetOnlyError()        parser = reqparse.RequestParser()        parser.add_argument('query', type=str, location='json')        args = parser.parse_args()        query = args['query']        if not query or len(query) > 250:            raise ValueError('Query is required and cannot exceed 250 characters')        try:            response = HitTestingService.retrieve(                dataset=dataset,                query=query,                account=current_user,                limit=10,            )            return {"query": response['query'], 'records': marshal(response['records'], hit_testing_record_fields)}        except services.errors.index.IndexNotInitializedError:            raise DatasetNotInitializedError()        except ProviderTokenNotInitError as ex:            raise ProviderNotInitializeError(ex.description)        except QuotaExceededError:            raise ProviderQuotaExceededError()        except ModelCurrentlyNotSupportError:            raise ProviderModelCurrentlyNotSupportError()        except LLMBadRequestError:            raise ProviderNotInitializeError(                f"No Embedding Model available. Please configure a valid provider "                f"in the Settings -> Model Provider.")        except ValueError as e:            raise ValueError(str(e))        except Exception as e:            logging.exception("Hit testing failed.")            raise InternalServerError(str(e))api.add_resource(HitTestingApi, '/datasets/<uuid:dataset_id>/hit-testing')
 |