| 1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586 | import loggingfrom flask_login import current_user  # type: ignorefrom flask_restful import marshal, reqparse  # type: ignorefrom werkzeug.exceptions import Forbidden, InternalServerError, NotFoundimport services.dataset_servicefrom controllers.console.app.error import (    CompletionRequestError,    ProviderModelCurrentlyNotSupportError,    ProviderNotInitializeError,    ProviderQuotaExceededError,)from controllers.console.datasets.error import DatasetNotInitializedErrorfrom core.errors.error import (    LLMBadRequestError,    ModelCurrentlyNotSupportError,    ProviderTokenNotInitError,    QuotaExceededError,)from core.model_runtime.errors.invoke import InvokeErrorfrom fields.hit_testing_fields import hit_testing_record_fieldsfrom services.dataset_service import DatasetServicefrom services.hit_testing_service import HitTestingServiceclass DatasetsHitTestingBase:    @staticmethod    def get_and_validate_dataset(dataset_id: str):        dataset = DatasetService.get_dataset(dataset_id)        if dataset is None:            raise NotFound("Dataset not found.")        try:            DatasetService.check_dataset_permission(dataset, current_user)        except services.errors.account.NoPermissionError as e:            raise Forbidden(str(e))        return dataset    @staticmethod    def hit_testing_args_check(args):        HitTestingService.hit_testing_args_check(args)    @staticmethod    def parse_args():        parser = reqparse.RequestParser()        parser.add_argument("query", type=str, location="json")        parser.add_argument("retrieval_model", type=dict, required=False, location="json")        parser.add_argument("external_retrieval_model", type=dict, required=False, location="json")        return parser.parse_args()    @staticmethod    def perform_hit_testing(dataset, args):        try:            response = HitTestingService.retrieve(                dataset=dataset,                query=args["query"],                account=current_user,                retrieval_model=args["retrieval_model"],                external_retrieval_model=args["external_retrieval_model"],                limit=10,            )            return {"query": response["query"], "records": marshal(response["records"], hit_testing_record_fields)}        except services.errors.index.IndexNotInitializedError:            raise DatasetNotInitializedError()        except ProviderTokenNotInitError as ex:            raise ProviderNotInitializeError(ex.description)        except QuotaExceededError:            raise ProviderQuotaExceededError()        except ModelCurrentlyNotSupportError:            raise ProviderModelCurrentlyNotSupportError()        except LLMBadRequestError:            raise ProviderNotInitializeError(                "No Embedding Model or Reranking Model available. Please configure a valid provider "                "in the Settings -> Model Provider."            )        except InvokeError as e:            raise CompletionRequestError(e.description)        except ValueError as e:            raise ValueError(str(e))        except Exception as e:            logging.exception("Hit testing failed.")            raise InternalServerError(str(e))
 |