Преглед на файлове

Feat/delete single dataset retrival (#6570)

Jyong преди 9 месеца
родител
ревизия
e4bb943fe5

+ 9 - 3
api/core/app/app_config/easy_ui_based_app/dataset/manager.py

@@ -62,7 +62,12 @@ class DatasetConfigManager:
             return None
             return None
 
 
         # dataset configs
         # dataset configs
-        dataset_configs = config.get('dataset_configs', {'retrieval_model': 'single'})
+        if 'dataset_configs' in config and config.get('dataset_configs'):
+            dataset_configs = config.get('dataset_configs')
+        else:
+            dataset_configs = {
+                'retrieval_model': 'multiple'
+            }
         query_variable = config.get('dataset_query_variable')
         query_variable = config.get('dataset_query_variable')
 
 
         if dataset_configs['retrieval_model'] == 'single':
         if dataset_configs['retrieval_model'] == 'single':
@@ -83,9 +88,10 @@ class DatasetConfigManager:
                     retrieve_strategy=DatasetRetrieveConfigEntity.RetrieveStrategy.value_of(
                     retrieve_strategy=DatasetRetrieveConfigEntity.RetrieveStrategy.value_of(
                         dataset_configs['retrieval_model']
                         dataset_configs['retrieval_model']
                     ),
                     ),
-                    top_k=dataset_configs.get('top_k'),
+                    top_k=dataset_configs.get('top_k', 4),
                     score_threshold=dataset_configs.get('score_threshold'),
                     score_threshold=dataset_configs.get('score_threshold'),
-                    reranking_model=dataset_configs.get('reranking_model')
+                    reranking_model=dataset_configs.get('reranking_model'),
+                    weights=dataset_configs.get('weights')
                 )
                 )
             )
             )
 
 

+ 4 - 0
api/core/app/app_config/entities.py

@@ -159,7 +159,11 @@ class DatasetRetrieveConfigEntity(BaseModel):
     retrieve_strategy: RetrieveStrategy
     retrieve_strategy: RetrieveStrategy
     top_k: Optional[int] = None
     top_k: Optional[int] = None
     score_threshold: Optional[float] = None
     score_threshold: Optional[float] = None
+    rerank_mode: Optional[str] = 'reranking_model'
     reranking_model: Optional[dict] = None
     reranking_model: Optional[dict] = None
+    weights: Optional[dict] = None
+
+
 
 
 
 
 class DatasetEntity(BaseModel):
 class DatasetEntity(BaseModel):

+ 38 - 15
api/core/rag/data_post_processor/data_post_processor.py

@@ -5,15 +5,20 @@ from core.model_runtime.entities.model_entities import ModelType
 from core.model_runtime.errors.invoke import InvokeAuthorizationError
 from core.model_runtime.errors.invoke import InvokeAuthorizationError
 from core.rag.data_post_processor.reorder import ReorderRunner
 from core.rag.data_post_processor.reorder import ReorderRunner
 from core.rag.models.document import Document
 from core.rag.models.document import Document
-from core.rag.rerank.rerank import RerankRunner
+from core.rag.rerank.constants.rerank_mode import RerankMode
+from core.rag.rerank.entity.weight import KeywordSetting, VectorSetting, Weights
+from core.rag.rerank.rerank_model import RerankModelRunner
+from core.rag.rerank.weight_rerank import WeightRerankRunner
 
 
 
 
 class DataPostProcessor:
 class DataPostProcessor:
     """Interface for data post-processing document.
     """Interface for data post-processing document.
     """
     """
 
 
