dataset_retriever_tool.py 7.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169
  1. import json
  2. from typing import Type
  3. from flask import current_app
  4. from langchain.tools import BaseTool
  5. from pydantic import Field, BaseModel
  6. from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
  7. from core.conversation_message_task import ConversationMessageTask
  8. from core.embedding.cached_embedding import CacheEmbedding
  9. from core.index.keyword_table_index.keyword_table_index import KeywordTableIndex, KeywordTableConfig
  10. from core.index.vector_index.vector_index import VectorIndex
  11. from core.model_providers.error import LLMBadRequestError, ProviderTokenNotInitError
  12. from core.model_providers.model_factory import ModelFactory
  13. from extensions.ext_database import db
  14. from models.dataset import Dataset, DocumentSegment, Document
  15. class DatasetRetrieverToolInput(BaseModel):
  16. query: str = Field(..., description="Query for the dataset to be used to retrieve the dataset.")
  17. class DatasetRetrieverTool(BaseTool):
  18. """Tool for querying a Dataset."""
  19. name: str = "dataset"
  20. args_schema: Type[BaseModel] = DatasetRetrieverToolInput
  21. description: str = "use this to retrieve a dataset. "
  22. tenant_id: str
  23. dataset_id: str
  24. k: int = 3
  25. conversation_message_task: ConversationMessageTask
  26. return_resource: str
  27. retriever_from: str
  28. @classmethod
  29. def from_dataset(cls, dataset: Dataset, **kwargs):
  30. description = dataset.description
  31. if not description:
  32. description = 'useful for when you want to answer queries about the ' + dataset.name
  33. description = description.replace('\n', '').replace('\r', '')
  34. return cls(
  35. name=f'dataset-{dataset.id}',
  36. tenant_id=dataset.tenant_id,
  37. dataset_id=dataset.id,
  38. description=description,
  39. **kwargs
  40. )
  41. def _run(self, query: str) -> str:
  42. dataset = db.session.query(Dataset).filter(
  43. Dataset.tenant_id == self.tenant_id,
  44. Dataset.id == self.dataset_id
  45. ).first()
  46. if not dataset:
  47. return f'[{self.name} failed to find dataset with id {self.dataset_id}.]'
  48. if dataset.indexing_technique == "economy":
  49. # use keyword table query
  50. kw_table_index = KeywordTableIndex(
  51. dataset=dataset,
  52. config=KeywordTableConfig(
  53. max_keywords_per_chunk=5
  54. )
  55. )
  56. documents = kw_table_index.search(query, search_kwargs={'k': self.k})
  57. return str("\n".join([document.page_content for document in documents]))
  58. else:
  59. try:
  60. embedding_model = ModelFactory.get_embedding_model(
  61. tenant_id=dataset.tenant_id,
  62. model_provider_name=dataset.embedding_model_provider,
  63. model_name=dataset.embedding_model
  64. )
  65. except LLMBadRequestError:
  66. return ''
  67. except ProviderTokenNotInitError:
  68. return ''
  69. embeddings = CacheEmbedding(embedding_model)
  70. vector_index = VectorIndex(
  71. dataset=dataset,
  72. config=current_app.config,
  73. embeddings=embeddings
  74. )
  75. if self.k > 0:
  76. documents = vector_index.search(
  77. query,
  78. search_type='similarity_score_threshold',
  79. search_kwargs={
  80. 'k': self.k,
  81. 'filter': {
  82. 'group_id': [dataset.id]
  83. }
  84. }
  85. )
  86. else:
  87. documents = []
  88. hit_callback = DatasetIndexToolCallbackHandler(dataset.id, self.conversation_message_task)
  89. hit_callback.on_tool_end(documents)
  90. document_score_list = {}
  91. if dataset.indexing_technique != "economy":
  92. for item in documents:
  93. document_score_list[item.metadata['doc_id']] = item.metadata['score']
  94. document_context_list = []
  95. index_node_ids = [document.metadata['doc_id'] for document in documents]
  96. segments = DocumentSegment.query.filter(DocumentSegment.dataset_id == self.dataset_id,
  97. DocumentSegment.completed_at.isnot(None),
  98. DocumentSegment.status == 'completed',
  99. DocumentSegment.enabled == True,
  100. DocumentSegment.index_node_id.in_(index_node_ids)
  101. ).all()
  102. if segments:
  103. index_node_id_to_position = {id: position for position, id in enumerate(index_node_ids)}
  104. sorted_segments = sorted(segments,
  105. key=lambda segment: index_node_id_to_position.get(segment.index_node_id,
  106. float('inf')))
  107. for segment in sorted_segments:
  108. if segment.answer:
  109. document_context_list.append(f'question:{segment.content} answer:{segment.answer}')
  110. else:
  111. document_context_list.append(segment.content)
  112. if self.return_resource:
  113. context_list = []
  114. resource_number = 1
  115. for segment in sorted_segments:
  116. context = {}
  117. document = Document.query.filter(Document.id == segment.document_id,
  118. Document.enabled == True,
  119. Document.archived == False,
  120. ).first()
  121. if dataset and document:
  122. source = {
  123. 'position': resource_number,
  124. 'dataset_id': dataset.id,
  125. 'dataset_name': dataset.name,
  126. 'document_id': document.id,
  127. 'document_name': document.name,
  128. 'data_source_type': document.data_source_type,
  129. 'segment_id': segment.id,
  130. 'retriever_from': self.retriever_from
  131. }
  132. if dataset.indexing_technique != "economy":
  133. source['score'] = document_score_list.get(segment.index_node_id)
  134. if self.retriever_from == 'dev':
  135. source['hit_count'] = segment.hit_count
  136. source['word_count'] = segment.word_count
  137. source['segment_position'] = segment.position
  138. source['index_node_hash'] = segment.index_node_hash
  139. if segment.answer:
  140. source['content'] = f'question:{segment.content} \nanswer:{segment.answer}'
  141. else:
  142. source['content'] = segment.content
  143. context_list.append(source)
  144. resource_number += 1
  145. hit_callback.return_retriever_resource_info(context_list)
  146. return str("\n".join(document_context_list))
  147. async def _arun(self, tool_input: str) -> str:
  148. raise NotImplementedError()