base.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311
  1. import json
  2. import logging
  3. from abc import abstractmethod
  4. from typing import List, Any, cast
  5. from langchain.embeddings.base import Embeddings
  6. from langchain.schema import Document, BaseRetriever
  7. from langchain.vectorstores import VectorStore
  8. from weaviate import UnexpectedStatusCodeException
  9. from core.index.base import BaseIndex
  10. from extensions.ext_database import db
  11. from models.dataset import Dataset, DocumentSegment, DatasetCollectionBinding
  12. from models.dataset import Document as DatasetDocument
  13. class BaseVectorIndex(BaseIndex):
  14. def __init__(self, dataset: Dataset, embeddings: Embeddings):
  15. super().__init__(dataset)
  16. self._embeddings = embeddings
  17. self._vector_store = None
  18. def get_type(self) -> str:
  19. raise NotImplementedError
  20. @abstractmethod
  21. def get_index_name(self, dataset: Dataset) -> str:
  22. raise NotImplementedError
  23. @abstractmethod
  24. def to_index_struct(self) -> dict:
  25. raise NotImplementedError
  26. @abstractmethod
  27. def _get_vector_store(self) -> VectorStore:
  28. raise NotImplementedError
  29. @abstractmethod
  30. def _get_vector_store_class(self) -> type:
  31. raise NotImplementedError
  32. @abstractmethod
  33. def search_by_full_text_index(
  34. self, query: str,
  35. **kwargs: Any
  36. ) -> List[Document]:
  37. raise NotImplementedError
  38. def search(
  39. self, query: str,
  40. **kwargs: Any
  41. ) -> List[Document]:
  42. vector_store = self._get_vector_store()
  43. vector_store = cast(self._get_vector_store_class(), vector_store)
  44. search_type = kwargs.get('search_type') if kwargs.get('search_type') else 'similarity'
  45. search_kwargs = kwargs.get('search_kwargs') if kwargs.get('search_kwargs') else {}
  46. if search_type == 'similarity_score_threshold':
  47. score_threshold = search_kwargs.get("score_threshold")
  48. if (score_threshold is None) or (not isinstance(score_threshold, float)):
  49. search_kwargs['score_threshold'] = .0
  50. docs_with_similarity = vector_store.similarity_search_with_relevance_scores(
  51. query, **search_kwargs
  52. )
  53. docs = []
  54. for doc, similarity in docs_with_similarity:
  55. doc.metadata['score'] = similarity
  56. docs.append(doc)
  57. return docs
  58. # similarity k
  59. # mmr k, fetch_k, lambda_mult
  60. # similarity_score_threshold k
  61. return vector_store.as_retriever(
  62. search_type=search_type,
  63. search_kwargs=search_kwargs
  64. ).get_relevant_documents(query)
  65. def get_retriever(self, **kwargs: Any) -> BaseRetriever:
  66. vector_store = self._get_vector_store()
  67. vector_store = cast(self._get_vector_store_class(), vector_store)
  68. return vector_store.as_retriever(**kwargs)
  69. def add_texts(self, texts: list[Document], **kwargs):
  70. if self._is_origin():
  71. self.recreate_dataset(self.dataset)
  72. vector_store = self._get_vector_store()
  73. vector_store = cast(self._get_vector_store_class(), vector_store)
  74. if kwargs.get('duplicate_check', False):
  75. texts = self._filter_duplicate_texts(texts)
  76. uuids = self._get_uuids(texts)
  77. vector_store.add_documents(texts, uuids=uuids)
  78. def text_exists(self, id: str) -> bool:
  79. vector_store = self._get_vector_store()
  80. vector_store = cast(self._get_vector_store_class(), vector_store)
  81. return vector_store.text_exists(id)
  82. def delete_by_ids(self, ids: list[str]) -> None:
  83. if self._is_origin():
  84. self.recreate_dataset(self.dataset)
  85. return
  86. vector_store = self._get_vector_store()
  87. vector_store = cast(self._get_vector_store_class(), vector_store)
  88. for node_id in ids:
  89. vector_store.del_text(node_id)
  90. def delete_by_group_id(self, group_id: str) -> None:
  91. vector_store = self._get_vector_store()
  92. vector_store = cast(self._get_vector_store_class(), vector_store)
  93. if self.dataset.collection_binding_id:
  94. vector_store.delete_by_group_id(group_id)
  95. else:
  96. vector_store.delete()
  97. def delete(self) -> None:
  98. vector_store = self._get_vector_store()
  99. vector_store = cast(self._get_vector_store_class(), vector_store)
  100. vector_store.delete()
  101. def _is_origin(self):
  102. return False
  103. def recreate_dataset(self, dataset: Dataset):
  104. logging.info(f"Recreating dataset {dataset.id}")
  105. try:
  106. self.delete()
  107. except UnexpectedStatusCodeException as e:
  108. if e.status_code != 400:
  109. # 400 means index not exists
  110. raise e
  111. dataset_documents = db.session.query(DatasetDocument).filter(
  112. DatasetDocument.dataset_id == dataset.id,
  113. DatasetDocument.indexing_status == 'completed',
  114. DatasetDocument.enabled == True,
  115. DatasetDocument.archived == False,
  116. ).all()
  117. documents = []
  118. for dataset_document in dataset_documents:
  119. segments = db.session.query(DocumentSegment).filter(
  120. DocumentSegment.document_id == dataset_document.id,
  121. DocumentSegment.status == 'completed',
  122. DocumentSegment.enabled == True
  123. ).all()
  124. for segment in segments:
  125. document = Document(
  126. page_content=segment.content,
  127. metadata={
  128. "doc_id": segment.index_node_id,
  129. "doc_hash": segment.index_node_hash,
  130. "document_id": segment.document_id,
  131. "dataset_id": segment.dataset_id,
  132. }
  133. )
  134. documents.append(document)
  135. origin_index_struct = self.dataset.index_struct[:]
  136. self.dataset.index_struct = None
  137. if documents:
  138. try:
  139. self.create(documents)
  140. except Exception as e:
  141. self.dataset.index_struct = origin_index_struct
  142. raise e
  143. dataset.index_struct = json.dumps(self.to_index_struct())
  144. db.session.commit()
  145. self.dataset = dataset
  146. logging.info(f"Dataset {dataset.id} recreate successfully.")
  147. def create_qdrant_dataset(self, dataset: Dataset):
  148. logging.info(f"create_qdrant_dataset {dataset.id}")
  149. try:
  150. self.delete()
  151. except UnexpectedStatusCodeException as e:
  152. if e.status_code != 400:
  153. # 400 means index not exists
  154. raise e
  155. dataset_documents = db.session.query(DatasetDocument).filter(
  156. DatasetDocument.dataset_id == dataset.id,
  157. DatasetDocument.indexing_status == 'completed',
  158. DatasetDocument.enabled == True,
  159. DatasetDocument.archived == False,
  160. ).all()
  161. documents = []
  162. for dataset_document in dataset_documents:
  163. segments = db.session.query(DocumentSegment).filter(
  164. DocumentSegment.document_id == dataset_document.id,
  165. DocumentSegment.status == 'completed',
  166. DocumentSegment.enabled == True
  167. ).all()
  168. for segment in segments:
  169. document = Document(
  170. page_content=segment.content,
  171. metadata={
  172. "doc_id": segment.index_node_id,
  173. "doc_hash": segment.index_node_hash,
  174. "document_id": segment.document_id,
  175. "dataset_id": segment.dataset_id,
  176. }
  177. )
  178. documents.append(document)
  179. if documents:
  180. try:
  181. self.create(documents)
  182. except Exception as e:
  183. raise e
  184. logging.info(f"Dataset {dataset.id} recreate successfully.")
  185. def update_qdrant_dataset(self, dataset: Dataset):
  186. logging.info(f"update_qdrant_dataset {dataset.id}")
  187. segment = db.session.query(DocumentSegment).filter(
  188. DocumentSegment.dataset_id == dataset.id,
  189. DocumentSegment.status == 'completed',
  190. DocumentSegment.enabled == True
  191. ).first()
  192. if segment:
  193. try:
  194. exist = self.text_exists(segment.index_node_id)
  195. if exist:
  196. index_struct = {
  197. "type": 'qdrant',
  198. "vector_store": {"class_prefix": dataset.index_struct_dict['vector_store']['class_prefix']}
  199. }
  200. dataset.index_struct = json.dumps(index_struct)
  201. db.session.commit()
  202. except Exception as e:
  203. raise e
  204. logging.info(f"Dataset {dataset.id} recreate successfully.")
  205. def restore_dataset_in_one(self, dataset: Dataset, dataset_collection_binding: DatasetCollectionBinding):
  206. logging.info(f"restore dataset in_one,_dataset {dataset.id}")
  207. dataset_documents = db.session.query(DatasetDocument).filter(
  208. DatasetDocument.dataset_id == dataset.id,
  209. DatasetDocument.indexing_status == 'completed',
  210. DatasetDocument.enabled == True,
  211. DatasetDocument.archived == False,
  212. ).all()
  213. documents = []
  214. for dataset_document in dataset_documents:
  215. segments = db.session.query(DocumentSegment).filter(
  216. DocumentSegment.document_id == dataset_document.id,
  217. DocumentSegment.status == 'completed',
  218. DocumentSegment.enabled == True
  219. ).all()
  220. for segment in segments:
  221. document = Document(
  222. page_content=segment.content,
  223. metadata={
  224. "doc_id": segment.index_node_id,
  225. "doc_hash": segment.index_node_hash,
  226. "document_id": segment.document_id,
  227. "dataset_id": segment.dataset_id,
  228. }
  229. )
  230. documents.append(document)
  231. if documents:
  232. try:
  233. self.add_texts(documents)
  234. except Exception as e:
  235. raise e
  236. logging.info(f"Dataset {dataset.id} recreate successfully.")
  237. def delete_original_collection(self, dataset: Dataset, dataset_collection_binding: DatasetCollectionBinding):
  238. logging.info(f"delete original collection: {dataset.id}")
  239. self.delete()
  240. dataset.collection_binding_id = dataset_collection_binding.id
  241. db.session.add(dataset)
  242. db.session.commit()
  243. logging.info(f"Dataset {dataset.id} recreate successfully.")