-    def __init__(self, tenant_id: str, reranking_model: dict, reorder_enabled: bool = False):
-        self.rerank_runner = self._get_rerank_runner(reranking_model, tenant_id)
+    def __init__(self, tenant_id: str, reranking_mode: str,
+                 reranking_model: Optional[dict] = None, weights: Optional[dict] = None,
+                 reorder_enabled: bool = False):
+        self.rerank_runner = self._get_rerank_runner(reranking_mode, tenant_id, reranking_model, weights)
         self.reorder_runner = self._get_reorder_runner(reorder_enabled)
         self.reorder_runner = self._get_reorder_runner(reorder_enabled)
 
 
     def invoke(self, query: str, documents: list[Document], score_threshold: Optional[float] = None,
     def invoke(self, query: str, documents: list[Document], score_threshold: Optional[float] = None,
@@ -26,19 +31,37 @@ class DataPostProcessor:
 
 
         return documents
         return documents
 
 
-    def _get_rerank_runner(self, reranking_model: dict, tenant_id: str) -> Optional[RerankRunner]:
-        if reranking_model:
-            try:
-                model_manager = ModelManager()
-                rerank_model_instance = model_manager.get_model_instance(
-                    tenant_id=tenant_id,
-                    provider=reranking_model['reranking_provider_name'],
-                    model_type=ModelType.RERANK,
-                    model=reranking_model['reranking_model_name']
+    def _get_rerank_runner(self, reranking_mode: str, tenant_id: str, reranking_model: Optional[dict] = None,
+                           weights: Optional[dict] = None) -> Optional[RerankModelRunner | WeightRerankRunner]:
+        if reranking_mode == RerankMode.WEIGHTED_SCORE.value and weights:
+            return WeightRerankRunner(
+                tenant_id,
+                Weights(
+                    weight_type=weights['weight_type'],
+                    vector_setting=VectorSetting(
+                        vector_weight=weights['vector_setting']['vector_weight'],
+                        embedding_provider_name=weights['vector_setting']['embedding_provider_name'],
+                        embedding_model_name=weights['vector_setting']['embedding_model_name'],
+                    ),
+                    keyword_setting=KeywordSetting(
+                        keyword_weight=weights['keyword_setting']['keyword_weight'],
+                    )
                 )
                 )
-            except InvokeAuthorizationError:
-                return None
-            return RerankRunner(rerank_model_instance)
+            )
+        elif reranking_mode == RerankMode.RERANKING_MODEL.value:
+            if reranking_model:
+                try:
+                    model_manager = ModelManager()
+                    rerank_model_instance = model_manager.get_model_instance(
+                        tenant_id=tenant_id,
+                        provider=reranking_model['reranking_provider_name'],
+                        model_type=ModelType.RERANK,
+                        model=reranking_model['reranking_model_name']
+                    )
+                except InvokeAuthorizationError:
+                    return None
+                return RerankModelRunner(rerank_model_instance)
+            return None
         return None
         return None
 
 
     def _get_reorder_runner(self, reorder_enabled) -> Optional[ReorderRunner]:
     def _get_reorder_runner(self, reorder_enabled) -> Optional[ReorderRunner]:

+ 2 - 1
api/core/rag/datasource/keyword/jieba/jieba_keyword_table_handler.py

@@ -1,4 +1,5 @@
 import re
 import re
+from typing import Optional
 
 
 import jieba
 import jieba
 from jieba.analyse import default_tfidf
 from jieba.analyse import default_tfidf
@@ -11,7 +12,7 @@ class JiebaKeywordTableHandler:
     def __init__(self):
     def __init__(self):
         default_tfidf.stop_words = STOPWORDS
         default_tfidf.stop_words = STOPWORDS
 
 
-    def extract_keywords(self, text: str, max_keywords_per_chunk: int = 10) -> set[str]:
+    def extract_keywords(self, text: str, max_keywords_per_chunk: Optional[int] = 10) -> set[str]:
         """Extract keywords with JIEBA tfidf."""
         """Extract keywords with JIEBA tfidf."""
         keywords = jieba.analyse.extract_tags(
         keywords = jieba.analyse.extract_tags(
             sentence=text,
             sentence=text,

+ 16 - 4
api/core/rag/datasource/retrieval_service.py

@@ -6,6 +6,7 @@ from flask import Flask, current_app
 from core.rag.data_post_processor.data_post_processor import DataPostProcessor
 from core.rag.data_post_processor.data_post_processor import DataPostProcessor
 from core.rag.datasource.keyword.keyword_factory import Keyword
 from core.rag.datasource.keyword.keyword_factory import Keyword
 from core.rag.datasource.vdb.vector_factory import Vector
 from core.rag.datasource.vdb.vector_factory import Vector
+from core.rag.rerank.constants.rerank_mode import RerankMode
 from core.rag.retrieval.retrival_methods import RetrievalMethod
 from core.rag.retrieval.retrival_methods import RetrievalMethod
 from extensions.ext_database import db
 from extensions.ext_database import db
 from models.dataset import Dataset
 from models.dataset import Dataset
@@ -26,13 +27,19 @@ class RetrievalService:
 
 
     @classmethod
     @classmethod
     def retrieve(cls, retrival_method: str, dataset_id: str, query: str,
     def retrieve(cls, retrival_method: str, dataset_id: str, query: str,
-                 top_k: int, score_threshold: Optional[float] = .0, reranking_model: Optional[dict] = None):
+                 top_k: int, score_threshold: Optional[float] = .0,
+                 reranking_model: Optional[dict] = None, reranking_mode: Optional[str] = None,
+                 weights: Optional[dict] = None):
         dataset = db.session.query(Dataset).filter(
         dataset = db.session.query(Dataset).filter(
             Dataset.id == dataset_id
             Dataset.id == dataset_id
         ).first()
         ).first()
         if not dataset or dataset.available_document_count == 0 or dataset.available_segment_count == 0:
         if not dataset or dataset.available_document_count == 0 or dataset.available_segment_count == 0:
             return []
             return []
         all_documents = []
         all_documents = []
+        keyword_search_documents = []
+        embedding_search_documents = []
+        full_text_search_documents = []
+        hybrid_search_documents = []
         threads = []
         threads = []
         exceptions = []
         exceptions = []
         # retrieval_model source with keyword
         # retrieval_model source with keyword
@@ -87,7 +94,8 @@ class RetrievalService:
             raise Exception(exception_message)
             raise Exception(exception_message)
 
 
         if retrival_method == RetrievalMethod.HYBRID_SEARCH.value:
         if retrival_method == RetrievalMethod.HYBRID_SEARCH.value:
-            data_post_processor = DataPostProcessor(str(dataset.tenant_id), reranking_model, False)
+            data_post_processor = DataPostProcessor(str(dataset.tenant_id), reranking_mode,
+                                                    reranking_model, weights, False)
             all_documents = data_post_processor.invoke(
             all_documents = data_post_processor.invoke(
                 query=query,
                 query=query,
                 documents=all_documents,
                 documents=all_documents,
@@ -143,7 +151,9 @@ class RetrievalService:
 
 
                 if documents:
                 if documents:
                     if reranking_model and retrival_method == RetrievalMethod.SEMANTIC_SEARCH.value:
                     if reranking_model and retrival_method == RetrievalMethod.SEMANTIC_SEARCH.value:
-                        data_post_processor = DataPostProcessor(str(dataset.tenant_id), reranking_model, False)
+                        data_post_processor = DataPostProcessor(str(dataset.tenant_id),
+                                                                RerankMode.RERANKING_MODEL.value,
+                                                                reranking_model, None, False)
                         all_documents.extend(data_post_processor.invoke(
                         all_documents.extend(data_post_processor.invoke(
                             query=query,
                             query=query,
                             documents=documents,
                             documents=documents,
@@ -175,7 +185,9 @@ class RetrievalService:
                 )
                 )
                 if documents:
                 if documents:
                     if reranking_model and retrival_method == RetrievalMethod.FULL_TEXT_SEARCH.value:
                     if reranking_model and retrival_method == RetrievalMethod.FULL_TEXT_SEARCH.value:
-                        data_post_processor = DataPostProcessor(str(dataset.tenant_id), reranking_model, False)
+                        data_post_processor = DataPostProcessor(str(dataset.tenant_id),
+                                                                RerankMode.RERANKING_MODEL.value,
+                                                                reranking_model, None, False)
                         all_documents.extend(data_post_processor.invoke(
                         all_documents.extend(data_post_processor.invoke(
                             query=query,
                             query=query,
                             documents=documents,
                             documents=documents,

+ 4 - 2
api/core/rag/datasource/vdb/qdrant/qdrant_vector.py

@@ -396,9 +396,11 @@ class QdrantVector(BaseVector):
         documents = []
         documents = []
         for result in results:
         for result in results:
             if result:
             if result:
-                documents.append(self._document_from_scored_point(
+                document = self._document_from_scored_point(
                     result, Field.CONTENT_KEY.value, Field.METADATA_KEY.value
                     result, Field.CONTENT_KEY.value, Field.METADATA_KEY.value
-                ))
+                )
+                document.metadata['vector'] = result.vector
+                documents.append(document)
 
 
         return documents
         return documents
 
 

+ 0 - 0
api/core/rag/docstore/__init__.py


+ 8 - 0
api/core/rag/rerank/constants/rerank_mode.py

@@ -0,0 +1,8 @@
+from enum import Enum
+
+
+class RerankMode(Enum):
+
+    RERANKING_MODEL = 'reranking_model'
+    WEIGHTED_SCORE = 'weighted_score'
+

+ 23 - 0
api/core/rag/rerank/entity/weight.py

@@ -0,0 +1,23 @@
+from pydantic import BaseModel
+
+
+class VectorSetting(BaseModel):
+    vector_weight: float
+
+    embedding_provider_name: str
+
+    embedding_model_name: str
+
+
+class KeywordSetting(BaseModel):
+    keyword_weight: float
+
+
+class Weights(BaseModel):
+    """Model for weighted rerank."""
+
+    weight_type: str
+
+    vector_setting: VectorSetting
+
+    keyword_setting: KeywordSetting

+ 1 - 1
api/core/rag/rerank/rerank.py

@@ -4,7 +4,7 @@ from core.model_manager import ModelInstance
 from core.rag.models.document import Document
 from core.rag.models.document import Document
 
 
 
 
-class RerankRunner:
+class RerankModelRunner:
     def __init__(self, rerank_model_instance: ModelInstance) -> None:
     def __init__(self, rerank_model_instance: ModelInstance) -> None:
         self.rerank_model_instance = rerank_model_instance
         self.rerank_model_instance = rerank_model_instance
 
 

+ 178 - 0
api/core/rag/rerank/weight_rerank.py

@@ -0,0 +1,178 @@
+import math
+from collections import Counter
+from typing import Optional
+
+import numpy as np
+
+from core.embedding.cached_embedding import CacheEmbedding
+from core.model_manager import ModelManager
+from core.model_runtime.entities.model_entities import ModelType
+from core.rag.datasource.keyword.jieba.jieba_keyword_table_handler import JiebaKeywordTableHandler
+from core.rag.models.document import Document
+from core.rag.rerank.entity.weight import VectorSetting, Weights
+
+
+class WeightRerankRunner:
+
+    def __init__(self, tenant_id: str, weights: Weights) -> None:
+        self.tenant_id = tenant_id
+        self.weights = weights
+
+    def run(self, query: str, documents: list[Document], score_threshold: Optional[float] = None,
+            top_n: Optional[int] = None, user: Optional[str] = None) -> list[Document]:
+        """
+        Run rerank model
+        :param query: search query
+        :param documents: documents for reranking
+        :param score_threshold: score threshold
+        :param top_n: top n
+        :param user: unique user id if needed
+
+        :return:
+        """
+        docs = []
+        doc_id = []
+        unique_documents = []
+        for document in documents:
+            if document.metadata['doc_id'] not in doc_id:
+                doc_id.append(document.metadata['doc_id'])
+                docs.append(document.page_content)
+                unique_documents.append(document)
+
+        documents = unique_documents
+
+        rerank_documents = []
+        query_scores = self._calculate_keyword_score(query, documents)
+
+        query_vector_scores = self._calculate_cosine(self.tenant_id, query, documents, self.weights.vector_setting)
+        for document, query_score, query_vector_score in zip(documents, query_scores, query_vector_scores):
+            # format document
+            score = self.weights.vector_setting.vector_weight * query_vector_score + \
+                    self.weights.keyword_setting.keyword_weight * query_score
+            if score_threshold and score < score_threshold:
+                continue
+            document.metadata['score'] = score
+            rerank_documents.append(document)
+        rerank_documents = sorted(rerank_documents, key=lambda x: x.metadata['score'], reverse=True)
+        return rerank_documents[:top_n] if top_n else rerank_documents
+
+    def _calculate_keyword_score(self, query: str, documents: list[Document]) -> list[float]:
+        """
+        Calculate BM25 scores
+        :param query: search query
+        :param documents: documents for reranking
+
+        :return:
+        """
+        keyword_table_handler = JiebaKeywordTableHandler()
+        query_keywords = keyword_table_handler.extract_keywords(query, None)
+        documents_keywords = []
+        for document in documents:
+            # get the document keywords
+            document_keywords = keyword_table_handler.extract_keywords(document.page_content, None)
+            document.metadata['keywords'] = document_keywords
+            documents_keywords.append(document_keywords)
+
+        # Counter query keywords(TF)
+        query_keyword_counts = Counter(query_keywords)
+
+        # total documents
+        total_documents = len(documents)
+
+        # calculate all documents' keywords IDF
+        all_keywords = set()
+        for document_keywords in documents_keywords:
+            all_keywords.update(document_keywords)
+
+        keyword_idf = {}
+        for keyword in all_keywords:
+            # calculate include query keywords' documents
+            doc_count_containing_keyword = sum(1 for doc_keywords in documents_keywords if keyword in doc_keywords)
+            # IDF
+            keyword_idf[keyword] = math.log((1 + total_documents) / (1 + doc_count_containing_keyword)) + 1
+
+        query_tfidf = {}
+
+        for keyword, count in query_keyword_counts.items():
+            tf = count
+            idf = keyword_idf.get(keyword, 0)
+            query_tfidf[keyword] = tf * idf
+
+        # calculate all documents' TF-IDF
+        documents_tfidf = []
+        for document_keywords in documents_keywords:
+            document_keyword_counts = Counter(document_keywords)
+            document_tfidf = {}
+            for keyword, count in document_keyword_counts.items():
+                tf = count
+                idf = keyword_idf.get(keyword, 0)
+                document_tfidf[keyword] = tf * idf
+            documents_tfidf.append(document_tfidf)
+
+        def cosine_similarity(vec1, vec2):
+            intersection = set(vec1.keys()) & set(vec2.keys())
+            numerator = sum(vec1[x] * vec2[x] for x in intersection)
+
+            sum1 = sum(vec1[x] ** 2 for x in vec1.keys())
+            sum2 = sum(vec2[x] ** 2 for x in vec2.keys())
+            denominator = math.sqrt(sum1) * math.sqrt(sum2)
+
+            if not denominator:
+                return 0.0
+            else:
+                return float(numerator) / denominator
+
+        similarities = []
+        for document_tfidf in documents_tfidf:
+            similarity = cosine_similarity(query_tfidf, document_tfidf)
+            similarities.append(similarity)
+
+        # for idx, similarity in enumerate(similarities):
+        #     print(f"Document {idx + 1} similarity: {similarity}")
+
+        return similarities
+
+    def _calculate_cosine(self, tenant_id: str, query: str, documents: list[Document],
+                          vector_setting: VectorSetting) -> list[float]:
+        """
+        Calculate Cosine scores
+        :param query: search query
+        :param documents: documents for reranking
+
+        :return:
+        """
+        query_vector_scores = []
+
+        model_manager = ModelManager()
+
+        embedding_model = model_manager.get_model_instance(
+            tenant_id=tenant_id,
+            provider=vector_setting.embedding_provider_name,
+            model_type=ModelType.TEXT_EMBEDDING,
+            model=vector_setting.embedding_model_name
+
+        )
+        cache_embedding = CacheEmbedding(embedding_model)
+        query_vector = cache_embedding.embed_query(query)
+        for document in documents:
+            # calculate cosine similarity
+            if 'score' in document.metadata:
+                query_vector_scores.append(document.metadata['score'])
+            else:
+                content_vector = document.metadata['vector']
+                # transform to NumPy
+                vec1 = np.array(query_vector)
+                vec2 = np.array(document.metadata['vector'])
+
+                # calculate dot product
+                dot_product = np.dot(vec1, vec2)
+
+                # calculate norm
+                norm_vec1 = np.linalg.norm(vec1)
+                norm_vec2 = np.linalg.norm(vec2)
+
+                # calculate cosine similarity
+                cosine_sim = dot_product / (norm_vec1 * norm_vec2)
+                query_vector_scores.append(cosine_sim)
+
+        return query_vector_scores

+ 124 - 22
api/core/rag/retrieval/dataset_retrieval.py

@@ -1,4 +1,6 @@
+import math
 import threading
 import threading
+from collections import Counter
 from typing import Optional, cast
 from typing import Optional, cast
 
 
 from flask import Flask, current_app
 from flask import Flask, current_app
@@ -14,9 +16,10 @@ from core.model_runtime.entities.model_entities import ModelFeature, ModelType
 from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
 from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
 from core.ops.ops_trace_manager import TraceQueueManager, TraceTask, TraceTaskName
 from core.ops.ops_trace_manager import TraceQueueManager, TraceTask, TraceTaskName
 from core.ops.utils import measure_time
 from core.ops.utils import measure_time
+from core.rag.data_post_processor.data_post_processor import DataPostProcessor
+from core.rag.datasource.keyword.jieba.jieba_keyword_table_handler import JiebaKeywordTableHandler
 from core.rag.datasource.retrieval_service import RetrievalService
 from core.rag.datasource.retrieval_service import RetrievalService
 from core.rag.models.document import Document
 from core.rag.models.document import Document
-from core.rag.rerank.rerank import RerankRunner
 from core.rag.retrieval.retrival_methods import RetrievalMethod
 from core.rag.retrieval.retrival_methods import RetrievalMethod
 from core.rag.retrieval.router.multi_dataset_function_call_router import FunctionCallMultiDatasetRouter
 from core.rag.retrieval.router.multi_dataset_function_call_router import FunctionCallMultiDatasetRouter
 from core.rag.retrieval.router.multi_dataset_react_route import ReactMultiDatasetRouter
 from core.rag.retrieval.router.multi_dataset_react_route import ReactMultiDatasetRouter
@@ -132,8 +135,9 @@ class DatasetRetrieval:
                 app_id, tenant_id, user_id, user_from,
                 app_id, tenant_id, user_id, user_from,
                 available_datasets, query, retrieve_config.top_k,
                 available_datasets, query, retrieve_config.top_k,
                 retrieve_config.score_threshold,
                 retrieve_config.score_threshold,
-                retrieve_config.reranking_model.get('reranking_provider_name'),
-                retrieve_config.reranking_model.get('reranking_model_name'),
+                retrieve_config.rerank_mode,
+                retrieve_config.reranking_model,
+                retrieve_config.weights,
                 message_id,
                 message_id,
             )
             )
 
 
@@ -272,7 +276,8 @@ class DatasetRetrieval:
                         retrival_method=retrival_method, dataset_id=dataset.id,
                         retrival_method=retrival_method, dataset_id=dataset.id,
                         query=query,
                         query=query,
                         top_k=top_k, score_threshold=score_threshold,
                         top_k=top_k, score_threshold=score_threshold,
-                        reranking_model=reranking_model
+                        reranking_model=reranking_model,
+                        weights=retrieval_model_config.get('weights', None),
                     )
                     )
                 self._on_query(query, [dataset_id], app_id, user_from, user_id)
                 self._on_query(query, [dataset_id], app_id, user_from, user_id)
 
 
@@ -292,14 +297,18 @@ class DatasetRetrieval:
             query: str,
             query: str,
             top_k: int,
             top_k: int,
             score_threshold: float,
             score_threshold: float,
-            reranking_provider_name: str,
-            reranking_model_name: str,
+            reranking_mode: str,
+            reranking_model: Optional[dict] = None,
+            weights: Optional[dict] = None,
+            reranking_enable: bool = True,
             message_id: Optional[str] = None,
             message_id: Optional[str] = None,
     ):
     ):
         threads = []
         threads = []
         all_documents = []
         all_documents = []
         dataset_ids = [dataset.id for dataset in available_datasets]
         dataset_ids = [dataset.id for dataset in available_datasets]
+        index_type = None
         for dataset in available_datasets:
         for dataset in available_datasets:
+            index_type = dataset.indexing_technique
             retrieval_thread = threading.Thread(target=self._retriever, kwargs={
             retrieval_thread = threading.Thread(target=self._retriever, kwargs={
                 'flask_app': current_app._get_current_object(),
                 'flask_app': current_app._get_current_object(),
                 'dataset_id': dataset.id,
                 'dataset_id': dataset.id,
@@ -311,23 +320,24 @@ class DatasetRetrieval:
             retrieval_thread.start()
             retrieval_thread.start()
         for thread in threads:
         for thread in threads:
             thread.join()
             thread.join()
-        # do rerank for searched documents
-        model_manager = ModelManager()
-        rerank_model_instance = model_manager.get_model_instance(
-            tenant_id=tenant_id,
-            provider=reranking_provider_name,
-            model_type=ModelType.RERANK,
-            model=reranking_model_name
-        )
 
 
-        rerank_runner = RerankRunner(rerank_model_instance)
+        if reranking_enable:
+            # do rerank for searched documents
+            data_post_processor = DataPostProcessor(tenant_id, reranking_mode,
+                                                    reranking_model, weights, False)
 
 
-        with measure_time() as timer:
-            all_documents = rerank_runner.run(
-                query, all_documents,
-                score_threshold,
-                top_k
-            )
+            with measure_time() as timer:
+                all_documents = data_post_processor.invoke(
+                    query=query,
+                    documents=all_documents,
+                    score_threshold=score_threshold,
+                    top_n=top_k
+                )
+        else:
+            if index_type == "economy":
+                all_documents = self.calculate_keyword_score(query, all_documents, top_k)
+            elif index_type == "high_quality":
+                all_documents = self.calculate_vector_score(all_documents, top_k, score_threshold)
         self._on_query(query, dataset_ids, app_id, user_from, user_id)
         self._on_query(query, dataset_ids, app_id, user_from, user_id)
 
 
         if all_documents:
         if all_documents:
@@ -420,7 +430,8 @@ class DatasetRetrieval:
                                                           score_threshold=retrieval_model['score_threshold']
                                                           score_threshold=retrieval_model['score_threshold']
                                                           if retrieval_model['score_threshold_enabled'] else None,
                                                           if retrieval_model['score_threshold_enabled'] else None,
                                                           reranking_model=retrieval_model['reranking_model']
                                                           reranking_model=retrieval_model['reranking_model']
-                                                          if retrieval_model['reranking_enable'] else None
+                                                          if retrieval_model['reranking_enable'] else None,
+                                                          weights=retrieval_model.get('weights', None),
                                                           )
                                                           )
 
 
                     all_documents.extend(documents)
                     all_documents.extend(documents)
@@ -513,3 +524,94 @@ class DatasetRetrieval:
             tools.append(tool)
             tools.append(tool)
 
 
         return tools
         return tools
+
+    def calculate_keyword_score(self, query: str, documents: list[Document], top_k: int) -> list[Document]:
+        """
+        Calculate keywords scores
+        :param query: search query
+        :param documents: documents for reranking
+
+        :return:
+        """
+        keyword_table_handler = JiebaKeywordTableHandler()
+        query_keywords = keyword_table_handler.extract_keywords(query, None)
+        documents_keywords = []
+        for document in documents:
+            # get the document keywords
+            document_keywords = keyword_table_handler.extract_keywords(document.page_content, None)
+            document.metadata['keywords'] = document_keywords
+            documents_keywords.append(document_keywords)
+
+        # Counter query keywords(TF)
+        query_keyword_counts = Counter(query_keywords)
+
+        # total documents
+        total_documents = len(documents)
+
+        # calculate all documents' keywords IDF
+        all_keywords = set()
+        for document_keywords in documents_keywords:
+            all_keywords.update(document_keywords)
+
+        keyword_idf = {}
+        for keyword in all_keywords:
+            # calculate include query keywords' documents
+            doc_count_containing_keyword = sum(1 for doc_keywords in documents_keywords if keyword in doc_keywords)
+            # IDF
+            keyword_idf[keyword] = math.log((1 + total_documents) / (1 + doc_count_containing_keyword)) + 1
+
+        query_tfidf = {}
+
+        for keyword, count in query_keyword_counts.items():
+            tf = count
+            idf = keyword_idf.get(keyword, 0)
+            query_tfidf[keyword] = tf * idf
+
+        # calculate all documents' TF-IDF
+        documents_tfidf = []
+        for document_keywords in documents_keywords:
+            document_keyword_counts = Counter(document_keywords)
+            document_tfidf = {}
+            for keyword, count in document_keyword_counts.items():
+                tf = count
+                idf = keyword_idf.get(keyword, 0)
+                document_tfidf[keyword] = tf * idf
+            documents_tfidf.append(document_tfidf)
+
+        def cosine_similarity(vec1, vec2):
+            intersection = set(vec1.keys()) & set(vec2.keys())
+            numerator = sum(vec1[x] * vec2[x] for x in intersection)
+
+            sum1 = sum(vec1[x] ** 2 for x in vec1.keys())
+            sum2 = sum(vec2[x] ** 2 for x in vec2.keys())
+            denominator = math.sqrt(sum1) * math.sqrt(sum2)
+
+            if not denominator:
+                return 0.0
+            else:
+                return float(numerator) / denominator
+
+        similarities = []
+        for document_tfidf in documents_tfidf:
+            similarity = cosine_similarity(query_tfidf, document_tfidf)
+            similarities.append(similarity)
+
+        for document, score in zip(documents, similarities):
+            # format document
+            document.metadata['score'] = score
+        documents = sorted(documents, key=lambda x: x.metadata['score'], reverse=True)
+        return documents[:top_k] if top_k else documents
+
+    def calculate_vector_score(self, all_documents: list[Document],
+                               top_k: int, score_threshold: float) -> list[Document]:
+        filter_documents = []
+        for document in all_documents:
+            if document.metadata['score'] >= score_threshold:
+                filter_documents.append(document)
+        if not filter_documents:
+            return []
+        filter_documents = sorted(filter_documents, key=lambda x: x.metadata['score'], reverse=True)
+        return filter_documents[:top_k] if top_k else filter_documents
+
+
+

+ 0 - 0
api/core/rag/splitter/__init__.py


+ 4 - 3
api/core/tools/tool/dataset_retriever/dataset_multi_retriever_tool.py

@@ -7,7 +7,7 @@ from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCa
 from core.model_manager import ModelManager
 from core.model_manager import ModelManager
 from core.model_runtime.entities.model_entities import ModelType
 from core.model_runtime.entities.model_entities import ModelType
 from core.rag.datasource.retrieval_service import RetrievalService
 from core.rag.datasource.retrieval_service import RetrievalService
-from core.rag.rerank.rerank import RerankRunner
+from core.rag.rerank.rerank_model import RerankModelRunner
 from core.rag.retrieval.retrival_methods import RetrievalMethod
 from core.rag.retrieval.retrival_methods import RetrievalMethod
 from core.tools.tool.dataset_retriever.dataset_retriever_base_tool import DatasetRetrieverBaseTool
 from core.tools.tool.dataset_retriever.dataset_retriever_base_tool import DatasetRetrieverBaseTool
 from extensions.ext_database import db
 from extensions.ext_database import db
@@ -72,7 +72,7 @@ class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool):
             model=self.reranking_model_name
             model=self.reranking_model_name
         )
         )
 
 
-        rerank_runner = RerankRunner(rerank_model_instance)
+        rerank_runner = RerankModelRunner(rerank_model_instance)
         all_documents = rerank_runner.run(query, all_documents, self.score_threshold, self.top_k)
         all_documents = rerank_runner.run(query, all_documents, self.score_threshold, self.top_k)
 
 
         for hit_callback in self.hit_callbacks:
         for hit_callback in self.hit_callbacks:
@@ -180,7 +180,8 @@ class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool):
                                                           score_threshold=retrieval_model['score_threshold']
                                                           score_threshold=retrieval_model['score_threshold']
                                                           if retrieval_model['score_threshold_enabled'] else None,
                                                           if retrieval_model['score_threshold_enabled'] else None,
                                                           reranking_model=retrieval_model['reranking_model']
                                                           reranking_model=retrieval_model['reranking_model']
-                                                          if retrieval_model['reranking_enable'] else None
+                                                          if retrieval_model['reranking_enable'] else None,
+                                                          weights=retrieval_model.get('weights', None),
                                                           )
                                                           )
 
 
                     all_documents.extend(documents)
                     all_documents.extend(documents)

+ 2 - 1
api/core/tools/tool/dataset_retriever/dataset_retriever_tool.py

@@ -78,7 +78,8 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool):
                                                       score_threshold=retrieval_model['score_threshold']
                                                       score_threshold=retrieval_model['score_threshold']
                                                       if retrieval_model['score_threshold_enabled'] else None,
                                                       if retrieval_model['score_threshold_enabled'] else None,
                                                       reranking_model=retrieval_model['reranking_model']
                                                       reranking_model=retrieval_model['reranking_model']
-                                                      if retrieval_model['reranking_enable'] else None
+                                                      if retrieval_model['reranking_enable'] else None,
+                                                      weights=retrieval_model.get('weights', None),
                                                       )
                                                       )
             else:
             else:
                 documents = []
                 documents = []

