|
@@ -0,0 +1,114 @@
|
|
|
+from typing import Optional, cast
|
|
|
+
|
|
|
+from langchain.embeddings.base import Embeddings
|
|
|
+from langchain.schema import Document, BaseRetriever
|
|
|
+from langchain.vectorstores import VectorStore, milvus
|
|
|
+from pydantic import BaseModel, root_validator
|
|
|
+
|
|
|
+from core.index.base import BaseIndex
|
|
|
+from core.index.vector_index.base import BaseVectorIndex
|
|
|
+from core.vector_store.milvus_vector_store import MilvusVectorStore
|
|
|
+from core.vector_store.weaviate_vector_store import WeaviateVectorStore
|
|
|
+from models.dataset import Dataset
|
|
|
+
|
|
|
+
|
|
|
+class MilvusConfig(BaseModel):
|
|
|
+ endpoint: str
|
|
|
+ user: str
|
|
|
+ password: str
|
|
|
+ batch_size: int = 100
|
|
|
+
|
|
|
+ @root_validator()
|
|
|
+ def validate_config(cls, values: dict) -> dict:
|
|
|
+ if not values['endpoint']:
|
|
|
+ raise ValueError("config MILVUS_ENDPOINT is required")
|
|
|
+ if not values['user']:
|
|
|
+ raise ValueError("config MILVUS_USER is required")
|
|
|
+ if not values['password']:
|
|
|
+ raise ValueError("config MILVUS_PASSWORD is required")
|
|
|
+ return values
|
|
|
+
|
|
|
+
|
|
|
+class MilvusVectorIndex(BaseVectorIndex):
|
|
|
+ def __init__(self, dataset: Dataset, config: MilvusConfig, embeddings: Embeddings):
|
|
|
+ super().__init__(dataset, embeddings)
|
|
|
+ self._client = self._init_client(config)
|
|
|
+
|
|
|
+ def get_type(self) -> str:
|
|
|
+ return 'milvus'
|
|
|
+
|
|
|
+ def get_index_name(self, dataset: Dataset) -> str:
|
|
|
+ if self.dataset.index_struct_dict:
|
|
|
+ class_prefix: str = self.dataset.index_struct_dict['vector_store']['class_prefix']
|
|
|
+ if not class_prefix.endswith('_Node'):
|
|
|
+ # original class_prefix
|
|
|
+ class_prefix += '_Node'
|
|
|
+
|
|
|
+ return class_prefix
|
|
|
+
|
|
|
+ dataset_id = dataset.id
|
|
|
+ return "Vector_index_" + dataset_id.replace("-", "_") + '_Node'
|
|
|
+
|
|
|
+
|
|
|
+ def to_index_struct(self) -> dict:
|
|
|
+ return {
|
|
|
+ "type": self.get_type(),
|
|
|
+ "vector_store": {"class_prefix": self.get_index_name(self.dataset)}
|
|
|
+ }
|
|
|
+
|
|
|
+ def create(self, texts: list[Document], **kwargs) -> BaseIndex:
|
|
|
+ uuids = self._get_uuids(texts)
|
|
|
+ self._vector_store = WeaviateVectorStore.from_documents(
|
|
|
+ texts,
|
|
|
+ self._embeddings,
|
|
|
+ client=self._client,
|
|
|
+ index_name=self.get_index_name(self.dataset),
|
|
|
+ uuids=uuids,
|
|
|
+ by_text=False
|
|
|
+ )
|
|
|
+
|
|
|
+ return self
|
|
|
+
|
|
|
+ def _get_vector_store(self) -> VectorStore:
|
|
|
+ """Only for created index."""
|
|
|
+ if self._vector_store:
|
|
|
+ return self._vector_store
|
|
|
+
|
|
|
+ attributes = ['doc_id', 'dataset_id', 'document_id']
|
|
|
+ if self._is_origin():
|
|
|
+ attributes = ['doc_id']
|
|
|
+
|
|
|
+ return WeaviateVectorStore(
|
|
|
+ client=self._client,
|
|
|
+ index_name=self.get_index_name(self.dataset),
|
|
|
+ text_key='text',
|
|
|
+ embedding=self._embeddings,
|
|
|
+ attributes=attributes,
|
|
|
+ by_text=False
|
|
|
+ )
|
|
|
+
|
|
|
+ def _get_vector_store_class(self) -> type:
|
|
|
+ return MilvusVectorStore
|
|
|
+
|
|
|
+ def delete_by_document_id(self, document_id: str):
|
|
|
+ if self._is_origin():
|
|
|
+ self.recreate_dataset(self.dataset)
|
|
|
+ return
|
|
|
+
|
|
|
+ vector_store = self._get_vector_store()
|
|
|
+ vector_store = cast(self._get_vector_store_class(), vector_store)
|
|
|
+
|
|
|
+ vector_store.del_texts({
|
|
|
+ "operator": "Equal",
|
|
|
+ "path": ["document_id"],
|
|
|
+ "valueText": document_id
|
|
|
+ })
|
|
|
+
|
|
|
+ def _is_origin(self):
|
|
|
+ if self.dataset.index_struct_dict:
|
|
|
+ class_prefix: str = self.dataset.index_struct_dict['vector_store']['class_prefix']
|
|
|
+ if not class_prefix.endswith('_Node'):
|
|
|
+ # original class_prefix
|
|
|
+ return True
|
|
|
+
|
|
|
+ return False
|