vector_service.py 9.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218
  1. from typing import Optional
  2. from core.model_manager import ModelInstance, ModelManager
  3. from core.model_runtime.entities.model_entities import ModelType
  4. from core.rag.datasource.keyword.keyword_factory import Keyword
  5. from core.rag.datasource.vdb.vector_factory import Vector
  6. from core.rag.index_processor.constant.index_type import IndexType
  7. from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
  8. from core.rag.models.document import Document
  9. from extensions.ext_database import db
  10. from models.dataset import ChildChunk, Dataset, DatasetProcessRule, DocumentSegment
  11. from models.dataset import Document as DatasetDocument
  12. from services.entities.knowledge_entities.knowledge_entities import ParentMode
  13. class VectorService:
  14. @classmethod
  15. def create_segments_vector(
  16. cls, keywords_list: Optional[list[list[str]]], segments: list[DocumentSegment], dataset: Dataset, doc_form: str
  17. ):
  18. documents = []
  19. for segment in segments:
  20. if doc_form == IndexType.PARENT_CHILD_INDEX:
  21. document = DatasetDocument.query.filter_by(id=segment.document_id).first()
  22. # get the process rule
  23. processing_rule = (
  24. db.session.query(DatasetProcessRule)
  25. .filter(DatasetProcessRule.id == document.dataset_process_rule_id)
  26. .first()
  27. )
  28. if not processing_rule:
  29. raise ValueError("No processing rule found.")
  30. # get embedding model instance
  31. if dataset.indexing_technique == "high_quality":
  32. # check embedding model setting
  33. model_manager = ModelManager()
  34. if dataset.embedding_model_provider:
  35. embedding_model_instance = model_manager.get_model_instance(
  36. tenant_id=dataset.tenant_id,
  37. provider=dataset.embedding_model_provider,
  38. model_type=ModelType.TEXT_EMBEDDING,
  39. model=dataset.embedding_model,
  40. )
  41. else:
  42. embedding_model_instance = model_manager.get_default_model_instance(
  43. tenant_id=dataset.tenant_id,
  44. model_type=ModelType.TEXT_EMBEDDING,
  45. )
  46. else:
  47. raise ValueError("The knowledge base index technique is not high quality!")
  48. cls.generate_child_chunks(segment, document, dataset, embedding_model_instance, processing_rule, False)
  49. else:
  50. document = Document(
  51. page_content=segment.content,
  52. metadata={
  53. "doc_id": segment.index_node_id,
  54. "doc_hash": segment.index_node_hash,
  55. "document_id": segment.document_id,
  56. "dataset_id": segment.dataset_id,
  57. },
  58. )
  59. documents.append(document)
  60. if len(documents) > 0:
  61. index_processor = IndexProcessorFactory(doc_form).init_index_processor()
  62. index_processor.load(dataset, documents, with_keywords=True, keywords_list=keywords_list)
  63. @classmethod
  64. def update_segment_vector(cls, keywords: Optional[list[str]], segment: DocumentSegment, dataset: Dataset):
  65. # update segment index task
  66. # format new index
  67. document = Document(
  68. page_content=segment.content,
  69. metadata={
  70. "doc_id": segment.index_node_id,
  71. "doc_hash": segment.index_node_hash,
  72. "document_id": segment.document_id,
  73. "dataset_id": segment.dataset_id,
  74. },
  75. )
  76. if dataset.indexing_technique == "high_quality":
  77. # update vector index
  78. vector = Vector(dataset=dataset)
  79. vector.delete_by_ids([segment.index_node_id])
  80. vector.add_texts([document], duplicate_check=True)
  81. # update keyword index
  82. keyword = Keyword(dataset)
  83. keyword.delete_by_ids([segment.index_node_id])
  84. # save keyword index
  85. if keywords and len(keywords) > 0:
  86. keyword.add_texts([document], keywords_list=[keywords])
  87. else:
  88. keyword.add_texts([document])
  89. @classmethod
  90. def generate_child_chunks(
  91. cls,
  92. segment: DocumentSegment,
  93. dataset_document: DatasetDocument,
  94. dataset: Dataset,
  95. embedding_model_instance: ModelInstance,
  96. processing_rule: DatasetProcessRule,
  97. regenerate: bool = False,
  98. ):
  99. index_processor = IndexProcessorFactory(dataset.doc_form).init_index_processor()
  100. if regenerate:
  101. # delete child chunks
  102. index_processor.clean(dataset, [segment.index_node_id], with_keywords=True, delete_child_chunks=True)
  103. # generate child chunks
  104. document = Document(
  105. page_content=segment.content,
  106. metadata={
  107. "doc_id": segment.index_node_id,
  108. "doc_hash": segment.index_node_hash,
  109. "document_id": segment.document_id,
  110. "dataset_id": segment.dataset_id,
  111. },
  112. )
  113. # use full doc mode to generate segment's child chunk
  114. processing_rule_dict = processing_rule.to_dict()
  115. processing_rule_dict["rules"]["parent_mode"] = ParentMode.FULL_DOC.value
  116. documents = index_processor.transform(
  117. [document],
  118. embedding_model_instance=embedding_model_instance,
  119. process_rule=processing_rule_dict,
  120. tenant_id=dataset.tenant_id,
  121. doc_language=dataset_document.doc_language,
  122. )
  123. # save child chunks
  124. if documents and documents[0].children:
  125. index_processor.load(dataset, documents)
  126. for position, child_chunk in enumerate(documents[0].children, start=1):
  127. child_segment = ChildChunk(
  128. tenant_id=dataset.tenant_id,
  129. dataset_id=dataset.id,
  130. document_id=dataset_document.id,
  131. segment_id=segment.id,
  132. position=position,
  133. index_node_id=child_chunk.metadata["doc_id"],
  134. index_node_hash=child_chunk.metadata["doc_hash"],
  135. content=child_chunk.page_content,
  136. word_count=len(child_chunk.page_content),
  137. type="automatic",
  138. created_by=dataset_document.created_by,
  139. )
  140. db.session.add(child_segment)
  141. db.session.commit()
  142. @classmethod
  143. def create_child_chunk_vector(cls, child_segment: ChildChunk, dataset: Dataset):
  144. child_document = Document(
  145. page_content=child_segment.content,
  146. metadata={
  147. "doc_id": child_segment.index_node_id,
  148. "doc_hash": child_segment.index_node_hash,
  149. "document_id": child_segment.document_id,
  150. "dataset_id": child_segment.dataset_id,
  151. },
  152. )
  153. if dataset.indexing_technique == "high_quality":
  154. # save vector index
  155. vector = Vector(dataset=dataset)
  156. vector.add_texts([child_document], duplicate_check=True)
  157. @classmethod
  158. def update_child_chunk_vector(
  159. cls,
  160. new_child_chunks: list[ChildChunk],
  161. update_child_chunks: list[ChildChunk],
  162. delete_child_chunks: list[ChildChunk],
  163. dataset: Dataset,
  164. ):
  165. documents = []
  166. delete_node_ids = []
  167. for new_child_chunk in new_child_chunks:
  168. new_child_document = Document(
  169. page_content=new_child_chunk.content,
  170. metadata={
  171. "doc_id": new_child_chunk.index_node_id,
  172. "doc_hash": new_child_chunk.index_node_hash,
  173. "document_id": new_child_chunk.document_id,
  174. "dataset_id": new_child_chunk.dataset_id,
  175. },
  176. )
  177. documents.append(new_child_document)
  178. for update_child_chunk in update_child_chunks:
  179. child_document = Document(
  180. page_content=update_child_chunk.content,
  181. metadata={
  182. "doc_id": update_child_chunk.index_node_id,
  183. "doc_hash": update_child_chunk.index_node_hash,
  184. "document_id": update_child_chunk.document_id,
  185. "dataset_id": update_child_chunk.dataset_id,
  186. },
  187. )
  188. documents.append(child_document)
  189. delete_node_ids.append(update_child_chunk.index_node_id)
  190. for delete_child_chunk in delete_child_chunks:
  191. delete_node_ids.append(delete_child_chunk.index_node_id)
  192. if dataset.indexing_technique == "high_quality":
  193. # update vector index
  194. vector = Vector(dataset=dataset)
  195. if delete_node_ids:
  196. vector.delete_by_ids(delete_node_ids)
  197. if documents:
  198. vector.add_texts(documents, duplicate_check=True)
  199. @classmethod
  200. def delete_child_chunk_vector(cls, child_chunk: ChildChunk, dataset: Dataset):
  201. vector = Vector(dataset=dataset)
  202. vector.delete_by_ids([child_chunk.index_node_id])