| 1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889 | import loggingfrom flask_login import current_userfrom flask_restful import Resource, marshal, reqparsefrom werkzeug.exceptions import Forbidden, InternalServerError, NotFoundimport servicesfrom controllers.console import apifrom controllers.console.app.error import (    CompletionRequestError,    ProviderModelCurrentlyNotSupportError,    ProviderNotInitializeError,    ProviderQuotaExceededError,)from controllers.console.datasets.error import DatasetNotInitializedErrorfrom controllers.console.setup import setup_requiredfrom controllers.console.wraps import account_initialization_requiredfrom core.errors.error import (    LLMBadRequestError,    ModelCurrentlyNotSupportError,    ProviderTokenNotInitError,    QuotaExceededError,)from core.model_runtime.errors.invoke import InvokeErrorfrom fields.hit_testing_fields import hit_testing_record_fieldsfrom libs.login import login_requiredfrom services.dataset_service import DatasetServicefrom services.hit_testing_service import HitTestingServiceclass HitTestingApi(Resource):    @setup_required    @login_required    @account_initialization_required    def post(self, dataset_id):        dataset_id_str = str(dataset_id)        dataset = DatasetService.get_dataset(dataset_id_str)        if dataset is None:            raise NotFound("Dataset not found.")        try:            DatasetService.check_dataset_permission(dataset, current_user)        except services.errors.account.NoPermissionError as e:            raise Forbidden(str(e))        parser = reqparse.RequestParser()        parser.add_argument("query", type=str, location="json")        parser.add_argument("retrieval_model", type=dict, required=False, location="json")        parser.add_argument("external_retrieval_model", type=dict, required=False, location="json")        args = parser.parse_args()        HitTestingService.hit_testing_args_check(args)        try:            response = HitTestingService.retrieve(                dataset=dataset,                query=args["query"],                account=current_user,                retrieval_model=args["retrieval_model"],                external_retrieval_model=args["external_retrieval_model"],                limit=10,            )            return {"query": response["query"], "records": marshal(response["records"], hit_testing_record_fields)}        except services.errors.index.IndexNotInitializedError:            raise DatasetNotInitializedError()        except ProviderTokenNotInitError as ex:            raise ProviderNotInitializeError(ex.description)        except QuotaExceededError:            raise ProviderQuotaExceededError()        except ModelCurrentlyNotSupportError:            raise ProviderModelCurrentlyNotSupportError()        except LLMBadRequestError:            raise ProviderNotInitializeError(                "No Embedding Model or Reranking Model available. Please configure a valid provider "                "in the Settings -> Model Provider."            )        except InvokeError as e:            raise CompletionRequestError(e.description)        except ValueError as e:            raise ValueError(str(e))        except Exception as e:            logging.exception("Hit testing failed.")            raise InternalServerError(str(e))api.add_resource(HitTestingApi, "/datasets/<uuid:dataset_id>/hit-testing")
 |