| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468 | import threadingfrom typing import Optional, castfrom flask import Flask, current_appfrom langchain.tools import BaseToolfrom core.app.app_config.entities import DatasetEntity, DatasetRetrieveConfigEntityfrom core.app.entities.app_invoke_entities import InvokeFrom, ModelConfigWithCredentialsEntityfrom core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandlerfrom core.entities.agent_entities import PlanningStrategyfrom core.memory.token_buffer_memory import TokenBufferMemoryfrom core.model_manager import ModelInstance, ModelManagerfrom core.model_runtime.entities.message_entities import PromptMessageToolfrom core.model_runtime.entities.model_entities import ModelFeature, ModelTypefrom core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModelfrom core.rag.datasource.retrieval_service import RetrievalServicefrom core.rag.models.document import Documentfrom core.rag.retrieval.router.multi_dataset_function_call_router import FunctionCallMultiDatasetRouterfrom core.rag.retrieval.router.multi_dataset_react_route import ReactMultiDatasetRouterfrom core.rerank.rerank import RerankRunnerfrom core.tools.tool.dataset_retriever.dataset_multi_retriever_tool import DatasetMultiRetrieverToolfrom core.tools.tool.dataset_retriever.dataset_retriever_tool import DatasetRetrieverToolfrom extensions.ext_database import dbfrom models.dataset import Dataset, DatasetQuery, DocumentSegmentfrom models.dataset import Document as DatasetDocumentdefault_retrieval_model = {    'search_method': 'semantic_search',    'reranking_enable': False,    'reranking_model': {        'reranking_provider_name': '',        'reranking_model_name': ''    },    'top_k': 2,    'score_threshold_enabled': False}class DatasetRetrieval:    def retrieve(self, app_id: str, user_id: str, tenant_id: str,                 model_config: ModelConfigWithCredentialsEntity,                 config: DatasetEntity,                 query: str,                 invoke_from: InvokeFrom,                 show_retrieve_source: bool,                 hit_callback: DatasetIndexToolCallbackHandler,                 memory: Optional[TokenBufferMemory] = None) -> Optional[str]:        """        Retrieve dataset.        :param app_id: app_id        :param user_id: user_id        :param tenant_id: tenant id        :param model_config: model config        :param config: dataset config        :param query: query        :param invoke_from: invoke from        :param show_retrieve_source: show retrieve source        :param hit_callback: hit callback        :param memory: memory        :return:        """        dataset_ids = config.dataset_ids        if len(dataset_ids) == 0:            return None        retrieve_config = config.retrieve_config        # check model is support tool calling        model_type_instance = model_config.provider_model_bundle.model_type_instance        model_type_instance = cast(LargeLanguageModel, model_type_instance)        model_manager = ModelManager()        model_instance = model_manager.get_model_instance(            tenant_id=tenant_id,            model_type=ModelType.LLM,            provider=model_config.provider,            model=model_config.model        )        # get model schema        model_schema = model_type_instance.get_model_schema(            model=model_config.model,            credentials=model_config.credentials        )        if not model_schema:            return None        planning_strategy = PlanningStrategy.REACT_ROUTER        features = model_schema.features        if features:            if ModelFeature.TOOL_CALL in features \                    or ModelFeature.MULTI_TOOL_CALL in features:                planning_strategy = PlanningStrategy.ROUTER        available_datasets = []        for dataset_id in dataset_ids:            # get dataset from dataset id            dataset = db.session.query(Dataset).filter(                Dataset.tenant_id == tenant_id,                Dataset.id == dataset_id            ).first()            # pass if dataset is not available            if not dataset:                continue            # pass if dataset is not available            if (dataset and dataset.available_document_count == 0                    and dataset.available_document_count == 0):                continue            available_datasets.append(dataset)        all_documents = []        user_from = 'account' if invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] else 'end_user'        if retrieve_config.retrieve_strategy == DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE:            all_documents = self.single_retrieve(app_id, tenant_id, user_id, user_from, available_datasets, query,                                                 model_instance,                                                 model_config, planning_strategy)        elif retrieve_config.retrieve_strategy == DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE:            all_documents = self.multiple_retrieve(app_id, tenant_id, user_id, user_from,                                                   available_datasets, query, retrieve_config.top_k,                                                   retrieve_config.score_threshold,                                                   retrieve_config.reranking_model.get('reranking_provider_name'),                                                   retrieve_config.reranking_model.get('reranking_model_name'))        document_score_list = {}        for item in all_documents:            if 'score' in item.metadata and item.metadata['score']:                document_score_list[item.metadata['doc_id']] = item.metadata['score']        document_context_list = []        index_node_ids = [document.metadata['doc_id'] for document in all_documents]        segments = DocumentSegment.query.filter(            DocumentSegment.dataset_id.in_(dataset_ids),            DocumentSegment.completed_at.isnot(None),            DocumentSegment.status == 'completed',            DocumentSegment.enabled == True,            DocumentSegment.index_node_id.in_(index_node_ids)        ).all()        if segments:            index_node_id_to_position = {id: position for position, id in enumerate(index_node_ids)}            sorted_segments = sorted(segments,                                     key=lambda segment: index_node_id_to_position.get(segment.index_node_id,                                                                                       float('inf')))            for segment in sorted_segments:                if segment.answer:                    document_context_list.append(f'question:{segment.content} answer:{segment.answer}')                else:                    document_context_list.append(segment.content)            if show_retrieve_source:                context_list = []                resource_number = 1                for segment in sorted_segments:                    dataset = Dataset.query.filter_by(                        id=segment.dataset_id                    ).first()                    document = DatasetDocument.query.filter(DatasetDocument.id == segment.document_id,                                                            DatasetDocument.enabled == True,                                                            DatasetDocument.archived == False,                                                            ).first()                    if dataset and document:                        source = {                            'position': resource_number,                            'dataset_id': dataset.id,                            'dataset_name': dataset.name,                            'document_id': document.id,                            'document_name': document.name,                            'data_source_type': document.data_source_type,                            'segment_id': segment.id,                            'retriever_from': invoke_from.to_source(),                            'score': document_score_list.get(segment.index_node_id, None)                        }                        if invoke_from.to_source() == 'dev':                            source['hit_count'] = segment.hit_count                            source['word_count'] = segment.word_count                            source['segment_position'] = segment.position                            source['index_node_hash'] = segment.index_node_hash                        if segment.answer:                            source['content'] = f'question:{segment.content} \nanswer:{segment.answer}'                        else:                            source['content'] = segment.content                        context_list.append(source)                    resource_number += 1                if hit_callback:                    hit_callback.return_retriever_resource_info(context_list)            return str("\n".join(document_context_list))        return ''    def single_retrieve(self, app_id: str,                        tenant_id: str,                        user_id: str,                        user_from: str,                        available_datasets: list,                        query: str,                        model_instance: ModelInstance,                        model_config: ModelConfigWithCredentialsEntity,                        planning_strategy: PlanningStrategy,                        ):        tools = []        for dataset in available_datasets:            description = dataset.description            if not description:                description = 'useful for when you want to answer queries about the ' + dataset.name            description = description.replace('\n', '').replace('\r', '')            message_tool = PromptMessageTool(                name=dataset.id,                description=description,                parameters={                    "type": "object",                    "properties": {},                    "required": [],                }            )            tools.append(message_tool)        dataset_id = None        if planning_strategy == PlanningStrategy.REACT_ROUTER:            react_multi_dataset_router = ReactMultiDatasetRouter()            dataset_id = react_multi_dataset_router.invoke(query, tools, model_config, model_instance,                                                           user_id, tenant_id)        elif planning_strategy == PlanningStrategy.ROUTER:            function_call_router = FunctionCallMultiDatasetRouter()            dataset_id = function_call_router.invoke(query, tools, model_config, model_instance)        if dataset_id:            # get retrieval model config            dataset = db.session.query(Dataset).filter(                Dataset.id == dataset_id            ).first()            if dataset:                retrieval_model_config = dataset.retrieval_model \                    if dataset.retrieval_model else default_retrieval_model                # get top k                top_k = retrieval_model_config['top_k']                # get retrieval method                if dataset.indexing_technique == "economy":                    retrival_method = 'keyword_search'                else:                    retrival_method = retrieval_model_config['search_method']                # get reranking model                reranking_model = retrieval_model_config['reranking_model'] \                    if retrieval_model_config['reranking_enable'] else None                # get score threshold                score_threshold = .0                score_threshold_enabled = retrieval_model_config.get("score_threshold_enabled")                if score_threshold_enabled:                    score_threshold = retrieval_model_config.get("score_threshold")                results = RetrievalService.retrieve(retrival_method=retrival_method, dataset_id=dataset.id,                                                    query=query,                                                    top_k=top_k, score_threshold=score_threshold,                                                    reranking_model=reranking_model)                self._on_query(query, [dataset_id], app_id, user_from, user_id)                if results:                    self._on_retrival_end(results)                return results        return []    def multiple_retrieve(self,                          app_id: str,                          tenant_id: str,                          user_id: str,                          user_from: str,                          available_datasets: list,                          query: str,                          top_k: int,                          score_threshold: float,                          reranking_provider_name: str,                          reranking_model_name: str):        threads = []        all_documents = []        dataset_ids = [dataset.id for dataset in available_datasets]        for dataset in available_datasets:            retrieval_thread = threading.Thread(target=self._retriever, kwargs={                'flask_app': current_app._get_current_object(),                'dataset_id': dataset.id,                'query': query,                'top_k': top_k,                'all_documents': all_documents,            })            threads.append(retrieval_thread)            retrieval_thread.start()        for thread in threads:            thread.join()        # do rerank for searched documents        model_manager = ModelManager()        rerank_model_instance = model_manager.get_model_instance(            tenant_id=tenant_id,            provider=reranking_provider_name,            model_type=ModelType.RERANK,            model=reranking_model_name        )        rerank_runner = RerankRunner(rerank_model_instance)        all_documents = rerank_runner.run(query, all_documents,                                          score_threshold,                                          top_k)        self._on_query(query, dataset_ids, app_id, user_from, user_id)        if all_documents:            self._on_retrival_end(all_documents)        return all_documents    def _on_retrival_end(self, documents: list[Document]) -> None:        """Handle retrival end."""        for document in documents:            query = db.session.query(DocumentSegment).filter(                DocumentSegment.index_node_id == document.metadata['doc_id']            )            # if 'dataset_id' in document.metadata:            if 'dataset_id' in document.metadata:                query = query.filter(DocumentSegment.dataset_id == document.metadata['dataset_id'])            # add hit count to document segment            query.update(                {DocumentSegment.hit_count: DocumentSegment.hit_count + 1},                synchronize_session=False            )            db.session.commit()    def _on_query(self, query: str, dataset_ids: list[str], app_id: str, user_from: str, user_id: str) -> None:        """        Handle query.        """        if not query:            return        for dataset_id in dataset_ids:            dataset_query = DatasetQuery(                dataset_id=dataset_id,                content=query,                source='app',                source_app_id=app_id,                created_by_role=user_from,                created_by=user_id            )            db.session.add(dataset_query)        db.session.commit()    def _retriever(self, flask_app: Flask, dataset_id: str, query: str, top_k: int, all_documents: list):        with flask_app.app_context():            dataset = db.session.query(Dataset).filter(                Dataset.id == dataset_id            ).first()            if not dataset:                return []            # get retrieval model , if the model is not setting , using default            retrieval_model = dataset.retrieval_model if dataset.retrieval_model else default_retrieval_model            if dataset.indexing_technique == "economy":                # use keyword table query                documents = RetrievalService.retrieve(retrival_method='keyword_search',                                                      dataset_id=dataset.id,                                                      query=query,                                                      top_k=top_k                                                      )                if documents:                    all_documents.extend(documents)            else:                if top_k > 0:                    # retrieval source                    documents = RetrievalService.retrieve(retrival_method=retrieval_model['search_method'],                                                          dataset_id=dataset.id,                                                          query=query,                                                          top_k=top_k,                                                          score_threshold=retrieval_model['score_threshold']                                                          if retrieval_model['score_threshold_enabled'] else None,                                                          reranking_model=retrieval_model['reranking_model']                                                          if retrieval_model['reranking_enable'] else None                                                          )                    all_documents.extend(documents)    def to_dataset_retriever_tool(self, tenant_id: str,                                  dataset_ids: list[str],                                  retrieve_config: DatasetRetrieveConfigEntity,                                  return_resource: bool,                                  invoke_from: InvokeFrom,                                  hit_callback: DatasetIndexToolCallbackHandler) \            -> Optional[list[BaseTool]]:        """        A dataset tool is a tool that can be used to retrieve information from a dataset        :param tenant_id: tenant id        :param dataset_ids: dataset ids        :param retrieve_config: retrieve config        :param return_resource: return resource        :param invoke_from: invoke from        :param hit_callback: hit callback        """        tools = []        available_datasets = []        for dataset_id in dataset_ids:            # get dataset from dataset id            dataset = db.session.query(Dataset).filter(                Dataset.tenant_id == tenant_id,                Dataset.id == dataset_id            ).first()            # pass if dataset is not available            if not dataset:                continue            # pass if dataset is not available            if (dataset and dataset.available_document_count == 0                    and dataset.available_document_count == 0):                continue            available_datasets.append(dataset)        if retrieve_config.retrieve_strategy == DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE:            # get retrieval model config            default_retrieval_model = {                'search_method': 'semantic_search',                'reranking_enable': False,                'reranking_model': {                    'reranking_provider_name': '',                    'reranking_model_name': ''                },                'top_k': 2,                'score_threshold_enabled': False            }            for dataset in available_datasets:                retrieval_model_config = dataset.retrieval_model \                    if dataset.retrieval_model else default_retrieval_model                # get top k                top_k = retrieval_model_config['top_k']                # get score threshold                score_threshold = None                score_threshold_enabled = retrieval_model_config.get("score_threshold_enabled")                if score_threshold_enabled:                    score_threshold = retrieval_model_config.get("score_threshold")                tool = DatasetRetrieverTool.from_dataset(                    dataset=dataset,                    top_k=top_k,                    score_threshold=score_threshold,                    hit_callbacks=[hit_callback],                    return_resource=return_resource,                    retriever_from=invoke_from.to_source()                )                tools.append(tool)        elif retrieve_config.retrieve_strategy == DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE:            tool = DatasetMultiRetrieverTool.from_dataset(                dataset_ids=[dataset.id for dataset in available_datasets],                tenant_id=tenant_id,                top_k=retrieve_config.top_k or 2,                score_threshold=retrieve_config.score_threshold,                hit_callbacks=[hit_callback],                return_resource=return_resource,                retriever_from=invoke_from.to_source(),                reranking_provider_name=retrieve_config.reranking_model.get('reranking_provider_name'),                reranking_model_name=retrieve_config.reranking_model.get('reranking_model_name')            )            tools.append(tool)        return tools
 |