dataset_retriever_tool.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220
  1. import json
  2. import threading
  3. from typing import Type, Optional, List
  4. from flask import current_app
  5. from langchain.tools import BaseTool
  6. from pydantic import Field, BaseModel
  7. from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
  8. from core.conversation_message_task import ConversationMessageTask
  9. from core.embedding.cached_embedding import CacheEmbedding
  10. from core.index.keyword_table_index.keyword_table_index import KeywordTableIndex, KeywordTableConfig
  11. from core.index.vector_index.vector_index import VectorIndex
  12. from core.model_providers.error import LLMBadRequestError, ProviderTokenNotInitError
  13. from core.model_providers.model_factory import ModelFactory
  14. from extensions.ext_database import db
  15. from models.dataset import Dataset, DocumentSegment, Document
  16. from services.retrieval_service import RetrievalService
  17. default_retrieval_model = {
  18. 'search_method': 'semantic_search',
  19. 'reranking_enable': False,
  20. 'reranking_model': {
  21. 'reranking_provider_name': '',
  22. 'reranking_model_name': ''
  23. },
  24. 'top_k': 2,
  25. 'score_threshold_enable': False
  26. }
  27. class DatasetRetrieverToolInput(BaseModel):
  28. query: str = Field(..., description="Query for the dataset to be used to retrieve the dataset.")
  29. class DatasetRetrieverTool(BaseTool):
  30. """Tool for querying a Dataset."""
  31. name: str = "dataset"
  32. args_schema: Type[BaseModel] = DatasetRetrieverToolInput
  33. description: str = "use this to retrieve a dataset. "
  34. tenant_id: str
  35. dataset_id: str
  36. top_k: int = 2
  37. score_threshold: Optional[float] = None
  38. conversation_message_task: ConversationMessageTask
  39. return_resource: bool
  40. retriever_from: str
  41. @classmethod
  42. def from_dataset(cls, dataset: Dataset, **kwargs):
  43. description = dataset.description
  44. if not description:
  45. description = 'useful for when you want to answer queries about the ' + dataset.name
  46. description = description.replace('\n', '').replace('\r', '')
  47. return cls(
  48. name=f'dataset-{dataset.id}',
  49. tenant_id=dataset.tenant_id,
  50. dataset_id=dataset.id,
  51. description=description,
  52. **kwargs
  53. )
  54. def _run(self, query: str) -> str:
  55. dataset = db.session.query(Dataset).filter(
  56. Dataset.tenant_id == self.tenant_id,
  57. Dataset.id == self.dataset_id
  58. ).first()
  59. if not dataset:
  60. return ''
  61. # get retrieval model , if the model is not setting , using default
  62. retrieval_model = dataset.retrieval_model if dataset.retrieval_model else default_retrieval_model
  63. if dataset.indexing_technique == "economy":
  64. # use keyword table query
  65. kw_table_index = KeywordTableIndex(
  66. dataset=dataset,
  67. config=KeywordTableConfig(
  68. max_keywords_per_chunk=5
  69. )
  70. )
  71. documents = kw_table_index.search(query, search_kwargs={'k': self.top_k})
  72. return str("\n".join([document.page_content for document in documents]))
  73. else:
  74. try:
  75. embedding_model = ModelFactory.get_embedding_model(
  76. tenant_id=dataset.tenant_id,
  77. model_provider_name=dataset.embedding_model_provider,
  78. model_name=dataset.embedding_model
  79. )
  80. except LLMBadRequestError:
  81. return ''
  82. except ProviderTokenNotInitError:
  83. return ''
  84. embeddings = CacheEmbedding(embedding_model)
  85. documents = []
  86. threads = []
  87. if self.top_k > 0:
  88. # retrieval source with semantic
  89. if retrieval_model['search_method'] == 'semantic_search' or retrieval_model['search_method'] == 'hybrid_search':
  90. embedding_thread = threading.Thread(target=RetrievalService.embedding_search, kwargs={
  91. 'flask_app': current_app._get_current_object(),
  92. 'dataset_id': str(dataset.id),
  93. 'query': query,
  94. 'top_k': self.top_k,
  95. 'score_threshold': retrieval_model['score_threshold'] if retrieval_model[
  96. 'score_threshold_enable'] else None,
  97. 'reranking_model': retrieval_model['reranking_model'] if retrieval_model[
  98. 'reranking_enable'] else None,
  99. 'all_documents': documents,
  100. 'search_method': retrieval_model['search_method'],
  101. 'embeddings': embeddings
  102. })
  103. threads.append(embedding_thread)
  104. embedding_thread.start()
  105. # retrieval_model source with full text
  106. if retrieval_model['search_method'] == 'full_text_search' or retrieval_model['search_method'] == 'hybrid_search':
  107. full_text_index_thread = threading.Thread(target=RetrievalService.full_text_index_search, kwargs={
  108. 'flask_app': current_app._get_current_object(),
  109. 'dataset_id': str(dataset.id),
  110. 'query': query,
  111. 'search_method': retrieval_model['search_method'],
  112. 'embeddings': embeddings,
  113. 'score_threshold': retrieval_model['score_threshold'] if retrieval_model[
  114. 'score_threshold_enable'] else None,
  115. 'top_k': self.top_k,
  116. 'reranking_model': retrieval_model['reranking_model'] if retrieval_model[
  117. 'reranking_enable'] else None,
  118. 'all_documents': documents
  119. })
  120. threads.append(full_text_index_thread)
  121. full_text_index_thread.start()
  122. for thread in threads:
  123. thread.join()
  124. # hybrid search: rerank after all documents have been searched
  125. if retrieval_model['search_method'] == 'hybrid_search':
  126. hybrid_rerank = ModelFactory.get_reranking_model(
  127. tenant_id=dataset.tenant_id,
  128. model_provider_name=retrieval_model['reranking_model']['reranking_provider_name'],
  129. model_name=retrieval_model['reranking_model']['reranking_model_name']
  130. )
  131. documents = hybrid_rerank.rerank(query, documents,
  132. retrieval_model['score_threshold'] if retrieval_model['score_threshold_enable'] else None,
  133. self.top_k)
  134. else:
  135. documents = []
  136. hit_callback = DatasetIndexToolCallbackHandler(self.conversation_message_task)
  137. hit_callback.on_tool_end(documents)
  138. document_score_list = {}
  139. if dataset.indexing_technique != "economy":
  140. for item in documents:
  141. document_score_list[item.metadata['doc_id']] = item.metadata['score']
  142. document_context_list = []
  143. index_node_ids = [document.metadata['doc_id'] for document in documents]
  144. segments = DocumentSegment.query.filter(DocumentSegment.dataset_id == self.dataset_id,
  145. DocumentSegment.completed_at.isnot(None),
  146. DocumentSegment.status == 'completed',
  147. DocumentSegment.enabled == True,
  148. DocumentSegment.index_node_id.in_(index_node_ids)
  149. ).all()
  150. if segments:
  151. index_node_id_to_position = {id: position for position, id in enumerate(index_node_ids)}
  152. sorted_segments = sorted(segments,
  153. key=lambda segment: index_node_id_to_position.get(segment.index_node_id,
  154. float('inf')))
  155. for segment in sorted_segments:
  156. if segment.answer:
  157. document_context_list.append(f'question:{segment.content} answer:{segment.answer}')
  158. else:
  159. document_context_list.append(segment.content)
  160. if self.return_resource:
  161. context_list = []
  162. resource_number = 1
  163. for segment in sorted_segments:
  164. context = {}
  165. document = Document.query.filter(Document.id == segment.document_id,
  166. Document.enabled == True,
  167. Document.archived == False,
  168. ).first()
  169. if dataset and document:
  170. source = {
  171. 'position': resource_number,
  172. 'dataset_id': dataset.id,
  173. 'dataset_name': dataset.name,
  174. 'document_id': document.id,
  175. 'document_name': document.name,
  176. 'data_source_type': document.data_source_type,
  177. 'segment_id': segment.id,
  178. 'retriever_from': self.retriever_from,
  179. 'score': document_score_list.get(segment.index_node_id, None)
  180. }
  181. if self.retriever_from == 'dev':
  182. source['hit_count'] = segment.hit_count
  183. source['word_count'] = segment.word_count
  184. source['segment_position'] = segment.position
  185. source['index_node_hash'] = segment.index_node_hash
  186. if segment.answer:
  187. source['content'] = f'question:{segment.content} \nanswer:{segment.answer}'
  188. else:
  189. source['content'] = segment.content
  190. context_list.append(source)
  191. resource_number += 1
  192. hit_callback.return_retriever_resource_info(context_list)
  193. return str("\n".join(document_context_list))
  194. async def _arun(self, tool_input: str) -> str:
  195. raise NotImplementedError()