+ 28 - 0
api/core/workflow/nodes/knowledge_retrieval/entities.py

@@ -13,13 +13,41 @@ class RerankingModelConfig(BaseModel):
     model: str
     model: str
 
 
 
 
+class VectorSetting(BaseModel):
+    """
+    Vector Setting.
+    """
+    vector_weight: float
+    embedding_provider_name: str
+    embedding_model_name: str
+
+
+class KeywordSetting(BaseModel):
+    """
+    Keyword Setting.
+    """
+    keyword_weight: float
+
+
+class WeightedScoreConfig(BaseModel):
+    """
+    Weighted score Config.
+    """
+    weight_type: str
+    vector_setting: VectorSetting
+    keyword_setting: KeywordSetting
+
+
 class MultipleRetrievalConfig(BaseModel):
 class MultipleRetrievalConfig(BaseModel):
     """
     """
     Multiple Retrieval Config.
     Multiple Retrieval Config.
     """
     """
     top_k: int
     top_k: int
     score_threshold: Optional[float] = None
     score_threshold: Optional[float] = None
+    reranking_mode: str = 'reranking_model'
+    reranking_enable: bool = True
     reranking_model: RerankingModelConfig
     reranking_model: RerankingModelConfig
+    weights: WeightedScoreConfig
 
 
 
 
 class ModelConfig(BaseModel):
 class ModelConfig(BaseModel):

