dataset_retrieval.py 31 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717
  1. import math
  2. import threading
  3. from collections import Counter
  4. from typing import Any, Optional, cast
  5. from flask import Flask, current_app
  6. from core.app.app_config.entities import DatasetEntity, DatasetRetrieveConfigEntity
  7. from core.app.entities.app_invoke_entities import InvokeFrom, ModelConfigWithCredentialsEntity
  8. from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
  9. from core.entities.agent_entities import PlanningStrategy
  10. from core.memory.token_buffer_memory import TokenBufferMemory
  11. from core.model_manager import ModelInstance, ModelManager
  12. from core.model_runtime.entities.message_entities import PromptMessageTool
  13. from core.model_runtime.entities.model_entities import ModelFeature, ModelType
  14. from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
  15. from core.ops.entities.trace_entity import TraceTaskName
  16. from core.ops.ops_trace_manager import TraceQueueManager, TraceTask
  17. from core.ops.utils import measure_time
  18. from core.rag.data_post_processor.data_post_processor import DataPostProcessor
  19. from core.rag.datasource.keyword.jieba.jieba_keyword_table_handler import JiebaKeywordTableHandler
  20. from core.rag.datasource.retrieval_service import RetrievalService
  21. from core.rag.entities.context_entities import DocumentContext
  22. from core.rag.models.document import Document
  23. from core.rag.rerank.rerank_type import RerankMode
  24. from core.rag.retrieval.retrieval_methods import RetrievalMethod
  25. from core.rag.retrieval.router.multi_dataset_function_call_router import FunctionCallMultiDatasetRouter
  26. from core.rag.retrieval.router.multi_dataset_react_route import ReactMultiDatasetRouter
  27. from core.tools.utils.dataset_retriever.dataset_retriever_base_tool import DatasetRetrieverBaseTool
  28. from extensions.ext_database import db
  29. from models.dataset import Dataset, DatasetQuery, DocumentSegment
  30. from models.dataset import Document as DatasetDocument
  31. from services.external_knowledge_service import ExternalDatasetService
  32. default_retrieval_model: dict[str, Any] = {
  33. "search_method": RetrievalMethod.SEMANTIC_SEARCH.value,
  34. "reranking_enable": False,
  35. "reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""},
  36. "top_k": 2,
  37. "score_threshold_enabled": False,
  38. }
  39. class DatasetRetrieval:
  40. def __init__(self, application_generate_entity=None):
  41. self.application_generate_entity = application_generate_entity
  42. def retrieve(
  43. self,
  44. app_id: str,
  45. user_id: str,
  46. tenant_id: str,
  47. model_config: ModelConfigWithCredentialsEntity,
  48. config: DatasetEntity,
  49. query: str,
  50. invoke_from: InvokeFrom,
  51. show_retrieve_source: bool,
  52. hit_callback: DatasetIndexToolCallbackHandler,
  53. message_id: str,
  54. memory: Optional[TokenBufferMemory] = None,
  55. ) -> Optional[str]:
  56. """
  57. Retrieve dataset.
  58. :param app_id: app_id
  59. :param user_id: user_id
  60. :param tenant_id: tenant id
  61. :param model_config: model config
  62. :param config: dataset config
  63. :param query: query
  64. :param invoke_from: invoke from
  65. :param show_retrieve_source: show retrieve source
  66. :param hit_callback: hit callback
  67. :param message_id: message id
  68. :param memory: memory
  69. :return:
  70. """
  71. dataset_ids = config.dataset_ids
  72. if len(dataset_ids) == 0:
  73. return None
  74. retrieve_config = config.retrieve_config
  75. # check model is support tool calling
  76. model_type_instance = model_config.provider_model_bundle.model_type_instance
  77. model_type_instance = cast(LargeLanguageModel, model_type_instance)
  78. model_manager = ModelManager()
  79. model_instance = model_manager.get_model_instance(
  80. tenant_id=tenant_id, model_type=ModelType.LLM, provider=model_config.provider, model=model_config.model
  81. )
  82. # get model schema
  83. model_schema = model_type_instance.get_model_schema(
  84. model=model_config.model, credentials=model_config.credentials
  85. )
  86. if not model_schema:
  87. return None
  88. planning_strategy = PlanningStrategy.REACT_ROUTER
  89. features = model_schema.features
  90. if features:
  91. if ModelFeature.TOOL_CALL in features or ModelFeature.MULTI_TOOL_CALL in features:
  92. planning_strategy = PlanningStrategy.ROUTER
  93. available_datasets = []
  94. for dataset_id in dataset_ids:
  95. # get dataset from dataset id
  96. dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
  97. # pass if dataset is not available
  98. if not dataset:
  99. continue
  100. # pass if dataset is not available
  101. if dataset and dataset.available_document_count == 0 and dataset.provider != "external":
  102. continue
  103. available_datasets.append(dataset)
  104. all_documents = []
  105. user_from = "account" if invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER} else "end_user"
  106. if retrieve_config.retrieve_strategy == DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE:
  107. all_documents = self.single_retrieve(
  108. app_id,
  109. tenant_id,
  110. user_id,
  111. user_from,
  112. available_datasets,
  113. query,
  114. model_instance,
  115. model_config,
  116. planning_strategy,
  117. message_id,
  118. )
  119. elif retrieve_config.retrieve_strategy == DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE:
  120. all_documents = self.multiple_retrieve(
  121. app_id,
  122. tenant_id,
  123. user_id,
  124. user_from,
  125. available_datasets,
  126. query,
  127. retrieve_config.top_k or 0,
  128. retrieve_config.score_threshold or 0,
  129. retrieve_config.rerank_mode or "reranking_model",
  130. retrieve_config.reranking_model,
  131. retrieve_config.weights,
  132. retrieve_config.reranking_enabled or True,
  133. message_id,
  134. )
  135. dify_documents = [item for item in all_documents if item.provider == "dify"]
  136. external_documents = [item for item in all_documents if item.provider == "external"]
  137. document_context_list = []
  138. retrieval_resource_list = []
  139. # deal with external documents
  140. for item in external_documents:
  141. document_context_list.append(DocumentContext(content=item.page_content, score=item.metadata.get("score")))
  142. source = {
  143. "dataset_id": item.metadata.get("dataset_id"),
  144. "dataset_name": item.metadata.get("dataset_name"),
  145. "document_name": item.metadata.get("title"),
  146. "data_source_type": "external",
  147. "retriever_from": invoke_from.to_source(),
  148. "score": item.metadata.get("score"),
  149. "content": item.page_content,
  150. }
  151. retrieval_resource_list.append(source)
  152. # deal with dify documents
  153. if dify_documents:
  154. records = RetrievalService.format_retrieval_documents(dify_documents)
  155. if records:
  156. for record in records:
  157. segment = record.segment
  158. if segment.answer:
  159. document_context_list.append(
  160. DocumentContext(
  161. content=f"question:{segment.get_sign_content()} answer:{segment.answer}",
  162. score=record.score,
  163. )
  164. )
  165. else:
  166. document_context_list.append(
  167. DocumentContext(
  168. content=segment.get_sign_content(),
  169. score=record.score,
  170. )
  171. )
  172. if show_retrieve_source:
  173. for record in records:
  174. segment = record.segment
  175. dataset = Dataset.query.filter_by(id=segment.dataset_id).first()
  176. document = DatasetDocument.query.filter(
  177. DatasetDocument.id == segment.document_id,
  178. DatasetDocument.enabled == True,
  179. DatasetDocument.archived == False,
  180. ).first()
  181. if dataset and document:
  182. source = {
  183. "dataset_id": dataset.id,
  184. "dataset_name": dataset.name,
  185. "document_id": document.id,
  186. "document_name": document.name,
  187. "data_source_type": document.data_source_type,
  188. "segment_id": segment.id,
  189. "retriever_from": invoke_from.to_source(),
  190. "score": record.score or 0.0,
  191. "doc_metadata": document.doc_metadata,
  192. }
  193. if invoke_from.to_source() == "dev":
  194. source["hit_count"] = segment.hit_count
  195. source["word_count"] = segment.word_count
  196. source["segment_position"] = segment.position
  197. source["index_node_hash"] = segment.index_node_hash
  198. if segment.answer:
  199. source["content"] = f"question:{segment.content} \nanswer:{segment.answer}"
  200. else:
  201. source["content"] = segment.content
  202. retrieval_resource_list.append(source)
  203. if hit_callback and retrieval_resource_list:
  204. retrieval_resource_list = sorted(retrieval_resource_list, key=lambda x: x.get("score") or 0.0, reverse=True)
  205. for position, item in enumerate(retrieval_resource_list, start=1):
  206. item["position"] = position
  207. hit_callback.return_retriever_resource_info(retrieval_resource_list)
  208. if document_context_list:
  209. document_context_list = sorted(document_context_list, key=lambda x: x.score or 0.0, reverse=True)
  210. return str("\n".join([document_context.content for document_context in document_context_list]))
  211. return ""
  212. def single_retrieve(
  213. self,
  214. app_id: str,
  215. tenant_id: str,
  216. user_id: str,
  217. user_from: str,
  218. available_datasets: list,
  219. query: str,
  220. model_instance: ModelInstance,
  221. model_config: ModelConfigWithCredentialsEntity,
  222. planning_strategy: PlanningStrategy,
  223. message_id: Optional[str] = None,
  224. ):
  225. tools = []
  226. for dataset in available_datasets:
  227. description = dataset.description
  228. if not description:
  229. description = "useful for when you want to answer queries about the " + dataset.name
  230. description = description.replace("\n", "").replace("\r", "")
  231. message_tool = PromptMessageTool(
  232. name=dataset.id,
  233. description=description,
  234. parameters={
  235. "type": "object",
  236. "properties": {},
  237. "required": [],
  238. },
  239. )
  240. tools.append(message_tool)
  241. dataset_id = None
  242. if planning_strategy == PlanningStrategy.REACT_ROUTER:
  243. react_multi_dataset_router = ReactMultiDatasetRouter()
  244. dataset_id = react_multi_dataset_router.invoke(
  245. query, tools, model_config, model_instance, user_id, tenant_id
  246. )
  247. elif planning_strategy == PlanningStrategy.ROUTER:
  248. function_call_router = FunctionCallMultiDatasetRouter()
  249. dataset_id = function_call_router.invoke(query, tools, model_config, model_instance)
  250. if dataset_id:
  251. # get retrieval model config
  252. dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first()
  253. if dataset:
  254. results = []
  255. if dataset.provider == "external":
  256. external_documents = ExternalDatasetService.fetch_external_knowledge_retrieval(
  257. tenant_id=dataset.tenant_id,
  258. dataset_id=dataset_id,
  259. query=query,
  260. external_retrieval_parameters=dataset.retrieval_model,
  261. )
  262. for external_document in external_documents:
  263. document = Document(
  264. page_content=external_document.get("content"),
  265. metadata=external_document.get("metadata"),
  266. provider="external",
  267. )
  268. if document.metadata is not None:
  269. document.metadata["score"] = external_document.get("score")
  270. document.metadata["title"] = external_document.get("title")
  271. document.metadata["dataset_id"] = dataset_id
  272. document.metadata["dataset_name"] = dataset.name
  273. results.append(document)
  274. else:
  275. retrieval_model_config = dataset.retrieval_model or default_retrieval_model
  276. # get top k
  277. top_k = retrieval_model_config["top_k"]
  278. # get retrieval method
  279. if dataset.indexing_technique == "economy":
  280. retrieval_method = "keyword_search"
  281. else:
  282. retrieval_method = retrieval_model_config["search_method"]
  283. # get reranking model
  284. reranking_model = (
  285. retrieval_model_config["reranking_model"]
  286. if retrieval_model_config["reranking_enable"]
  287. else None
  288. )
  289. # get score threshold
  290. score_threshold = 0.0
  291. score_threshold_enabled = retrieval_model_config.get("score_threshold_enabled")
  292. if score_threshold_enabled:
  293. score_threshold = retrieval_model_config.get("score_threshold", 0.0)
  294. with measure_time() as timer:
  295. results = RetrievalService.retrieve(
  296. retrieval_method=retrieval_method,
  297. dataset_id=dataset.id,
  298. query=query,
  299. top_k=top_k,
  300. score_threshold=score_threshold,
  301. reranking_model=reranking_model,
  302. reranking_mode=retrieval_model_config.get("reranking_mode", "reranking_model"),
  303. weights=retrieval_model_config.get("weights", None),
  304. )
  305. self._on_query(query, [dataset_id], app_id, user_from, user_id)
  306. if results:
  307. self._on_retrieval_end(results, message_id, timer)
  308. return results
  309. return []
  310. def multiple_retrieve(
  311. self,
  312. app_id: str,
  313. tenant_id: str,
  314. user_id: str,
  315. user_from: str,
  316. available_datasets: list,
  317. query: str,
  318. top_k: int,
  319. score_threshold: float,
  320. reranking_mode: str,
  321. reranking_model: Optional[dict] = None,
  322. weights: Optional[dict[str, Any]] = None,
  323. reranking_enable: bool = True,
  324. message_id: Optional[str] = None,
  325. ):
  326. if not available_datasets:
  327. return []
  328. threads = []
  329. all_documents: list[Document] = []
  330. dataset_ids = [dataset.id for dataset in available_datasets]
  331. index_type_check = all(
  332. item.indexing_technique == available_datasets[0].indexing_technique for item in available_datasets
  333. )
  334. if not index_type_check and (not reranking_enable or reranking_mode != RerankMode.RERANKING_MODEL):
  335. raise ValueError(
  336. "The configured knowledge base list have different indexing technique, please set reranking model."
  337. )
  338. index_type = available_datasets[0].indexing_technique
  339. if index_type == "high_quality":
  340. embedding_model_check = all(
  341. item.embedding_model == available_datasets[0].embedding_model for item in available_datasets
  342. )
  343. embedding_model_provider_check = all(
  344. item.embedding_model_provider == available_datasets[0].embedding_model_provider
  345. for item in available_datasets
  346. )
  347. if (
  348. reranking_enable
  349. and reranking_mode == "weighted_score"
  350. and (not embedding_model_check or not embedding_model_provider_check)
  351. ):
  352. raise ValueError(
  353. "The configured knowledge base list have different embedding model, please set reranking model."
  354. )
  355. if reranking_enable and reranking_mode == RerankMode.WEIGHTED_SCORE:
  356. if weights is not None:
  357. weights["vector_setting"]["embedding_provider_name"] = available_datasets[
  358. 0
  359. ].embedding_model_provider
  360. weights["vector_setting"]["embedding_model_name"] = available_datasets[0].embedding_model
  361. for dataset in available_datasets:
  362. index_type = dataset.indexing_technique
  363. retrieval_thread = threading.Thread(
  364. target=self._retriever,
  365. kwargs={
  366. "flask_app": current_app._get_current_object(), # type: ignore
  367. "dataset_id": dataset.id,
  368. "query": query,
  369. "top_k": top_k,
  370. "all_documents": all_documents,
  371. },
  372. )
  373. threads.append(retrieval_thread)
  374. retrieval_thread.start()
  375. for thread in threads:
  376. thread.join()
  377. with measure_time() as timer:
  378. if reranking_enable:
  379. # do rerank for searched documents
  380. data_post_processor = DataPostProcessor(tenant_id, reranking_mode, reranking_model, weights, False)
  381. all_documents = data_post_processor.invoke(
  382. query=query, documents=all_documents, score_threshold=score_threshold, top_n=top_k
  383. )
  384. else:
  385. if index_type == "economy":
  386. all_documents = self.calculate_keyword_score(query, all_documents, top_k)
  387. elif index_type == "high_quality":
  388. all_documents = self.calculate_vector_score(all_documents, top_k, score_threshold)
  389. self._on_query(query, dataset_ids, app_id, user_from, user_id)
  390. if all_documents:
  391. self._on_retrieval_end(all_documents, message_id, timer)
  392. return all_documents
  393. def _on_retrieval_end(
  394. self, documents: list[Document], message_id: Optional[str] = None, timer: Optional[dict] = None
  395. ) -> None:
  396. """Handle retrieval end."""
  397. dify_documents = [document for document in documents if document.provider == "dify"]
  398. for document in dify_documents:
  399. if document.metadata is not None:
  400. query = db.session.query(DocumentSegment).filter(
  401. DocumentSegment.index_node_id == document.metadata["doc_id"]
  402. )
  403. # if 'dataset_id' in document.metadata:
  404. if "dataset_id" in document.metadata:
  405. query = query.filter(DocumentSegment.dataset_id == document.metadata["dataset_id"])
  406. # add hit count to document segment
  407. query.update({DocumentSegment.hit_count: DocumentSegment.hit_count + 1}, synchronize_session=False)
  408. db.session.commit()
  409. # get tracing instance
  410. trace_manager: TraceQueueManager | None = (
  411. self.application_generate_entity.trace_manager if self.application_generate_entity else None
  412. )
  413. if trace_manager:
  414. trace_manager.add_trace_task(
  415. TraceTask(
  416. TraceTaskName.DATASET_RETRIEVAL_TRACE, message_id=message_id, documents=documents, timer=timer
  417. )
  418. )
  419. def _on_query(self, query: str, dataset_ids: list[str], app_id: str, user_from: str, user_id: str) -> None:
  420. """
  421. Handle query.
  422. """
  423. if not query:
  424. return
  425. dataset_queries = []
  426. for dataset_id in dataset_ids:
  427. dataset_query = DatasetQuery(
  428. dataset_id=dataset_id,
  429. content=query,
  430. source="app",
  431. source_app_id=app_id,
  432. created_by_role=user_from,
  433. created_by=user_id,
  434. )
  435. dataset_queries.append(dataset_query)
  436. if dataset_queries:
  437. db.session.add_all(dataset_queries)
  438. db.session.commit()
  439. def _retriever(self, flask_app: Flask, dataset_id: str, query: str, top_k: int, all_documents: list):
  440. with flask_app.app_context():
  441. dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first()
  442. if not dataset:
  443. return []
  444. if dataset.provider == "external":
  445. external_documents = ExternalDatasetService.fetch_external_knowledge_retrieval(
  446. tenant_id=dataset.tenant_id,
  447. dataset_id=dataset_id,
  448. query=query,
  449. external_retrieval_parameters=dataset.retrieval_model,
  450. )
  451. for external_document in external_documents:
  452. document = Document(
  453. page_content=external_document.get("content"),
  454. metadata=external_document.get("metadata"),
  455. provider="external",
  456. )
  457. if document.metadata is not None:
  458. document.metadata["score"] = external_document.get("score")
  459. document.metadata["title"] = external_document.get("title")
  460. document.metadata["dataset_id"] = dataset_id
  461. document.metadata["dataset_name"] = dataset.name
  462. all_documents.append(document)
  463. else:
  464. # get retrieval model , if the model is not setting , using default
  465. retrieval_model = dataset.retrieval_model or default_retrieval_model
  466. if dataset.indexing_technique == "economy":
  467. # use keyword table query
  468. documents = RetrievalService.retrieve(
  469. retrieval_method="keyword_search", dataset_id=dataset.id, query=query, top_k=top_k
  470. )
  471. if documents:
  472. all_documents.extend(documents)
  473. else:
  474. if top_k > 0:
  475. # retrieval source
  476. documents = RetrievalService.retrieve(
  477. retrieval_method=retrieval_model["search_method"],
  478. dataset_id=dataset.id,
  479. query=query,
  480. top_k=retrieval_model.get("top_k") or 2,
  481. score_threshold=retrieval_model.get("score_threshold", 0.0)
  482. if retrieval_model["score_threshold_enabled"]
  483. else 0.0,
  484. reranking_model=retrieval_model.get("reranking_model", None)
  485. if retrieval_model["reranking_enable"]
  486. else None,
  487. reranking_mode=retrieval_model.get("reranking_mode") or "reranking_model",
  488. weights=retrieval_model.get("weights", None),
  489. )
  490. all_documents.extend(documents)
  491. def to_dataset_retriever_tool(
  492. self,
  493. tenant_id: str,
  494. dataset_ids: list[str],
  495. retrieve_config: DatasetRetrieveConfigEntity,
  496. return_resource: bool,
  497. invoke_from: InvokeFrom,
  498. hit_callback: DatasetIndexToolCallbackHandler,
  499. ) -> Optional[list[DatasetRetrieverBaseTool]]:
  500. """
  501. A dataset tool is a tool that can be used to retrieve information from a dataset
  502. :param tenant_id: tenant id
  503. :param dataset_ids: dataset ids
  504. :param retrieve_config: retrieve config
  505. :param return_resource: return resource
  506. :param invoke_from: invoke from
  507. :param hit_callback: hit callback
  508. """
  509. tools = []
  510. available_datasets = []
  511. for dataset_id in dataset_ids:
  512. # get dataset from dataset id
  513. dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
  514. # pass if dataset is not available
  515. if not dataset:
  516. continue
  517. # pass if dataset is not available
  518. if dataset and dataset.provider != "external" and dataset.available_document_count == 0:
  519. continue
  520. available_datasets.append(dataset)
  521. if retrieve_config.retrieve_strategy == DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE:
  522. # get retrieval model config
  523. default_retrieval_model = {
  524. "search_method": RetrievalMethod.SEMANTIC_SEARCH.value,
  525. "reranking_enable": False,
  526. "reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""},
  527. "top_k": 2,
  528. "score_threshold_enabled": False,
  529. }
  530. for dataset in available_datasets:
  531. retrieval_model_config = dataset.retrieval_model or default_retrieval_model
  532. # get top k
  533. top_k = retrieval_model_config["top_k"]
  534. # get score threshold
  535. score_threshold = None
  536. score_threshold_enabled = retrieval_model_config.get("score_threshold_enabled")
  537. if score_threshold_enabled:
  538. score_threshold = retrieval_model_config.get("score_threshold")
  539. from core.tools.utils.dataset_retriever.dataset_retriever_tool import DatasetRetrieverTool
  540. tool = DatasetRetrieverTool.from_dataset(
  541. dataset=dataset,
  542. top_k=top_k,
  543. score_threshold=score_threshold,
  544. hit_callbacks=[hit_callback],
  545. return_resource=return_resource,
  546. retriever_from=invoke_from.to_source(),
  547. )
  548. tools.append(tool)
  549. elif retrieve_config.retrieve_strategy == DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE:
  550. from core.tools.utils.dataset_retriever.dataset_multi_retriever_tool import DatasetMultiRetrieverTool
  551. if retrieve_config.reranking_model is None:
  552. raise ValueError("Reranking model is required for multiple retrieval")
  553. tool = DatasetMultiRetrieverTool.from_dataset(
  554. dataset_ids=[dataset.id for dataset in available_datasets],
  555. tenant_id=tenant_id,
  556. top_k=retrieve_config.top_k or 2,
  557. score_threshold=retrieve_config.score_threshold,
  558. hit_callbacks=[hit_callback],
  559. return_resource=return_resource,
  560. retriever_from=invoke_from.to_source(),
  561. reranking_provider_name=retrieve_config.reranking_model.get("reranking_provider_name"),
  562. reranking_model_name=retrieve_config.reranking_model.get("reranking_model_name"),
  563. )
  564. tools.append(tool)
  565. return tools
  566. def calculate_keyword_score(self, query: str, documents: list[Document], top_k: int) -> list[Document]:
  567. """
  568. Calculate keywords scores
  569. :param query: search query
  570. :param documents: documents for reranking
  571. :return:
  572. """
  573. keyword_table_handler = JiebaKeywordTableHandler()
  574. query_keywords = keyword_table_handler.extract_keywords(query, None)
  575. documents_keywords = []
  576. for document in documents:
  577. if document.metadata is not None:
  578. # get the document keywords
  579. document_keywords = keyword_table_handler.extract_keywords(document.page_content, None)
  580. document.metadata["keywords"] = document_keywords
  581. documents_keywords.append(document_keywords)
  582. # Counter query keywords(TF)
  583. query_keyword_counts = Counter(query_keywords)
  584. # total documents
  585. total_documents = len(documents)
  586. # calculate all documents' keywords IDF
  587. all_keywords = set()
  588. for document_keywords in documents_keywords:
  589. all_keywords.update(document_keywords)
  590. keyword_idf = {}
  591. for keyword in all_keywords:
  592. # calculate include query keywords' documents
  593. doc_count_containing_keyword = sum(1 for doc_keywords in documents_keywords if keyword in doc_keywords)
  594. # IDF
  595. keyword_idf[keyword] = math.log((1 + total_documents) / (1 + doc_count_containing_keyword)) + 1
  596. query_tfidf = {}
  597. for keyword, count in query_keyword_counts.items():
  598. tf = count
  599. idf = keyword_idf.get(keyword, 0)
  600. query_tfidf[keyword] = tf * idf
  601. # calculate all documents' TF-IDF
  602. documents_tfidf = []
  603. for document_keywords in documents_keywords:
  604. document_keyword_counts = Counter(document_keywords)
  605. document_tfidf = {}
  606. for keyword, count in document_keyword_counts.items():
  607. tf = count
  608. idf = keyword_idf.get(keyword, 0)
  609. document_tfidf[keyword] = tf * idf
  610. documents_tfidf.append(document_tfidf)
  611. def cosine_similarity(vec1, vec2):
  612. intersection = set(vec1.keys()) & set(vec2.keys())
  613. numerator = sum(vec1[x] * vec2[x] for x in intersection)
  614. sum1 = sum(vec1[x] ** 2 for x in vec1)
  615. sum2 = sum(vec2[x] ** 2 for x in vec2)
  616. denominator = math.sqrt(sum1) * math.sqrt(sum2)
  617. if not denominator:
  618. return 0.0
  619. else:
  620. return float(numerator) / denominator
  621. similarities = []
  622. for document_tfidf in documents_tfidf:
  623. similarity = cosine_similarity(query_tfidf, document_tfidf)
  624. similarities.append(similarity)
  625. for document, score in zip(documents, similarities):
  626. # format document
  627. if document.metadata is not None:
  628. document.metadata["score"] = score
  629. documents = sorted(documents, key=lambda x: x.metadata.get("score", 0) if x.metadata else 0, reverse=True)
  630. return documents[:top_k] if top_k else documents
  631. def calculate_vector_score(
  632. self, all_documents: list[Document], top_k: int, score_threshold: float
  633. ) -> list[Document]:
  634. filter_documents = []
  635. for document in all_documents:
  636. if score_threshold is None or (document.metadata and document.metadata.get("score", 0) >= score_threshold):
  637. filter_documents.append(document)
  638. if not filter_documents:
  639. return []
  640. filter_documents = sorted(
  641. filter_documents, key=lambda x: x.metadata.get("score", 0) if x.metadata else 0, reverse=True
  642. )
  643. return filter_documents[:top_k] if top_k else filter_documents