| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137 | 
							- import json
 
- import logging
 
- from typing import List, Optional
 
- from llama_index.data_structs import Node
 
- from requests import ReadTimeout
 
- from sqlalchemy.exc import IntegrityError
 
- from tenacity import retry, stop_after_attempt, retry_if_exception_type
 
- from core.index.index_builder import IndexBuilder
 
- from core.vector_store.base import BaseGPTVectorStoreIndex
 
- from extensions.ext_vector_store import vector_store
 
- from extensions.ext_database import db
 
- from models.dataset import Dataset, Embedding
 
- class VectorIndex:
 
-     def __init__(self, dataset: Dataset):
 
-         self._dataset = dataset
 
-     def add_nodes(self, nodes: List[Node], duplicate_check: bool = False):
 
-         if not self._dataset.index_struct_dict:
 
-             index_id = "Vector_index_" + self._dataset.id.replace("-", "_")
 
-             self._dataset.index_struct = json.dumps(vector_store.to_index_struct(index_id))
 
-             db.session.commit()
 
-         service_context = IndexBuilder.get_default_service_context(tenant_id=self._dataset.tenant_id)
 
-         index = vector_store.get_index(
 
-             service_context=service_context,
 
-             index_struct=self._dataset.index_struct_dict
 
-         )
 
-         if duplicate_check:
 
-             nodes = self._filter_duplicate_nodes(index, nodes)
 
-         embedding_queue_nodes = []
 
-         embedded_nodes = []
 
-         for node in nodes:
 
-             node_hash = node.doc_hash
 
-             # if node hash in cached embedding tables, use cached embedding
 
-             embedding = db.session.query(Embedding).filter_by(hash=node_hash).first()
 
-             if embedding:
 
-                 node.embedding = embedding.get_embedding()
 
-                 embedded_nodes.append(node)
 
-             else:
 
-                 embedding_queue_nodes.append(node)
 
-         if embedding_queue_nodes:
 
-             embedding_results = index._get_node_embedding_results(
 
-                 embedding_queue_nodes,
 
-                 set(),
 
-             )
 
-             # pre embed nodes for cached embedding
 
-             for embedding_result in embedding_results:
 
-                 node = embedding_result.node
 
-                 node.embedding = embedding_result.embedding
 
-                 try:
 
-                     embedding = Embedding(hash=node.doc_hash)
 
-                     embedding.set_embedding(node.embedding)
 
-                     db.session.add(embedding)
 
-                     db.session.commit()
 
-                 except IntegrityError:
 
-                     db.session.rollback()
 
-                     continue
 
-                 except:
 
-                     logging.exception('Failed to add embedding to db')
 
-                     continue
 
-                 embedded_nodes.append(node)
 
-         self.index_insert_nodes(index, embedded_nodes)
 
-     @retry(reraise=True, retry=retry_if_exception_type(ReadTimeout), stop=stop_after_attempt(3))
 
-     def index_insert_nodes(self, index: BaseGPTVectorStoreIndex, nodes: List[Node]):
 
-         index.insert_nodes(nodes)
 
-     def del_nodes(self, node_ids: List[str]):
 
-         if not self._dataset.index_struct_dict:
 
-             return
 
-         service_context = IndexBuilder.get_fake_llm_service_context(tenant_id=self._dataset.tenant_id)
 
-         index = vector_store.get_index(
 
-             service_context=service_context,
 
-             index_struct=self._dataset.index_struct_dict
 
-         )
 
-         for node_id in node_ids:
 
-             self.index_delete_node(index, node_id)
 
-     @retry(reraise=True, retry=retry_if_exception_type(ReadTimeout), stop=stop_after_attempt(3))
 
-     def index_delete_node(self, index: BaseGPTVectorStoreIndex, node_id: str):
 
-         index.delete_node(node_id)
 
-     def del_doc(self, doc_id: str):
 
-         if not self._dataset.index_struct_dict:
 
-             return
 
-         service_context = IndexBuilder.get_fake_llm_service_context(tenant_id=self._dataset.tenant_id)
 
-         index = vector_store.get_index(
 
-             service_context=service_context,
 
-             index_struct=self._dataset.index_struct_dict
 
-         )
 
-         self.index_delete_doc(index, doc_id)
 
-     @retry(reraise=True, retry=retry_if_exception_type(ReadTimeout), stop=stop_after_attempt(3))
 
-     def index_delete_doc(self, index: BaseGPTVectorStoreIndex, doc_id: str):
 
-         index.delete(doc_id)
 
-     @property
 
-     def query_index(self) -> Optional[BaseGPTVectorStoreIndex]:
 
-         if not self._dataset.index_struct_dict:
 
-             return None
 
-         service_context = IndexBuilder.get_default_service_context(tenant_id=self._dataset.tenant_id)
 
-         return vector_store.get_index(
 
-             service_context=service_context,
 
-             index_struct=self._dataset.index_struct_dict
 
-         )
 
-     def _filter_duplicate_nodes(self, index: BaseGPTVectorStoreIndex, nodes: List[Node]) -> List[Node]:
 
-         for node in nodes:
 
-             node_id = node.doc_id
 
-             exists_duplicate_node = index.exists_by_node_id(node_id)
 
-             if exists_duplicate_node:
 
-                 nodes.remove(node)
 
-         return nodes
 
 
  |