Jyong 1 год назад
Родитель
Сommit
a5b80c9d1f

+ 2 - 2
api/core/tool/dataset_multi_retriever_tool.py

@@ -192,7 +192,7 @@ class DatasetMultiRetrieverTool(BaseTool):
                         'search_method'] == 'hybrid_search':
                         embedding_thread = threading.Thread(target=RetrievalService.embedding_search, kwargs={
                             'flask_app': current_app._get_current_object(),
-                            'dataset': dataset,
+                            'dataset_id': str(dataset.id),
                             'query': query,
                             'top_k': self.top_k,
                             'score_threshold': self.score_threshold,
@@ -210,7 +210,7 @@ class DatasetMultiRetrieverTool(BaseTool):
                         full_text_index_thread = threading.Thread(target=RetrievalService.full_text_index_search,
                                                                   kwargs={
                                                                       'flask_app': current_app._get_current_object(),
-                                                                      'dataset': dataset,
+                                                                      'dataset_id': str(dataset.id),
                                                                       'query': query,
                                                                       'search_method': 'hybrid_search',
                                                                       'embeddings': embeddings,

+ 2 - 2
api/core/tool/dataset_retriever_tool.py

@@ -106,7 +106,7 @@ class DatasetRetrieverTool(BaseTool):
                 if retrieval_model['search_method'] == 'semantic_search' or retrieval_model['search_method'] == 'hybrid_search':
                     embedding_thread = threading.Thread(target=RetrievalService.embedding_search, kwargs={
                         'flask_app': current_app._get_current_object(),
-                        'dataset': dataset,
+                        'dataset_id': str(dataset.id),
                         'query': query,
                         'top_k': self.top_k,
                         'score_threshold': retrieval_model['score_threshold'] if retrieval_model[
@@ -124,7 +124,7 @@ class DatasetRetrieverTool(BaseTool):
                 if retrieval_model['search_method'] == 'full_text_search' or retrieval_model['search_method'] == 'hybrid_search':
                     full_text_index_thread = threading.Thread(target=RetrievalService.full_text_index_search, kwargs={
                         'flask_app': current_app._get_current_object(),
-                        'dataset': dataset,
+                        'dataset_id': str(dataset.id),
                         'query': query,
                         'search_method': retrieval_model['search_method'],
                         'embeddings': embeddings,

+ 2 - 2
api/services/hit_testing_service.py

@@ -61,7 +61,7 @@ class HitTestingService:
         if retrieval_model['search_method'] == 'semantic_search' or retrieval_model['search_method'] == 'hybrid_search':
             embedding_thread = threading.Thread(target=RetrievalService.embedding_search, kwargs={
                 'flask_app': current_app._get_current_object(),
-                'dataset': dataset,
+                'dataset_id': str(dataset.id),
                 'query': query,
                 'top_k': retrieval_model['top_k'],
                 'score_threshold': retrieval_model['score_threshold'] if retrieval_model['score_threshold_enable'] else None,
@@ -77,7 +77,7 @@ class HitTestingService:
         if retrieval_model['search_method'] == 'full_text_search' or retrieval_model['search_method'] == 'hybrid_search':
             full_text_index_thread = threading.Thread(target=RetrievalService.full_text_index_search, kwargs={
                 'flask_app': current_app._get_current_object(),
-                'dataset': dataset,
+                'dataset_id': str(dataset.id),
                 'query': query,
                 'search_method': retrieval_model['search_method'],
                 'embeddings': embeddings,

+ 9 - 2
api/services/retrieval_service.py

@@ -4,6 +4,7 @@ from flask import current_app, Flask
 from langchain.embeddings.base import Embeddings
 from core.index.vector_index.vector_index import VectorIndex
 from core.model_providers.model_factory import ModelFactory
+from extensions.ext_database import db
 from models.dataset import Dataset
 
 default_retrieval_model = {
@@ -21,10 +22,13 @@ default_retrieval_model = {
 class RetrievalService:
 
     @classmethod
-    def embedding_search(cls, flask_app: Flask, dataset: Dataset, query: str,
+    def embedding_search(cls, flask_app: Flask, dataset_id: str, query: str,
                          top_k: int, score_threshold: Optional[float], reranking_model: Optional[dict],
                          all_documents: list, search_method: str, embeddings: Embeddings):
         with flask_app.app_context():
+            dataset = db.session.query(Dataset).filter(
+                Dataset.id == dataset_id
+            ).first()
 
             vector_index = VectorIndex(
                 dataset=dataset,
@@ -56,10 +60,13 @@ class RetrievalService:
                     all_documents.extend(documents)
 
     @classmethod
-    def full_text_index_search(cls, flask_app: Flask, dataset: Dataset, query: str,
+    def full_text_index_search(cls, flask_app: Flask, dataset_id: str, query: str,
                                top_k: int, score_threshold: Optional[float], reranking_model: Optional[dict],
                                all_documents: list, search_method: str, embeddings: Embeddings):
         with flask_app.app_context():
+            dataset = db.session.query(Dataset).filter(
+                Dataset.id == dataset_id
+            ).first()
 
             vector_index = VectorIndex(
                 dataset=dataset,