+ 27 - 2
api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py

@@ -138,13 +138,38 @@ class KnowledgeRetrievalNode(BaseNode):
                     planning_strategy=planning_strategy
                     planning_strategy=planning_strategy
                 )
                 )
         elif node_data.retrieval_mode == DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE.value:
         elif node_data.retrieval_mode == DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE.value:
+            if node_data.multiple_retrieval_config.reranking_mode == 'reranking_model':
+                reranking_model = {
+                    'reranking_provider_name': node_data.multiple_retrieval_config.reranking_model['provider'],
+                    'reranking_model_name': node_data.multiple_retrieval_config.reranking_model['name']
+                }
+                weights = None
+            elif node_data.multiple_retrieval_config.reranking_mode == 'weighted_score':
+                reranking_model = None
+                weights = {
+                    'weight_type': node_data.multiple_retrieval_config.weights.weight_type,
+                    'vector_setting': {
+                        "vector_weight": node_data.multiple_retrieval_config.weights.vector_setting.vector_weight,
+                        "embedding_provider_name": node_data.multiple_retrieval_config.weights.vector_setting.embedding_provider_name,
+                        "embedding_model_name": node_data.multiple_retrieval_config.weights.vector_setting.embedding_model_name,
+                    },
+                    'keyword_setting': {
+                        "keyword_weight": node_data.multiple_retrieval_config.weights.keyword_setting.keyword_weight
+                    }
+                }
+            else:
+                reranking_model = None
+                weights = None
             all_documents = dataset_retrieval.multiple_retrieve(self.app_id, self.tenant_id, self.user_id,
             all_documents = dataset_retrieval.multiple_retrieve(self.app_id, self.tenant_id, self.user_id,
                                                                 self.user_from.value,
                                                                 self.user_from.value,
                                                                 available_datasets, query,
                                                                 available_datasets, query,
                                                                 node_data.multiple_retrieval_config.top_k,
                                                                 node_data.multiple_retrieval_config.top_k,
                                                                 node_data.multiple_retrieval_config.score_threshold,
                                                                 node_data.multiple_retrieval_config.score_threshold,
-                                                                node_data.multiple_retrieval_config.reranking_model.provider,
-                                                                node_data.multiple_retrieval_config.reranking_model.model)
+                                                                node_data.multiple_retrieval_config.reranking_mode,
+                                                                reranking_model,
+                                                                weights,
+                                                                node_data.multiple_retrieval_config.reranking_enable,
+                                                                )
 
 
         context_list = []
         context_list = []
         if all_documents:
         if all_documents:

