|
@@ -4,6 +4,7 @@ from flask import current_app, Flask
|
|
|
from langchain.embeddings.base import Embeddings
|
|
|
from core.index.vector_index.vector_index import VectorIndex
|
|
|
from core.model_providers.model_factory import ModelFactory
|
|
|
+from extensions.ext_database import db
|
|
|
from models.dataset import Dataset
|
|
|
|
|
|
default_retrieval_model = {
|
|
@@ -21,10 +22,13 @@ default_retrieval_model = {
|
|
|
class RetrievalService:
|
|
|
|
|
|
@classmethod
|
|
|
- def embedding_search(cls, flask_app: Flask, dataset: Dataset, query: str,
|
|
|
+ def embedding_search(cls, flask_app: Flask, dataset_id: str, query: str,
|
|
|
top_k: int, score_threshold: Optional[float], reranking_model: Optional[dict],
|
|
|
all_documents: list, search_method: str, embeddings: Embeddings):
|
|
|
with flask_app.app_context():
|
|
|
+ dataset = db.session.query(Dataset).filter(
|
|
|
+ Dataset.id == dataset_id
|
|
|
+ ).first()
|
|
|
|
|
|
vector_index = VectorIndex(
|
|
|
dataset=dataset,
|
|
@@ -56,10 +60,13 @@ class RetrievalService:
|
|
|
all_documents.extend(documents)
|
|
|
|
|
|
@classmethod
|
|
|
- def full_text_index_search(cls, flask_app: Flask, dataset: Dataset, query: str,
|
|
|
+ def full_text_index_search(cls, flask_app: Flask, dataset_id: str, query: str,
|
|
|
top_k: int, score_threshold: Optional[float], reranking_model: Optional[dict],
|
|
|
all_documents: list, search_method: str, embeddings: Embeddings):
|
|
|
with flask_app.app_context():
|
|
|
+ dataset = db.session.query(Dataset).filter(
|
|
|
+ Dataset.id == dataset_id
|
|
|
+ ).first()
|
|
|
|
|
|
vector_index = VectorIndex(
|
|
|
dataset=dataset,
|