| 12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091 | import loggingfrom flask_login import current_userfrom flask_restful import Resource, marshal, reqparsefrom werkzeug.exceptions import Forbidden, InternalServerError, NotFoundimport servicesfrom controllers.console import apifrom controllers.console.app.error import (    CompletionRequestError,    ProviderModelCurrentlyNotSupportError,    ProviderNotInitializeError,    ProviderQuotaExceededError,)from controllers.console.datasets.error import DatasetNotInitializedError, HighQualityDatasetOnlyErrorfrom controllers.console.setup import setup_requiredfrom controllers.console.wraps import account_initialization_requiredfrom core.errors.error import (    LLMBadRequestError,    ModelCurrentlyNotSupportError,    ProviderTokenNotInitError,    QuotaExceededError,)from core.model_runtime.errors.invoke import InvokeErrorfrom fields.hit_testing_fields import hit_testing_record_fieldsfrom libs.login import login_requiredfrom services.dataset_service import DatasetServicefrom services.hit_testing_service import HitTestingServiceclass HitTestingApi(Resource):    @setup_required    @login_required    @account_initialization_required    def post(self, dataset_id):        dataset_id_str = str(dataset_id)        dataset = DatasetService.get_dataset(dataset_id_str)        if dataset is None:            raise NotFound("Dataset not found.")        try:            DatasetService.check_dataset_permission(dataset, current_user)        except services.errors.account.NoPermissionError as e:            raise Forbidden(str(e))        # only high quality dataset can be used for hit testing        if dataset.indexing_technique != 'high_quality':            raise HighQualityDatasetOnlyError()        parser = reqparse.RequestParser()        parser.add_argument('query', type=str, location='json')        parser.add_argument('retrieval_model', type=dict, required=False, location='json')        args = parser.parse_args()        HitTestingService.hit_testing_args_check(args)        try:            response = HitTestingService.retrieve(                dataset=dataset,                query=args['query'],                account=current_user,                retrieval_model=args['retrieval_model'],                limit=10            )            return {"query": response['query'], 'records': marshal(response['records'], hit_testing_record_fields)}        except services.errors.index.IndexNotInitializedError:            raise DatasetNotInitializedError()        except ProviderTokenNotInitError as ex:            raise ProviderNotInitializeError(ex.description)        except QuotaExceededError:            raise ProviderQuotaExceededError()        except ModelCurrentlyNotSupportError:            raise ProviderModelCurrentlyNotSupportError()        except LLMBadRequestError:            raise ProviderNotInitializeError(                "No Embedding Model or Reranking Model available. Please configure a valid provider "                "in the Settings -> Model Provider.")        except InvokeError as e:            raise CompletionRequestError(e.description)        except ValueError as e:            raise ValueError(str(e))        except Exception as e:            logging.exception("Hit testing failed.")            raise InternalServerError(str(e))api.add_resource(HitTestingApi, '/datasets/<uuid:dataset_id>/hit-testing')
 |