123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302 |
- import json
- import logging
- from abc import abstractmethod
- from typing import List, Any, cast
- from langchain.embeddings.base import Embeddings
- from langchain.schema import Document, BaseRetriever
- from langchain.vectorstores import VectorStore
- from weaviate import UnexpectedStatusCodeException
- from core.index.base import BaseIndex
- from extensions.ext_database import db
- from models.dataset import Dataset, DocumentSegment, DatasetCollectionBinding
- from models.dataset import Document as DatasetDocument
- class BaseVectorIndex(BaseIndex):
- def __init__(self, dataset: Dataset, embeddings: Embeddings):
- super().__init__(dataset)
- self._embeddings = embeddings
- self._vector_store = None
- def get_type(self) -> str:
- raise NotImplementedError
- @abstractmethod
- def get_index_name(self, dataset: Dataset) -> str:
- raise NotImplementedError
- @abstractmethod
- def to_index_struct(self) -> dict:
- raise NotImplementedError
- @abstractmethod
- def _get_vector_store(self) -> VectorStore:
- raise NotImplementedError
- @abstractmethod
- def _get_vector_store_class(self) -> type:
- raise NotImplementedError
- def search(
- self, query: str,
- **kwargs: Any
- ) -> List[Document]:
- vector_store = self._get_vector_store()
- vector_store = cast(self._get_vector_store_class(), vector_store)
- search_type = kwargs.get('search_type') if kwargs.get('search_type') else 'similarity'
- search_kwargs = kwargs.get('search_kwargs') if kwargs.get('search_kwargs') else {}
- if search_type == 'similarity_score_threshold':
- score_threshold = search_kwargs.get("score_threshold")
- if (score_threshold is None) or (not isinstance(score_threshold, float)):
- search_kwargs['score_threshold'] = .0
- docs_with_similarity = vector_store.similarity_search_with_relevance_scores(
- query, **search_kwargs
- )
- docs = []
- for doc, similarity in docs_with_similarity:
- doc.metadata['score'] = similarity
- docs.append(doc)
- return docs
- # similarity k
- # mmr k, fetch_k, lambda_mult
- # similarity_score_threshold k
- return vector_store.as_retriever(
- search_type=search_type,
- search_kwargs=search_kwargs
- ).get_relevant_documents(query)
- def get_retriever(self, **kwargs: Any) -> BaseRetriever:
- vector_store = self._get_vector_store()
- vector_store = cast(self._get_vector_store_class(), vector_store)
- return vector_store.as_retriever(**kwargs)
- def add_texts(self, texts: list[Document], **kwargs):
- if self._is_origin():
- self.recreate_dataset(self.dataset)
- vector_store = self._get_vector_store()
- vector_store = cast(self._get_vector_store_class(), vector_store)
- if kwargs.get('duplicate_check', False):
- texts = self._filter_duplicate_texts(texts)
- uuids = self._get_uuids(texts)
- vector_store.add_documents(texts, uuids=uuids)
- def text_exists(self, id: str) -> bool:
- vector_store = self._get_vector_store()
- vector_store = cast(self._get_vector_store_class(), vector_store)
- return vector_store.text_exists(id)
- def delete_by_ids(self, ids: list[str]) -> None:
- if self._is_origin():
- self.recreate_dataset(self.dataset)
- return
- vector_store = self._get_vector_store()
- vector_store = cast(self._get_vector_store_class(), vector_store)
- for node_id in ids:
- vector_store.del_text(node_id)
- def delete_by_group_id(self, group_id: str) -> None:
- vector_store = self._get_vector_store()
- vector_store = cast(self._get_vector_store_class(), vector_store)
- vector_store.delete()
- def delete(self) -> None:
- vector_store = self._get_vector_store()
- vector_store = cast(self._get_vector_store_class(), vector_store)
- vector_store.delete()
- def _is_origin(self):
- return False
- def recreate_dataset(self, dataset: Dataset):
- logging.info(f"Recreating dataset {dataset.id}")
- try:
- self.delete()
- except UnexpectedStatusCodeException as e:
- if e.status_code != 400:
- # 400 means index not exists
- raise e
- dataset_documents = db.session.query(DatasetDocument).filter(
- DatasetDocument.dataset_id == dataset.id,
- DatasetDocument.indexing_status == 'completed',
- DatasetDocument.enabled == True,
- DatasetDocument.archived == False,
- ).all()
- documents = []
- for dataset_document in dataset_documents:
- segments = db.session.query(DocumentSegment).filter(
- DocumentSegment.document_id == dataset_document.id,
- DocumentSegment.status == 'completed',
- DocumentSegment.enabled == True
- ).all()
- for segment in segments:
- document = Document(
- page_content=segment.content,
- metadata={
- "doc_id": segment.index_node_id,
- "doc_hash": segment.index_node_hash,
- "document_id": segment.document_id,
- "dataset_id": segment.dataset_id,
- }
- )
- documents.append(document)
- origin_index_struct = self.dataset.index_struct[:]
- self.dataset.index_struct = None
- if documents:
- try:
- self.create(documents)
- except Exception as e:
- self.dataset.index_struct = origin_index_struct
- raise e
- dataset.index_struct = json.dumps(self.to_index_struct())
- db.session.commit()
- self.dataset = dataset
- logging.info(f"Dataset {dataset.id} recreate successfully.")
- def create_qdrant_dataset(self, dataset: Dataset):
- logging.info(f"create_qdrant_dataset {dataset.id}")
- try:
- self.delete()
- except UnexpectedStatusCodeException as e:
- if e.status_code != 400:
- # 400 means index not exists
- raise e
- dataset_documents = db.session.query(DatasetDocument).filter(
- DatasetDocument.dataset_id == dataset.id,
- DatasetDocument.indexing_status == 'completed',
- DatasetDocument.enabled == True,
- DatasetDocument.archived == False,
- ).all()
- documents = []
- for dataset_document in dataset_documents:
- segments = db.session.query(DocumentSegment).filter(
- DocumentSegment.document_id == dataset_document.id,
- DocumentSegment.status == 'completed',
- DocumentSegment.enabled == True
- ).all()
- for segment in segments:
- document = Document(
- page_content=segment.content,
- metadata={
- "doc_id": segment.index_node_id,
- "doc_hash": segment.index_node_hash,
- "document_id": segment.document_id,
- "dataset_id": segment.dataset_id,
- }
- )
- documents.append(document)
- if documents:
- try:
- self.create(documents)
- except Exception as e:
- raise e
- logging.info(f"Dataset {dataset.id} recreate successfully.")
- def update_qdrant_dataset(self, dataset: Dataset):
- logging.info(f"update_qdrant_dataset {dataset.id}")
- segment = db.session.query(DocumentSegment).filter(
- DocumentSegment.dataset_id == dataset.id,
- DocumentSegment.status == 'completed',
- DocumentSegment.enabled == True
- ).first()
- if segment:
- try:
- exist = self.text_exists(segment.index_node_id)
- if exist:
- index_struct = {
- "type": 'qdrant',
- "vector_store": {"class_prefix": dataset.index_struct_dict['vector_store']['class_prefix']}
- }
- dataset.index_struct = json.dumps(index_struct)
- db.session.commit()
- except Exception as e:
- raise e
- logging.info(f"Dataset {dataset.id} recreate successfully.")
- def restore_dataset_in_one(self, dataset: Dataset, dataset_collection_binding: DatasetCollectionBinding):
- logging.info(f"restore dataset in_one,_dataset {dataset.id}")
- dataset_documents = db.session.query(DatasetDocument).filter(
- DatasetDocument.dataset_id == dataset.id,
- DatasetDocument.indexing_status == 'completed',
- DatasetDocument.enabled == True,
- DatasetDocument.archived == False,
- ).all()
- documents = []
- for dataset_document in dataset_documents:
- segments = db.session.query(DocumentSegment).filter(
- DocumentSegment.document_id == dataset_document.id,
- DocumentSegment.status == 'completed',
- DocumentSegment.enabled == True
- ).all()
- for segment in segments:
- document = Document(
- page_content=segment.content,
- metadata={
- "doc_id": segment.index_node_id,
- "doc_hash": segment.index_node_hash,
- "document_id": segment.document_id,
- "dataset_id": segment.dataset_id,
- }
- )
- documents.append(document)
- if documents:
- try:
- self.create_with_collection_name(documents, dataset_collection_binding.collection_name)
- except Exception as e:
- raise e
- logging.info(f"Dataset {dataset.id} recreate successfully.")
- def delete_original_collection(self, dataset: Dataset, dataset_collection_binding: DatasetCollectionBinding):
- logging.info(f"delete original collection: {dataset.id}")
- self.delete()
- dataset.collection_binding_id = dataset_collection_binding.id
- db.session.add(dataset)
- db.session.commit()
- logging.info(f"Dataset {dataset.id} recreate successfully.")
|