| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232 | import threadingfrom typing import Optionalfrom flask import Flask, current_appfrom core.rag.data_post_processor.data_post_processor import DataPostProcessorfrom core.rag.datasource.keyword.keyword_factory import Keywordfrom core.rag.datasource.vdb.vector_factory import Vectorfrom core.rag.rerank.rerank_type import RerankModefrom core.rag.retrieval.retrieval_methods import RetrievalMethodfrom extensions.ext_database import dbfrom models.dataset import Datasetfrom services.external_knowledge_service import ExternalDatasetServicedefault_retrieval_model = {    "search_method": RetrievalMethod.SEMANTIC_SEARCH.value,    "reranking_enable": False,    "reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""},    "top_k": 2,    "score_threshold_enabled": False,}class RetrievalService:    @classmethod    def retrieve(        cls,        retrieval_method: str,        dataset_id: str,        query: str,        top_k: int,        score_threshold: Optional[float] = 0.0,        reranking_model: Optional[dict] = None,        reranking_mode: Optional[str] = "reranking_model",        weights: Optional[dict] = None,    ):        dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first()        if not dataset:            return []        if not dataset or dataset.available_document_count == 0 or dataset.available_segment_count == 0:            return []        all_documents = []        threads = []        exceptions = []        # retrieval_model source with keyword        if retrieval_method == "keyword_search":            keyword_thread = threading.Thread(                target=RetrievalService.keyword_search,                kwargs={                    "flask_app": current_app._get_current_object(),                    "dataset_id": dataset_id,                    "query": query,                    "top_k": top_k,                    "all_documents": all_documents,                    "exceptions": exceptions,                },            )            threads.append(keyword_thread)            keyword_thread.start()        # retrieval_model source with semantic        if RetrievalMethod.is_support_semantic_search(retrieval_method):            embedding_thread = threading.Thread(                target=RetrievalService.embedding_search,                kwargs={                    "flask_app": current_app._get_current_object(),                    "dataset_id": dataset_id,                    "query": query,                    "top_k": top_k,                    "score_threshold": score_threshold,                    "reranking_model": reranking_model,                    "all_documents": all_documents,                    "retrieval_method": retrieval_method,                    "exceptions": exceptions,                },            )            threads.append(embedding_thread)            embedding_thread.start()        # retrieval source with full text        if RetrievalMethod.is_support_fulltext_search(retrieval_method):            full_text_index_thread = threading.Thread(                target=RetrievalService.full_text_index_search,                kwargs={                    "flask_app": current_app._get_current_object(),                    "dataset_id": dataset_id,                    "query": query,                    "retrieval_method": retrieval_method,                    "score_threshold": score_threshold,                    "top_k": top_k,                    "reranking_model": reranking_model,                    "all_documents": all_documents,                    "exceptions": exceptions,                },            )            threads.append(full_text_index_thread)            full_text_index_thread.start()        for thread in threads:            thread.join()        if exceptions:            exception_message = ";\n".join(exceptions)            raise Exception(exception_message)        if retrieval_method == RetrievalMethod.HYBRID_SEARCH.value:            data_post_processor = DataPostProcessor(                str(dataset.tenant_id), reranking_mode, reranking_model, weights, False            )            all_documents = data_post_processor.invoke(                query=query, documents=all_documents, score_threshold=score_threshold, top_n=top_k            )        return all_documents    @classmethod    def external_retrieve(cls, dataset_id: str, query: str, external_retrieval_model: Optional[dict] = None):        dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first()        if not dataset:            return []        all_documents = ExternalDatasetService.fetch_external_knowledge_retrieval(            dataset.tenant_id, dataset_id, query, external_retrieval_model        )        return all_documents    @classmethod    def keyword_search(        cls, flask_app: Flask, dataset_id: str, query: str, top_k: int, all_documents: list, exceptions: list    ):        with flask_app.app_context():            try:                dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first()                keyword = Keyword(dataset=dataset)                documents = keyword.search(cls.escape_query_for_search(query), top_k=top_k)                all_documents.extend(documents)            except Exception as e:                exceptions.append(str(e))    @classmethod    def embedding_search(        cls,        flask_app: Flask,        dataset_id: str,        query: str,        top_k: int,        score_threshold: Optional[float],        reranking_model: Optional[dict],        all_documents: list,        retrieval_method: str,        exceptions: list,    ):        with flask_app.app_context():            try:                dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first()                vector = Vector(dataset=dataset)                documents = vector.search_by_vector(                    cls.escape_query_for_search(query),                    search_type="similarity_score_threshold",                    top_k=top_k,                    score_threshold=score_threshold,                    filter={"group_id": [dataset.id]},                )                if documents:                    if (                        reranking_model                        and reranking_model.get("reranking_model_name")                        and reranking_model.get("reranking_provider_name")                        and retrieval_method == RetrievalMethod.SEMANTIC_SEARCH.value                    ):                        data_post_processor = DataPostProcessor(                            str(dataset.tenant_id), RerankMode.RERANKING_MODEL.value, reranking_model, None, False                        )                        all_documents.extend(                            data_post_processor.invoke(                                query=query, documents=documents, score_threshold=score_threshold, top_n=len(documents)                            )                        )                    else:                        all_documents.extend(documents)            except Exception as e:                exceptions.append(str(e))    @classmethod    def full_text_index_search(        cls,        flask_app: Flask,        dataset_id: str,        query: str,        top_k: int,        score_threshold: Optional[float],        reranking_model: Optional[dict],        all_documents: list,        retrieval_method: str,        exceptions: list,    ):        with flask_app.app_context():            try:                dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first()                vector_processor = Vector(                    dataset=dataset,                )                documents = vector_processor.search_by_full_text(cls.escape_query_for_search(query), top_k=top_k)                if documents:                    if (                        reranking_model                        and reranking_model.get("reranking_model_name")                        and reranking_model.get("reranking_provider_name")                        and retrieval_method == RetrievalMethod.FULL_TEXT_SEARCH.value                    ):                        data_post_processor = DataPostProcessor(                            str(dataset.tenant_id), RerankMode.RERANKING_MODEL.value, reranking_model, None, False                        )                        all_documents.extend(                            data_post_processor.invoke(                                query=query, documents=documents, score_threshold=score_threshold, top_n=len(documents)                            )                        )                    else:                        all_documents.extend(documents)            except Exception as e:                exceptions.append(str(e))    @staticmethod    def escape_query_for_search(query: str) -> str:        return query.replace('"', '\\"')
 |