| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208 | import loggingimport threadingimport timefrom typing import Listimport numpy as npfrom flask import current_appfrom langchain.embeddings.base import Embeddingsfrom langchain.schema import Documentfrom sklearn.manifold import TSNEfrom core.embedding.cached_embedding import CacheEmbeddingfrom core.model_manager import ModelManagerfrom core.model_runtime.entities.model_entities import ModelTypefrom core.rerank.rerank import RerankRunnerfrom extensions.ext_database import dbfrom models.account import Accountfrom models.dataset import Dataset, DocumentSegment, DatasetQueryfrom services.retrieval_service import RetrievalServicedefault_retrieval_model = {    'search_method': 'semantic_search',    'reranking_enable': False,    'reranking_model': {        'reranking_provider_name': '',        'reranking_model_name': ''    },    'top_k': 2,    'score_threshold_enabled': False}class HitTestingService:    @classmethod    def retrieve(cls, dataset: Dataset, query: str, account: Account, retrieval_model: dict, limit: int = 10) -> dict:        if dataset.available_document_count == 0 or dataset.available_segment_count == 0:            return {                "query": {                    "content": query,                    "tsne_position": {'x': 0, 'y': 0},                },                "records": []            }        start = time.perf_counter()        # get retrieval model , if the model is not setting , using default        if not retrieval_model:            retrieval_model = dataset.retrieval_model if dataset.retrieval_model else default_retrieval_model        # get embedding model        model_manager = ModelManager()        embedding_model = model_manager.get_model_instance(            tenant_id=dataset.tenant_id,            model_type=ModelType.TEXT_EMBEDDING,            provider=dataset.embedding_model_provider,            model=dataset.embedding_model        )        embeddings = CacheEmbedding(embedding_model)        all_documents = []        threads = []        # retrieval_model source with semantic        if retrieval_model['search_method'] == 'semantic_search' or retrieval_model['search_method'] == 'hybrid_search':            embedding_thread = threading.Thread(target=RetrievalService.embedding_search, kwargs={                'flask_app': current_app._get_current_object(),                'dataset_id': str(dataset.id),                'query': query,                'top_k': retrieval_model['top_k'],                'score_threshold': retrieval_model['score_threshold'] if retrieval_model['score_threshold_enabled'] else None,                'reranking_model': retrieval_model['reranking_model'] if retrieval_model['reranking_enable'] else None,                'all_documents': all_documents,                'search_method': retrieval_model['search_method'],                'embeddings': embeddings            })            threads.append(embedding_thread)            embedding_thread.start()        # retrieval source with full text        if retrieval_model['search_method'] == 'full_text_search' or retrieval_model['search_method'] == 'hybrid_search':            full_text_index_thread = threading.Thread(target=RetrievalService.full_text_index_search, kwargs={                'flask_app': current_app._get_current_object(),                'dataset_id': str(dataset.id),                'query': query,                'search_method': retrieval_model['search_method'],                'embeddings': embeddings,                'score_threshold': retrieval_model['score_threshold'] if retrieval_model['score_threshold_enabled'] else None,                'top_k': retrieval_model['top_k'],                'reranking_model': retrieval_model['reranking_model'] if retrieval_model['reranking_enable'] else None,                'all_documents': all_documents            })            threads.append(full_text_index_thread)            full_text_index_thread.start()        for thread in threads:            thread.join()        if retrieval_model['search_method'] == 'hybrid_search':            model_manager = ModelManager()            rerank_model_instance = model_manager.get_model_instance(                tenant_id=dataset.tenant_id,                provider=retrieval_model['reranking_model']['reranking_provider_name'],                model_type=ModelType.RERANK,                model=retrieval_model['reranking_model']['reranking_model_name']            )            rerank_runner = RerankRunner(rerank_model_instance)            all_documents = rerank_runner.run(                query=query,                documents=all_documents,                score_threshold=retrieval_model['score_threshold'] if retrieval_model['score_threshold_enabled'] else None,                top_n=retrieval_model['top_k'],                user=f"account-{account.id}"            )        end = time.perf_counter()        logging.debug(f"Hit testing retrieve in {end - start:0.4f} seconds")        dataset_query = DatasetQuery(            dataset_id=dataset.id,            content=query,            source='hit_testing',            created_by_role='account',            created_by=account.id        )        db.session.add(dataset_query)        db.session.commit()        return cls.compact_retrieve_response(dataset, embeddings, query, all_documents)    @classmethod    def compact_retrieve_response(cls, dataset: Dataset, embeddings: Embeddings, query: str, documents: List[Document]):        text_embeddings = [            embeddings.embed_query(query)        ]        text_embeddings.extend(embeddings.embed_documents([document.page_content for document in documents]))        tsne_position_data = cls.get_tsne_positions_from_embeddings(text_embeddings)        query_position = tsne_position_data.pop(0)        i = 0        records = []        for document in documents:            index_node_id = document.metadata['doc_id']            segment = db.session.query(DocumentSegment).filter(                DocumentSegment.dataset_id == dataset.id,                DocumentSegment.enabled == True,                DocumentSegment.status == 'completed',                DocumentSegment.index_node_id == index_node_id            ).first()            if not segment:                i += 1                continue            record = {                "segment": segment,                "score": document.metadata.get('score', None),                "tsne_position": tsne_position_data[i]            }            records.append(record)            i += 1        return {            "query": {                "content": query,                "tsne_position": query_position,            },            "records": records        }    @classmethod    def get_tsne_positions_from_embeddings(cls, embeddings: list):        embedding_length = len(embeddings)        if embedding_length <= 1:            return [{'x': 0, 'y': 0}]        concatenate_data = np.array(embeddings).reshape(embedding_length, -1)        # concatenate_data = np.concatenate(embeddings)        perplexity = embedding_length / 2 + 1        if perplexity >= embedding_length:            perplexity = max(embedding_length - 1, 1)        tsne = TSNE(n_components=2, perplexity=perplexity, early_exaggeration=12.0)        data_tsne = tsne.fit_transform(concatenate_data)        tsne_position_data = []        for i in range(len(data_tsne)):            tsne_position_data.append({'x': float(data_tsne[i][0]), 'y': float(data_tsne[i][1])})        return tsne_position_data    @classmethod    def hit_testing_args_check(cls, args):        query = args['query']        if not query or len(query) > 250:            raise ValueError('Query is required and cannot exceed 250 characters')
 |