+ 18 - 0
api/fields/dataset_fields.py

@@ -18,10 +18,28 @@ reranking_model_fields = {
     'reranking_model_name': fields.String
     'reranking_model_name': fields.String
 }
 }
 
 
+keyword_setting_fields = {
+    'keyword_weight': fields.Float
+}
+
+vector_setting_fields = {
+    'vector_weight': fields.Float,
+    'embedding_model_name': fields.String,
+    'embedding_provider_name': fields.String,
+}
+
+weighted_score_fields = {
+    'weight_type': fields.String,
+    'keyword_setting': fields.Nested(keyword_setting_fields),
+    'vector_setting': fields.Nested(vector_setting_fields),
+}
+
 dataset_retrieval_model_fields = {
 dataset_retrieval_model_fields = {
     'search_method': fields.String,
     'search_method': fields.String,
     'reranking_enable': fields.Boolean,
     'reranking_enable': fields.Boolean,
+    'reranking_mode': fields.String,
     'reranking_model': fields.Nested(reranking_model_fields),
     'reranking_model': fields.Nested(reranking_model_fields),
+    'weights': fields.Nested(weighted_score_fields, allow_null=True),
     'top_k': fields.Integer,
     'top_k': fields.Integer,
     'score_threshold_enabled': fields.Boolean,
     'score_threshold_enabled': fields.Boolean,
     'score_threshold': fields.Float
     'score_threshold': fields.Float

+ 3 - 1
api/models/model.py

@@ -328,7 +328,9 @@ class AppModelConfig(db.Model):
                 return {'retrieval_model': 'single'}
                 return {'retrieval_model': 'single'}
             else:
             else:
                 return dataset_configs
                 return dataset_configs
-        return {'retrieval_model': 'single'}
+        return {
+                'retrieval_model': 'multiple',
+            }
 
 
     @property
     @property
     def file_upload_dict(self) -> dict:
     def file_upload_dict(self) -> dict:

Файловите разлики са ограничени, защото са твърде много
+ 154 - 55
api/poetry.lock


+ 3 - 2
api/pyproject.toml

@@ -163,6 +163,7 @@ redis = { version = "~5.0.3", extras = ["hiredis"] }
 replicate = "~0.22.0"
 replicate = "~0.22.0"
 resend = "~0.7.0"
 resend = "~0.7.0"
 safetensors = "~0.4.3"
 safetensors = "~0.4.3"
+scikit-learn = "^1.5.1"
 sentry-sdk = { version = "~1.44.1", extras = ["flask"] }
 sentry-sdk = { version = "~1.44.1", extras = ["flask"] }
 sqlalchemy = "~2.0.29"
 sqlalchemy = "~2.0.29"
 tencentcloud-sdk-python-hunyuan = "~3.0.1158"
 tencentcloud-sdk-python-hunyuan = "~3.0.1158"
@@ -175,7 +176,7 @@ werkzeug = "~3.0.1"
 xinference-client = "0.9.4"
 xinference-client = "0.9.4"
 yarl = "~1.9.4"
 yarl = "~1.9.4"
 zhipuai = "1.0.7"
 zhipuai = "1.0.7"
-
+rank-bm25 = "~0.2.2"
 ############################################################
 ############################################################
 # Tool dependencies required by tool implementations
 # Tool dependencies required by tool implementations
 ############################################################
 ############################################################
@@ -200,7 +201,7 @@ cloudscraper = "1.2.71"
 ############################################################
 ############################################################
 
 
 [tool.poetry.group.vdb.dependencies]
 [tool.poetry.group.vdb.dependencies]
-chromadb = "~0.5.1"
+chromadb = "0.5.1"
 oracledb = "~2.2.1"
 oracledb = "~2.2.1"
 pgvecto-rs = "0.1.4"
 pgvecto-rs = "0.1.4"
 pgvector = "0.2.5"
 pgvector = "0.2.5"

+ 5 - 3
api/services/hit_testing_service.py

@@ -38,14 +38,16 @@ class HitTestingService:
         if not retrieval_model:
         if not retrieval_model:
             retrieval_model = dataset.retrieval_model if dataset.retrieval_model else default_retrieval_model
             retrieval_model = dataset.retrieval_model if dataset.retrieval_model else default_retrieval_model
 
 
-        all_documents = RetrievalService.retrieve(retrival_method=retrieval_model['search_method'],
+        all_documents = RetrievalService.retrieve(retrival_method=retrieval_model.get('search_method', 'semantic_search'),
                                                   dataset_id=dataset.id,
                                                   dataset_id=dataset.id,
                                                   query=cls.escape_query_for_search(query),
                                                   query=cls.escape_query_for_search(query),
-                                                  top_k=retrieval_model['top_k'],
+                                                  top_k=retrieval_model.get('top_k', 2),
                                                   score_threshold=retrieval_model['score_threshold']
                                                   score_threshold=retrieval_model['score_threshold']
                                                   if retrieval_model['score_threshold_enabled'] else None,
                                                   if retrieval_model['score_threshold_enabled'] else None,
                                                   reranking_model=retrieval_model['reranking_model']
                                                   reranking_model=retrieval_model['reranking_model']
-                                                  if retrieval_model['reranking_enable'] else None
+                                                  if retrieval_model['reranking_enable'] else None,
+                                                  reranking_mode=retrieval_model.get('reranking_mode', None),
+                                                  weights=retrieval_model.get('weights', None),
                                                   )
                                                   )
 
 
         end = time.perf_counter()
         end = time.perf_counter()