| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107 | import loggingimport timefrom core.rag.datasource.retrieval_service import RetrievalServicefrom core.rag.models.document import Documentfrom core.rag.retrieval.retrival_methods import RetrievalMethodfrom extensions.ext_database import dbfrom models.account import Accountfrom models.dataset import Dataset, DatasetQuery, DocumentSegmentdefault_retrieval_model = {    'search_method': RetrievalMethod.SEMANTIC_SEARCH,    'reranking_enable': False,    'reranking_model': {        'reranking_provider_name': '',        'reranking_model_name': ''    },    'top_k': 2,    'score_threshold_enabled': False}class HitTestingService:    @classmethod    def retrieve(cls, dataset: Dataset, query: str, account: Account, retrieval_model: dict, limit: int = 10) -> dict:        if dataset.available_document_count == 0 or dataset.available_segment_count == 0:            return {                "query": {                    "content": query,                    "tsne_position": {'x': 0, 'y': 0},                },                "records": []            }        start = time.perf_counter()        # get retrieval model , if the model is not setting , using default        if not retrieval_model:            retrieval_model = dataset.retrieval_model if dataset.retrieval_model else default_retrieval_model        all_documents = RetrievalService.retrieve(retrival_method=retrieval_model['search_method'],                                                  dataset_id=dataset.id,                                                  query=query,                                                  top_k=retrieval_model['top_k'],                                                  score_threshold=retrieval_model['score_threshold']                                                  if retrieval_model['score_threshold_enabled'] else None,                                                  reranking_model=retrieval_model['reranking_model']                                                  if retrieval_model['reranking_enable'] else None                                                  )        end = time.perf_counter()        logging.debug(f"Hit testing retrieve in {end - start:0.4f} seconds")        dataset_query = DatasetQuery(            dataset_id=dataset.id,            content=query,            source='hit_testing',            created_by_role='account',            created_by=account.id        )        db.session.add(dataset_query)        db.session.commit()        return cls.compact_retrieve_response(dataset, query, all_documents)    @classmethod    def compact_retrieve_response(cls, dataset: Dataset, query: str, documents: list[Document]):        i = 0        records = []        for document in documents:            index_node_id = document.metadata['doc_id']            segment = db.session.query(DocumentSegment).filter(                DocumentSegment.dataset_id == dataset.id,                DocumentSegment.enabled == True,                DocumentSegment.status == 'completed',                DocumentSegment.index_node_id == index_node_id            ).first()            if not segment:                i += 1                continue            record = {                "segment": segment,                "score": document.metadata.get('score', None),            }            records.append(record)            i += 1        return {            "query": {                "content": query,            },            "records": records        }    @classmethod    def hit_testing_args_check(cls, args):        query = args['query']        if not query or len(query) > 250:            raise ValueError('Query is required and cannot exceed 250 characters')
 |