|
@@ -1,4 +1,6 @@
|
|
|
|
+import math
|
|
import threading
|
|
import threading
|
|
|
|
+from collections import Counter
|
|
from typing import Optional, cast
|
|
from typing import Optional, cast
|
|
|
|
|
|
from flask import Flask, current_app
|
|
from flask import Flask, current_app
|
|
@@ -14,9 +16,10 @@ from core.model_runtime.entities.model_entities import ModelFeature, ModelType
|
|
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
|
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
|
from core.ops.ops_trace_manager import TraceQueueManager, TraceTask, TraceTaskName
|
|
from core.ops.ops_trace_manager import TraceQueueManager, TraceTask, TraceTaskName
|
|
from core.ops.utils import measure_time
|
|
from core.ops.utils import measure_time
|
|
|
|
+from core.rag.data_post_processor.data_post_processor import DataPostProcessor
|
|
|
|
+from core.rag.datasource.keyword.jieba.jieba_keyword_table_handler import JiebaKeywordTableHandler
|
|
from core.rag.datasource.retrieval_service import RetrievalService
|
|
from core.rag.datasource.retrieval_service import RetrievalService
|
|
from core.rag.models.document import Document
|
|
from core.rag.models.document import Document
|
|
-from core.rag.rerank.rerank import RerankRunner
|
|
|
|
from core.rag.retrieval.retrival_methods import RetrievalMethod
|
|
from core.rag.retrieval.retrival_methods import RetrievalMethod
|
|
from core.rag.retrieval.router.multi_dataset_function_call_router import FunctionCallMultiDatasetRouter
|
|
from core.rag.retrieval.router.multi_dataset_function_call_router import FunctionCallMultiDatasetRouter
|
|
from core.rag.retrieval.router.multi_dataset_react_route import ReactMultiDatasetRouter
|
|
from core.rag.retrieval.router.multi_dataset_react_route import ReactMultiDatasetRouter
|
|
@@ -132,8 +135,9 @@ class DatasetRetrieval:
|
|
app_id, tenant_id, user_id, user_from,
|
|
app_id, tenant_id, user_id, user_from,
|
|
available_datasets, query, retrieve_config.top_k,
|
|
available_datasets, query, retrieve_config.top_k,
|
|
retrieve_config.score_threshold,
|
|
retrieve_config.score_threshold,
|
|
- retrieve_config.reranking_model.get('reranking_provider_name'),
|
|
|
|
- retrieve_config.reranking_model.get('reranking_model_name'),
|
|
|
|
|
|
+ retrieve_config.rerank_mode,
|
|
|
|
+ retrieve_config.reranking_model,
|
|
|
|
+ retrieve_config.weights,
|
|
message_id,
|
|
message_id,
|
|
)
|
|
)
|
|
|
|
|
|
@@ -272,7 +276,8 @@ class DatasetRetrieval:
|
|
retrival_method=retrival_method, dataset_id=dataset.id,
|
|
retrival_method=retrival_method, dataset_id=dataset.id,
|
|
query=query,
|
|
query=query,
|
|
top_k=top_k, score_threshold=score_threshold,
|
|
top_k=top_k, score_threshold=score_threshold,
|
|
- reranking_model=reranking_model
|
|
|
|
|
|
+ reranking_model=reranking_model,
|
|
|
|
+ weights=retrieval_model_config.get('weights', None),
|
|
)
|
|
)
|
|
self._on_query(query, [dataset_id], app_id, user_from, user_id)
|
|
self._on_query(query, [dataset_id], app_id, user_from, user_id)
|
|
|
|
|
|
@@ -292,14 +297,18 @@ class DatasetRetrieval:
|
|
query: str,
|
|
query: str,
|
|
top_k: int,
|
|
top_k: int,
|
|
score_threshold: float,
|
|
score_threshold: float,
|
|
- reranking_provider_name: str,
|
|
|
|
- reranking_model_name: str,
|
|
|
|
|
|
+ reranking_mode: str,
|
|
|
|
+ reranking_model: Optional[dict] = None,
|
|
|
|
+ weights: Optional[dict] = None,
|
|
|
|
+ reranking_enable: bool = True,
|
|
message_id: Optional[str] = None,
|
|
message_id: Optional[str] = None,
|
|
):
|
|
):
|
|
threads = []
|
|
threads = []
|
|
all_documents = []
|
|
all_documents = []
|
|
dataset_ids = [dataset.id for dataset in available_datasets]
|
|
dataset_ids = [dataset.id for dataset in available_datasets]
|
|
|
|
+ index_type = None
|
|
for dataset in available_datasets:
|
|
for dataset in available_datasets:
|
|
|
|
+ index_type = dataset.indexing_technique
|
|
retrieval_thread = threading.Thread(target=self._retriever, kwargs={
|
|
retrieval_thread = threading.Thread(target=self._retriever, kwargs={
|
|
'flask_app': current_app._get_current_object(),
|
|
'flask_app': current_app._get_current_object(),
|
|
'dataset_id': dataset.id,
|
|
'dataset_id': dataset.id,
|
|
@@ -311,23 +320,24 @@ class DatasetRetrieval:
|
|
retrieval_thread.start()
|
|
retrieval_thread.start()
|
|
for thread in threads:
|
|
for thread in threads:
|
|
thread.join()
|
|
thread.join()
|
|
- # do rerank for searched documents
|
|
|
|
- model_manager = ModelManager()
|
|
|
|
- rerank_model_instance = model_manager.get_model_instance(
|
|
|
|
- tenant_id=tenant_id,
|
|
|
|
- provider=reranking_provider_name,
|
|
|
|
- model_type=ModelType.RERANK,
|
|
|
|
- model=reranking_model_name
|
|
|
|
- )
|
|
|
|
|
|
|
|
- rerank_runner = RerankRunner(rerank_model_instance)
|
|
|
|
|
|
+ if reranking_enable:
|
|
|
|
+ # do rerank for searched documents
|
|
|
|
+ data_post_processor = DataPostProcessor(tenant_id, reranking_mode,
|
|
|
|
+ reranking_model, weights, False)
|
|
|
|
|
|
- with measure_time() as timer:
|
|
|
|
- all_documents = rerank_runner.run(
|
|
|
|
- query, all_documents,
|
|
|
|
- score_threshold,
|
|
|
|
- top_k
|
|
|
|
- )
|
|
|
|
|
|
+ with measure_time() as timer:
|
|
|
|
+ all_documents = data_post_processor.invoke(
|
|
|
|
+ query=query,
|
|
|
|
+ documents=all_documents,
|
|
|
|
+ score_threshold=score_threshold,
|
|
|
|
+ top_n=top_k
|
|
|
|
+ )
|
|
|
|
+ else:
|
|
|
|
+ if index_type == "economy":
|
|
|
|
+ all_documents = self.calculate_keyword_score(query, all_documents, top_k)
|
|
|
|
+ elif index_type == "high_quality":
|
|
|
|
+ all_documents = self.calculate_vector_score(all_documents, top_k, score_threshold)
|
|
self._on_query(query, dataset_ids, app_id, user_from, user_id)
|
|
self._on_query(query, dataset_ids, app_id, user_from, user_id)
|
|
|
|
|
|
if all_documents:
|
|
if all_documents:
|
|
@@ -420,7 +430,8 @@ class DatasetRetrieval:
|
|
score_threshold=retrieval_model['score_threshold']
|
|
score_threshold=retrieval_model['score_threshold']
|
|
if retrieval_model['score_threshold_enabled'] else None,
|
|
if retrieval_model['score_threshold_enabled'] else None,
|
|
reranking_model=retrieval_model['reranking_model']
|
|
reranking_model=retrieval_model['reranking_model']
|
|
- if retrieval_model['reranking_enable'] else None
|
|
|
|
|
|
+ if retrieval_model['reranking_enable'] else None,
|
|
|
|
+ weights=retrieval_model.get('weights', None),
|
|
)
|
|
)
|
|
|
|
|
|
all_documents.extend(documents)
|
|
all_documents.extend(documents)
|
|
@@ -513,3 +524,94 @@ class DatasetRetrieval:
|
|
tools.append(tool)
|
|
tools.append(tool)
|
|
|
|
|
|
return tools
|
|
return tools
|
|
|
|
+
|
|
|
|
+ def calculate_keyword_score(self, query: str, documents: list[Document], top_k: int) -> list[Document]:
|
|
|
|
+ """
|
|
|
|
+ Calculate keywords scores
|
|
|
|
+ :param query: search query
|
|
|
|
+ :param documents: documents for reranking
|
|
|
|
+
|
|
|
|
+ :return:
|
|
|
|
+ """
|
|
|
|
+ keyword_table_handler = JiebaKeywordTableHandler()
|
|
|
|
+ query_keywords = keyword_table_handler.extract_keywords(query, None)
|
|
|
|
+ documents_keywords = []
|
|
|
|
+ for document in documents:
|
|
|
|
+ # get the document keywords
|
|
|
|
+ document_keywords = keyword_table_handler.extract_keywords(document.page_content, None)
|
|
|
|
+ document.metadata['keywords'] = document_keywords
|
|
|
|
+ documents_keywords.append(document_keywords)
|
|
|
|
+
|
|
|
|
+ # Counter query keywords(TF)
|
|
|
|
+ query_keyword_counts = Counter(query_keywords)
|
|
|
|
+
|
|
|
|
+ # total documents
|
|
|
|
+ total_documents = len(documents)
|
|
|
|
+
|
|
|
|
+ # calculate all documents' keywords IDF
|
|
|
|
+ all_keywords = set()
|
|
|
|
+ for document_keywords in documents_keywords:
|
|
|
|
+ all_keywords.update(document_keywords)
|
|
|
|
+
|
|
|
|
+ keyword_idf = {}
|
|
|
|
+ for keyword in all_keywords:
|
|
|
|
+ # calculate include query keywords' documents
|
|
|
|
+ doc_count_containing_keyword = sum(1 for doc_keywords in documents_keywords if keyword in doc_keywords)
|
|
|
|
+ # IDF
|
|
|
|
+ keyword_idf[keyword] = math.log((1 + total_documents) / (1 + doc_count_containing_keyword)) + 1
|
|
|
|
+
|
|
|
|
+ query_tfidf = {}
|
|
|
|
+
|
|
|
|
+ for keyword, count in query_keyword_counts.items():
|
|
|
|
+ tf = count
|
|
|
|
+ idf = keyword_idf.get(keyword, 0)
|
|
|
|
+ query_tfidf[keyword] = tf * idf
|
|
|
|
+
|
|
|
|
+ # calculate all documents' TF-IDF
|
|
|
|
+ documents_tfidf = []
|
|
|
|
+ for document_keywords in documents_keywords:
|
|
|
|
+ document_keyword_counts = Counter(document_keywords)
|
|
|
|
+ document_tfidf = {}
|
|
|
|
+ for keyword, count in document_keyword_counts.items():
|
|
|
|
+ tf = count
|
|
|
|
+ idf = keyword_idf.get(keyword, 0)
|
|
|
|
+ document_tfidf[keyword] = tf * idf
|
|
|
|
+ documents_tfidf.append(document_tfidf)
|
|
|
|
+
|
|
|
|
+ def cosine_similarity(vec1, vec2):
|
|
|
|
+ intersection = set(vec1.keys()) & set(vec2.keys())
|
|
|
|
+ numerator = sum(vec1[x] * vec2[x] for x in intersection)
|
|
|
|
+
|
|
|
|
+ sum1 = sum(vec1[x] ** 2 for x in vec1.keys())
|
|
|
|
+ sum2 = sum(vec2[x] ** 2 for x in vec2.keys())
|
|
|
|
+ denominator = math.sqrt(sum1) * math.sqrt(sum2)
|
|
|
|
+
|
|
|
|
+ if not denominator:
|
|
|
|
+ return 0.0
|
|
|
|
+ else:
|
|
|
|
+ return float(numerator) / denominator
|
|
|
|
+
|
|
|
|
+ similarities = []
|
|
|
|
+ for document_tfidf in documents_tfidf:
|
|
|
|
+ similarity = cosine_similarity(query_tfidf, document_tfidf)
|
|
|
|
+ similarities.append(similarity)
|
|
|
|
+
|
|
|
|
+ for document, score in zip(documents, similarities):
|
|
|
|
+ # format document
|
|
|
|
+ document.metadata['score'] = score
|
|
|
|
+ documents = sorted(documents, key=lambda x: x.metadata['score'], reverse=True)
|
|
|
|
+ return documents[:top_k] if top_k else documents
|
|
|
|
+
|
|
|
|
+ def calculate_vector_score(self, all_documents: list[Document],
|
|
|
|
+ top_k: int, score_threshold: float) -> list[Document]:
|
|
|
|
+ filter_documents = []
|
|
|
|
+ for document in all_documents:
|
|
|
|
+ if document.metadata['score'] >= score_threshold:
|
|
|
|
+ filter_documents.append(document)
|
|
|
|
+ if not filter_documents:
|
|
|
|
+ return []
|
|
|
|
+ filter_documents = sorted(filter_documents, key=lambda x: x.metadata['score'], reverse=True)
|
|
|
|
+ return filter_documents[:top_k] if top_k else filter_documents
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+
|