dataset_retrieval.py 25 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614
  1. import math
  2. import threading
  3. from collections import Counter
  4. from typing import Optional, cast
  5. from flask import Flask, current_app
  6. from core.app.app_config.entities import DatasetEntity, DatasetRetrieveConfigEntity
  7. from core.app.entities.app_invoke_entities import InvokeFrom, ModelConfigWithCredentialsEntity
  8. from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
  9. from core.entities.agent_entities import PlanningStrategy
  10. from core.memory.token_buffer_memory import TokenBufferMemory
  11. from core.model_manager import ModelInstance, ModelManager
  12. from core.model_runtime.entities.message_entities import PromptMessageTool
  13. from core.model_runtime.entities.model_entities import ModelFeature, ModelType
  14. from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
  15. from core.ops.entities.trace_entity import TraceTaskName
  16. from core.ops.ops_trace_manager import TraceQueueManager, TraceTask
  17. from core.ops.utils import measure_time
  18. from core.rag.data_post_processor.data_post_processor import DataPostProcessor
  19. from core.rag.datasource.keyword.jieba.jieba_keyword_table_handler import JiebaKeywordTableHandler
  20. from core.rag.datasource.retrieval_service import RetrievalService
  21. from core.rag.models.document import Document
  22. from core.rag.retrieval.retrieval_methods import RetrievalMethod
  23. from core.rag.retrieval.router.multi_dataset_function_call_router import FunctionCallMultiDatasetRouter
  24. from core.rag.retrieval.router.multi_dataset_react_route import ReactMultiDatasetRouter
  25. from core.tools.utils.dataset_retriever.dataset_retriever_base_tool import DatasetRetrieverBaseTool
  26. from extensions.ext_database import db
  27. from models.dataset import Dataset, DatasetQuery, DocumentSegment
  28. from models.dataset import Document as DatasetDocument
  29. default_retrieval_model = {
  30. "search_method": RetrievalMethod.SEMANTIC_SEARCH.value,
  31. "reranking_enable": False,
  32. "reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""},
  33. "top_k": 2,
  34. "score_threshold_enabled": False,
  35. }
  36. class DatasetRetrieval:
  37. def __init__(self, application_generate_entity=None):
  38. self.application_generate_entity = application_generate_entity
  39. def retrieve(
  40. self,
  41. app_id: str,
  42. user_id: str,
  43. tenant_id: str,
  44. model_config: ModelConfigWithCredentialsEntity,
  45. config: DatasetEntity,
  46. query: str,
  47. invoke_from: InvokeFrom,
  48. show_retrieve_source: bool,
  49. hit_callback: DatasetIndexToolCallbackHandler,
  50. message_id: str,
  51. memory: Optional[TokenBufferMemory] = None,
  52. ) -> Optional[str]:
  53. """
  54. Retrieve dataset.
  55. :param app_id: app_id
  56. :param user_id: user_id
  57. :param tenant_id: tenant id
  58. :param model_config: model config
  59. :param config: dataset config
  60. :param query: query
  61. :param invoke_from: invoke from
  62. :param show_retrieve_source: show retrieve source
  63. :param hit_callback: hit callback
  64. :param message_id: message id
  65. :param memory: memory
  66. :return:
  67. """
  68. dataset_ids = config.dataset_ids
  69. if len(dataset_ids) == 0:
  70. return None
  71. retrieve_config = config.retrieve_config
  72. # check model is support tool calling
  73. model_type_instance = model_config.provider_model_bundle.model_type_instance
  74. model_type_instance = cast(LargeLanguageModel, model_type_instance)
  75. model_manager = ModelManager()
  76. model_instance = model_manager.get_model_instance(
  77. tenant_id=tenant_id, model_type=ModelType.LLM, provider=model_config.provider, model=model_config.model
  78. )
  79. # get model schema
  80. model_schema = model_type_instance.get_model_schema(
  81. model=model_config.model, credentials=model_config.credentials
  82. )
  83. if not model_schema:
  84. return None
  85. planning_strategy = PlanningStrategy.REACT_ROUTER
  86. features = model_schema.features
  87. if features:
  88. if ModelFeature.TOOL_CALL in features or ModelFeature.MULTI_TOOL_CALL in features:
  89. planning_strategy = PlanningStrategy.ROUTER
  90. available_datasets = []
  91. for dataset_id in dataset_ids:
  92. # get dataset from dataset id
  93. dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
  94. # pass if dataset is not available
  95. if not dataset:
  96. continue
  97. # pass if dataset is not available
  98. if dataset and dataset.available_document_count == 0 and dataset.available_document_count == 0:
  99. continue
  100. available_datasets.append(dataset)
  101. all_documents = []
  102. user_from = "account" if invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER} else "end_user"
  103. if retrieve_config.retrieve_strategy == DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE:
  104. all_documents = self.single_retrieve(
  105. app_id,
  106. tenant_id,
  107. user_id,
  108. user_from,
  109. available_datasets,
  110. query,
  111. model_instance,
  112. model_config,
  113. planning_strategy,
  114. message_id,
  115. )
  116. elif retrieve_config.retrieve_strategy == DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE:
  117. all_documents = self.multiple_retrieve(
  118. app_id,
  119. tenant_id,
  120. user_id,
  121. user_from,
  122. available_datasets,
  123. query,
  124. retrieve_config.top_k,
  125. retrieve_config.score_threshold,
  126. retrieve_config.rerank_mode,
  127. retrieve_config.reranking_model,
  128. retrieve_config.weights,
  129. retrieve_config.reranking_enabled,
  130. message_id,
  131. )
  132. document_score_list = {}
  133. for item in all_documents:
  134. if item.metadata.get("score"):
  135. document_score_list[item.metadata["doc_id"]] = item.metadata["score"]
  136. document_context_list = []
  137. index_node_ids = [document.metadata["doc_id"] for document in all_documents]
  138. segments = DocumentSegment.query.filter(
  139. DocumentSegment.dataset_id.in_(dataset_ids),
  140. DocumentSegment.completed_at.isnot(None),
  141. DocumentSegment.status == "completed",
  142. DocumentSegment.enabled == True,
  143. DocumentSegment.index_node_id.in_(index_node_ids),
  144. ).all()
  145. if segments:
  146. index_node_id_to_position = {id: position for position, id in enumerate(index_node_ids)}
  147. sorted_segments = sorted(
  148. segments, key=lambda segment: index_node_id_to_position.get(segment.index_node_id, float("inf"))
  149. )
  150. for segment in sorted_segments:
  151. if segment.answer:
  152. document_context_list.append(f"question:{segment.get_sign_content()} answer:{segment.answer}")
  153. else:
  154. document_context_list.append(segment.get_sign_content())
  155. if show_retrieve_source:
  156. context_list = []
  157. resource_number = 1
  158. for segment in sorted_segments:
  159. dataset = Dataset.query.filter_by(id=segment.dataset_id).first()
  160. document = DatasetDocument.query.filter(
  161. DatasetDocument.id == segment.document_id,
  162. DatasetDocument.enabled == True,
  163. DatasetDocument.archived == False,
  164. ).first()
  165. if dataset and document:
  166. source = {
  167. "position": resource_number,
  168. "dataset_id": dataset.id,
  169. "dataset_name": dataset.name,
  170. "document_id": document.id,
  171. "document_name": document.name,
  172. "data_source_type": document.data_source_type,
  173. "segment_id": segment.id,
  174. "retriever_from": invoke_from.to_source(),
  175. "score": document_score_list.get(segment.index_node_id, None),
  176. }
  177. if invoke_from.to_source() == "dev":
  178. source["hit_count"] = segment.hit_count
  179. source["word_count"] = segment.word_count
  180. source["segment_position"] = segment.position
  181. source["index_node_hash"] = segment.index_node_hash
  182. if segment.answer:
  183. source["content"] = f"question:{segment.content} \nanswer:{segment.answer}"
  184. else:
  185. source["content"] = segment.content
  186. context_list.append(source)
  187. resource_number += 1
  188. if hit_callback:
  189. hit_callback.return_retriever_resource_info(context_list)
  190. return str("\n".join(document_context_list))
  191. return ""
  192. def single_retrieve(
  193. self,
  194. app_id: str,
  195. tenant_id: str,
  196. user_id: str,
  197. user_from: str,
  198. available_datasets: list,
  199. query: str,
  200. model_instance: ModelInstance,
  201. model_config: ModelConfigWithCredentialsEntity,
  202. planning_strategy: PlanningStrategy,
  203. message_id: Optional[str] = None,
  204. ):
  205. tools = []
  206. for dataset in available_datasets:
  207. description = dataset.description
  208. if not description:
  209. description = "useful for when you want to answer queries about the " + dataset.name
  210. description = description.replace("\n", "").replace("\r", "")
  211. message_tool = PromptMessageTool(
  212. name=dataset.id,
  213. description=description,
  214. parameters={
  215. "type": "object",
  216. "properties": {},
  217. "required": [],
  218. },
  219. )
  220. tools.append(message_tool)
  221. dataset_id = None
  222. if planning_strategy == PlanningStrategy.REACT_ROUTER:
  223. react_multi_dataset_router = ReactMultiDatasetRouter()
  224. dataset_id = react_multi_dataset_router.invoke(
  225. query, tools, model_config, model_instance, user_id, tenant_id
  226. )
  227. elif planning_strategy == PlanningStrategy.ROUTER:
  228. function_call_router = FunctionCallMultiDatasetRouter()
  229. dataset_id = function_call_router.invoke(query, tools, model_config, model_instance)
  230. if dataset_id:
  231. # get retrieval model config
  232. dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first()
  233. if dataset:
  234. retrieval_model_config = dataset.retrieval_model or default_retrieval_model
  235. # get top k
  236. top_k = retrieval_model_config["top_k"]
  237. # get retrieval method
  238. if dataset.indexing_technique == "economy":
  239. retrieval_method = "keyword_search"
  240. else:
  241. retrieval_method = retrieval_model_config["search_method"]
  242. # get reranking model
  243. reranking_model = (
  244. retrieval_model_config["reranking_model"] if retrieval_model_config["reranking_enable"] else None
  245. )
  246. # get score threshold
  247. score_threshold = 0.0
  248. score_threshold_enabled = retrieval_model_config.get("score_threshold_enabled")
  249. if score_threshold_enabled:
  250. score_threshold = retrieval_model_config.get("score_threshold")
  251. with measure_time() as timer:
  252. results = RetrievalService.retrieve(
  253. retrieval_method=retrieval_method,
  254. dataset_id=dataset.id,
  255. query=query,
  256. top_k=top_k,
  257. score_threshold=score_threshold,
  258. reranking_model=reranking_model,
  259. reranking_mode=retrieval_model_config.get("reranking_mode", "reranking_model"),
  260. weights=retrieval_model_config.get("weights", None),
  261. )
  262. self._on_query(query, [dataset_id], app_id, user_from, user_id)
  263. if results:
  264. self._on_retrieval_end(results, message_id, timer)
  265. return results
  266. return []
  267. def multiple_retrieve(
  268. self,
  269. app_id: str,
  270. tenant_id: str,
  271. user_id: str,
  272. user_from: str,
  273. available_datasets: list,
  274. query: str,
  275. top_k: int,
  276. score_threshold: float,
  277. reranking_mode: str,
  278. reranking_model: Optional[dict] = None,
  279. weights: Optional[dict] = None,
  280. reranking_enable: bool = True,
  281. message_id: Optional[str] = None,
  282. ):
  283. threads = []
  284. all_documents = []
  285. dataset_ids = [dataset.id for dataset in available_datasets]
  286. index_type = None
  287. for dataset in available_datasets:
  288. index_type = dataset.indexing_technique
  289. retrieval_thread = threading.Thread(
  290. target=self._retriever,
  291. kwargs={
  292. "flask_app": current_app._get_current_object(),
  293. "dataset_id": dataset.id,
  294. "query": query,
  295. "top_k": top_k,
  296. "all_documents": all_documents,
  297. },
  298. )
  299. threads.append(retrieval_thread)
  300. retrieval_thread.start()
  301. for thread in threads:
  302. thread.join()
  303. with measure_time() as timer:
  304. if reranking_enable:
  305. # do rerank for searched documents
  306. data_post_processor = DataPostProcessor(tenant_id, reranking_mode, reranking_model, weights, False)
  307. all_documents = data_post_processor.invoke(
  308. query=query, documents=all_documents, score_threshold=score_threshold, top_n=top_k
  309. )
  310. else:
  311. if index_type == "economy":
  312. all_documents = self.calculate_keyword_score(query, all_documents, top_k)
  313. elif index_type == "high_quality":
  314. all_documents = self.calculate_vector_score(all_documents, top_k, score_threshold)
  315. self._on_query(query, dataset_ids, app_id, user_from, user_id)
  316. if all_documents:
  317. self._on_retrieval_end(all_documents, message_id, timer)
  318. return all_documents
  319. def _on_retrieval_end(
  320. self, documents: list[Document], message_id: Optional[str] = None, timer: Optional[dict] = None
  321. ) -> None:
  322. """Handle retrieval end."""
  323. for document in documents:
  324. query = db.session.query(DocumentSegment).filter(
  325. DocumentSegment.index_node_id == document.metadata["doc_id"]
  326. )
  327. # if 'dataset_id' in document.metadata:
  328. if "dataset_id" in document.metadata:
  329. query = query.filter(DocumentSegment.dataset_id == document.metadata["dataset_id"])
  330. # add hit count to document segment
  331. query.update({DocumentSegment.hit_count: DocumentSegment.hit_count + 1}, synchronize_session=False)
  332. db.session.commit()
  333. # get tracing instance
  334. trace_manager: TraceQueueManager | None = (
  335. self.application_generate_entity.trace_manager if self.application_generate_entity else None
  336. )
  337. if trace_manager:
  338. trace_manager.add_trace_task(
  339. TraceTask(
  340. TraceTaskName.DATASET_RETRIEVAL_TRACE, message_id=message_id, documents=documents, timer=timer
  341. )
  342. )
  343. def _on_query(self, query: str, dataset_ids: list[str], app_id: str, user_from: str, user_id: str) -> None:
  344. """
  345. Handle query.
  346. """
  347. if not query:
  348. return
  349. dataset_queries = []
  350. for dataset_id in dataset_ids:
  351. dataset_query = DatasetQuery(
  352. dataset_id=dataset_id,
  353. content=query,
  354. source="app",
  355. source_app_id=app_id,
  356. created_by_role=user_from,
  357. created_by=user_id,
  358. )
  359. dataset_queries.append(dataset_query)
  360. if dataset_queries:
  361. db.session.add_all(dataset_queries)
  362. db.session.commit()
  363. def _retriever(self, flask_app: Flask, dataset_id: str, query: str, top_k: int, all_documents: list):
  364. with flask_app.app_context():
  365. dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first()
  366. if not dataset:
  367. return []
  368. # get retrieval model , if the model is not setting , using default
  369. retrieval_model = dataset.retrieval_model or default_retrieval_model
  370. if dataset.indexing_technique == "economy":
  371. # use keyword table query
  372. documents = RetrievalService.retrieve(
  373. retrieval_method="keyword_search", dataset_id=dataset.id, query=query, top_k=top_k
  374. )
  375. if documents:
  376. all_documents.extend(documents)
  377. else:
  378. if top_k > 0:
  379. # retrieval source
  380. documents = RetrievalService.retrieve(
  381. retrieval_method=retrieval_model["search_method"],
  382. dataset_id=dataset.id,
  383. query=query,
  384. top_k=retrieval_model.get("top_k") or 2,
  385. score_threshold=retrieval_model.get("score_threshold", 0.0)
  386. if retrieval_model["score_threshold_enabled"]
  387. else 0.0,
  388. reranking_model=retrieval_model.get("reranking_model", None)
  389. if retrieval_model["reranking_enable"]
  390. else None,
  391. reranking_mode=retrieval_model.get("reranking_mode") or "reranking_model",
  392. weights=retrieval_model.get("weights", None),
  393. )
  394. all_documents.extend(documents)
  395. def to_dataset_retriever_tool(
  396. self,
  397. tenant_id: str,
  398. dataset_ids: list[str],
  399. retrieve_config: DatasetRetrieveConfigEntity,
  400. return_resource: bool,
  401. invoke_from: InvokeFrom,
  402. hit_callback: DatasetIndexToolCallbackHandler,
  403. ) -> Optional[list[DatasetRetrieverBaseTool]]:
  404. """
  405. A dataset tool is a tool that can be used to retrieve information from a dataset
  406. :param tenant_id: tenant id
  407. :param dataset_ids: dataset ids
  408. :param retrieve_config: retrieve config
  409. :param return_resource: return resource
  410. :param invoke_from: invoke from
  411. :param hit_callback: hit callback
  412. """
  413. tools = []
  414. available_datasets = []
  415. for dataset_id in dataset_ids:
  416. # get dataset from dataset id
  417. dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
  418. # pass if dataset is not available
  419. if not dataset:
  420. continue
  421. # pass if dataset is not available
  422. if dataset and dataset.available_document_count == 0 and dataset.available_document_count == 0:
  423. continue
  424. available_datasets.append(dataset)
  425. if retrieve_config.retrieve_strategy == DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE:
  426. # get retrieval model config
  427. default_retrieval_model = {
  428. "search_method": RetrievalMethod.SEMANTIC_SEARCH.value,
  429. "reranking_enable": False,
  430. "reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""},
  431. "top_k": 2,
  432. "score_threshold_enabled": False,
  433. }
  434. for dataset in available_datasets:
  435. retrieval_model_config = dataset.retrieval_model or default_retrieval_model
  436. # get top k
  437. top_k = retrieval_model_config["top_k"]
  438. # get score threshold
  439. score_threshold = None
  440. score_threshold_enabled = retrieval_model_config.get("score_threshold_enabled")
  441. if score_threshold_enabled:
  442. score_threshold = retrieval_model_config.get("score_threshold")
  443. from core.tools.utils.dataset_retriever.dataset_retriever_tool import DatasetRetrieverTool
  444. tool = DatasetRetrieverTool.from_dataset(
  445. dataset=dataset,
  446. top_k=top_k,
  447. score_threshold=score_threshold,
  448. hit_callbacks=[hit_callback],
  449. return_resource=return_resource,
  450. retriever_from=invoke_from.to_source(),
  451. )
  452. tools.append(tool)
  453. elif retrieve_config.retrieve_strategy == DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE:
  454. from core.tools.utils.dataset_retriever.dataset_multi_retriever_tool import DatasetMultiRetrieverTool
  455. tool = DatasetMultiRetrieverTool.from_dataset(
  456. dataset_ids=[dataset.id for dataset in available_datasets],
  457. tenant_id=tenant_id,
  458. top_k=retrieve_config.top_k or 2,
  459. score_threshold=retrieve_config.score_threshold,
  460. hit_callbacks=[hit_callback],
  461. return_resource=return_resource,
  462. retriever_from=invoke_from.to_source(),
  463. reranking_provider_name=retrieve_config.reranking_model.get("reranking_provider_name"),
  464. reranking_model_name=retrieve_config.reranking_model.get("reranking_model_name"),
  465. )
  466. tools.append(tool)
  467. return tools
  468. def calculate_keyword_score(self, query: str, documents: list[Document], top_k: int) -> list[Document]:
  469. """
  470. Calculate keywords scores
  471. :param query: search query
  472. :param documents: documents for reranking
  473. :return:
  474. """
  475. keyword_table_handler = JiebaKeywordTableHandler()
  476. query_keywords = keyword_table_handler.extract_keywords(query, None)
  477. documents_keywords = []
  478. for document in documents:
  479. # get the document keywords
  480. document_keywords = keyword_table_handler.extract_keywords(document.page_content, None)
  481. document.metadata["keywords"] = document_keywords
  482. documents_keywords.append(document_keywords)
  483. # Counter query keywords(TF)
  484. query_keyword_counts = Counter(query_keywords)
  485. # total documents
  486. total_documents = len(documents)
  487. # calculate all documents' keywords IDF
  488. all_keywords = set()
  489. for document_keywords in documents_keywords:
  490. all_keywords.update(document_keywords)
  491. keyword_idf = {}
  492. for keyword in all_keywords:
  493. # calculate include query keywords' documents
  494. doc_count_containing_keyword = sum(1 for doc_keywords in documents_keywords if keyword in doc_keywords)
  495. # IDF
  496. keyword_idf[keyword] = math.log((1 + total_documents) / (1 + doc_count_containing_keyword)) + 1
  497. query_tfidf = {}
  498. for keyword, count in query_keyword_counts.items():
  499. tf = count
  500. idf = keyword_idf.get(keyword, 0)
  501. query_tfidf[keyword] = tf * idf
  502. # calculate all documents' TF-IDF
  503. documents_tfidf = []
  504. for document_keywords in documents_keywords:
  505. document_keyword_counts = Counter(document_keywords)
  506. document_tfidf = {}
  507. for keyword, count in document_keyword_counts.items():
  508. tf = count
  509. idf = keyword_idf.get(keyword, 0)
  510. document_tfidf[keyword] = tf * idf
  511. documents_tfidf.append(document_tfidf)
  512. def cosine_similarity(vec1, vec2):
  513. intersection = set(vec1.keys()) & set(vec2.keys())
  514. numerator = sum(vec1[x] * vec2[x] for x in intersection)
  515. sum1 = sum(vec1[x] ** 2 for x in vec1)
  516. sum2 = sum(vec2[x] ** 2 for x in vec2)
  517. denominator = math.sqrt(sum1) * math.sqrt(sum2)
  518. if not denominator:
  519. return 0.0
  520. else:
  521. return float(numerator) / denominator
  522. similarities = []
  523. for document_tfidf in documents_tfidf:
  524. similarity = cosine_similarity(query_tfidf, document_tfidf)
  525. similarities.append(similarity)
  526. for document, score in zip(documents, similarities):
  527. # format document
  528. document.metadata["score"] = score
  529. documents = sorted(documents, key=lambda x: x.metadata["score"], reverse=True)
  530. return documents[:top_k] if top_k else documents
  531. def calculate_vector_score(
  532. self, all_documents: list[Document], top_k: int, score_threshold: float
  533. ) -> list[Document]:
  534. filter_documents = []
  535. for document in all_documents:
  536. if score_threshold is None or document.metadata["score"] >= score_threshold:
  537. filter_documents.append(document)
  538. if not filter_documents:
  539. return []
  540. filter_documents = sorted(filter_documents, key=lambda x: x.metadata["score"], reverse=True)
  541. return filter_documents[:top_k] if top_k else filter_documents