dataset_retrieval.py 30 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687
  1. import math
  2. import threading
  3. from collections import Counter
  4. from typing import Optional, cast
  5. from flask import Flask, current_app
  6. from core.app.app_config.entities import DatasetEntity, DatasetRetrieveConfigEntity
  7. from core.app.entities.app_invoke_entities import InvokeFrom, ModelConfigWithCredentialsEntity
  8. from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
  9. from core.entities.agent_entities import PlanningStrategy
  10. from core.memory.token_buffer_memory import TokenBufferMemory
  11. from core.model_manager import ModelInstance, ModelManager
  12. from core.model_runtime.entities.message_entities import PromptMessageTool
  13. from core.model_runtime.entities.model_entities import ModelFeature, ModelType
  14. from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
  15. from core.ops.entities.trace_entity import TraceTaskName
  16. from core.ops.ops_trace_manager import TraceQueueManager, TraceTask
  17. from core.ops.utils import measure_time
  18. from core.rag.data_post_processor.data_post_processor import DataPostProcessor
  19. from core.rag.datasource.keyword.jieba.jieba_keyword_table_handler import JiebaKeywordTableHandler
  20. from core.rag.datasource.retrieval_service import RetrievalService
  21. from core.rag.entities.context_entities import DocumentContext
  22. from core.rag.models.document import Document
  23. from core.rag.retrieval.retrieval_methods import RetrievalMethod
  24. from core.rag.retrieval.router.multi_dataset_function_call_router import FunctionCallMultiDatasetRouter
  25. from core.rag.retrieval.router.multi_dataset_react_route import ReactMultiDatasetRouter
  26. from core.tools.utils.dataset_retriever.dataset_retriever_base_tool import DatasetRetrieverBaseTool
  27. from extensions.ext_database import db
  28. from models.dataset import Dataset, DatasetQuery, DocumentSegment
  29. from models.dataset import Document as DatasetDocument
  30. from services.external_knowledge_service import ExternalDatasetService
  31. default_retrieval_model = {
  32. "search_method": RetrievalMethod.SEMANTIC_SEARCH.value,
  33. "reranking_enable": False,
  34. "reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""},
  35. "top_k": 2,
  36. "score_threshold_enabled": False,
  37. }
  38. class DatasetRetrieval:
  39. def __init__(self, application_generate_entity=None):
  40. self.application_generate_entity = application_generate_entity
  41. def retrieve(
  42. self,
  43. app_id: str,
  44. user_id: str,
  45. tenant_id: str,
  46. model_config: ModelConfigWithCredentialsEntity,
  47. config: DatasetEntity,
  48. query: str,
  49. invoke_from: InvokeFrom,
  50. show_retrieve_source: bool,
  51. hit_callback: DatasetIndexToolCallbackHandler,
  52. message_id: str,
  53. memory: Optional[TokenBufferMemory] = None,
  54. ) -> Optional[str]:
  55. """
  56. Retrieve dataset.
  57. :param app_id: app_id
  58. :param user_id: user_id
  59. :param tenant_id: tenant id
  60. :param model_config: model config
  61. :param config: dataset config
  62. :param query: query
  63. :param invoke_from: invoke from
  64. :param show_retrieve_source: show retrieve source
  65. :param hit_callback: hit callback
  66. :param message_id: message id
  67. :param memory: memory
  68. :return:
  69. """
  70. dataset_ids = config.dataset_ids
  71. if len(dataset_ids) == 0:
  72. return None
  73. retrieve_config = config.retrieve_config
  74. # check model is support tool calling
  75. model_type_instance = model_config.provider_model_bundle.model_type_instance
  76. model_type_instance = cast(LargeLanguageModel, model_type_instance)
  77. model_manager = ModelManager()
  78. model_instance = model_manager.get_model_instance(
  79. tenant_id=tenant_id, model_type=ModelType.LLM, provider=model_config.provider, model=model_config.model
  80. )
  81. # get model schema
  82. model_schema = model_type_instance.get_model_schema(
  83. model=model_config.model, credentials=model_config.credentials
  84. )
  85. if not model_schema:
  86. return None
  87. planning_strategy = PlanningStrategy.REACT_ROUTER
  88. features = model_schema.features
  89. if features:
  90. if ModelFeature.TOOL_CALL in features or ModelFeature.MULTI_TOOL_CALL in features:
  91. planning_strategy = PlanningStrategy.ROUTER
  92. available_datasets = []
  93. for dataset_id in dataset_ids:
  94. # get dataset from dataset id
  95. dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
  96. # pass if dataset is not available
  97. if not dataset:
  98. continue
  99. # pass if dataset is not available
  100. if dataset and dataset.available_document_count == 0 and dataset.provider != "external":
  101. continue
  102. available_datasets.append(dataset)
  103. all_documents = []
  104. user_from = "account" if invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER} else "end_user"
  105. if retrieve_config.retrieve_strategy == DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE:
  106. all_documents = self.single_retrieve(
  107. app_id,
  108. tenant_id,
  109. user_id,
  110. user_from,
  111. available_datasets,
  112. query,
  113. model_instance,
  114. model_config,
  115. planning_strategy,
  116. message_id,
  117. )
  118. elif retrieve_config.retrieve_strategy == DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE:
  119. all_documents = self.multiple_retrieve(
  120. app_id,
  121. tenant_id,
  122. user_id,
  123. user_from,
  124. available_datasets,
  125. query,
  126. retrieve_config.top_k,
  127. retrieve_config.score_threshold,
  128. retrieve_config.rerank_mode,
  129. retrieve_config.reranking_model,
  130. retrieve_config.weights,
  131. retrieve_config.reranking_enabled,
  132. message_id,
  133. )
  134. dify_documents = [item for item in all_documents if item.provider == "dify"]
  135. external_documents = [item for item in all_documents if item.provider == "external"]
  136. document_context_list = []
  137. retrieval_resource_list = []
  138. # deal with external documents
  139. for item in external_documents:
  140. document_context_list.append(DocumentContext(content=item.page_content, score=item.metadata.get("score")))
  141. source = {
  142. "dataset_id": item.metadata.get("dataset_id"),
  143. "dataset_name": item.metadata.get("dataset_name"),
  144. "document_name": item.metadata.get("title"),
  145. "data_source_type": "external",
  146. "retriever_from": invoke_from.to_source(),
  147. "score": item.metadata.get("score"),
  148. "content": item.page_content,
  149. }
  150. retrieval_resource_list.append(source)
  151. document_score_list = {}
  152. # deal with dify documents
  153. if dify_documents:
  154. for item in dify_documents:
  155. if item.metadata.get("score"):
  156. document_score_list[item.metadata["doc_id"]] = item.metadata["score"]
  157. index_node_ids = [document.metadata["doc_id"] for document in dify_documents]
  158. segments = DocumentSegment.query.filter(
  159. DocumentSegment.dataset_id.in_(dataset_ids),
  160. DocumentSegment.status == "completed",
  161. DocumentSegment.enabled == True,
  162. DocumentSegment.index_node_id.in_(index_node_ids),
  163. ).all()
  164. if segments:
  165. index_node_id_to_position = {id: position for position, id in enumerate(index_node_ids)}
  166. sorted_segments = sorted(
  167. segments, key=lambda segment: index_node_id_to_position.get(segment.index_node_id, float("inf"))
  168. )
  169. for segment in sorted_segments:
  170. if segment.answer:
  171. document_context_list.append(
  172. DocumentContext(
  173. content=f"question:{segment.get_sign_content()} answer:{segment.answer}",
  174. score=document_score_list.get(segment.index_node_id, None),
  175. )
  176. )
  177. else:
  178. document_context_list.append(
  179. DocumentContext(
  180. content=segment.get_sign_content(),
  181. score=document_score_list.get(segment.index_node_id, None),
  182. )
  183. )
  184. if show_retrieve_source:
  185. for segment in sorted_segments:
  186. dataset = Dataset.query.filter_by(id=segment.dataset_id).first()
  187. document = DatasetDocument.query.filter(
  188. DatasetDocument.id == segment.document_id,
  189. DatasetDocument.enabled == True,
  190. DatasetDocument.archived == False,
  191. ).first()
  192. if dataset and document:
  193. source = {
  194. "dataset_id": dataset.id,
  195. "dataset_name": dataset.name,
  196. "document_id": document.id,
  197. "document_name": document.name,
  198. "data_source_type": document.data_source_type,
  199. "segment_id": segment.id,
  200. "retriever_from": invoke_from.to_source(),
  201. "score": document_score_list.get(segment.index_node_id, 0.0),
  202. }
  203. if invoke_from.to_source() == "dev":
  204. source["hit_count"] = segment.hit_count
  205. source["word_count"] = segment.word_count
  206. source["segment_position"] = segment.position
  207. source["index_node_hash"] = segment.index_node_hash
  208. if segment.answer:
  209. source["content"] = f"question:{segment.content} \nanswer:{segment.answer}"
  210. else:
  211. source["content"] = segment.content
  212. retrieval_resource_list.append(source)
  213. if hit_callback and retrieval_resource_list:
  214. retrieval_resource_list = sorted(retrieval_resource_list, key=lambda x: x.get("score") or 0.0, reverse=True)
  215. for position, item in enumerate(retrieval_resource_list, start=1):
  216. item["position"] = position
  217. hit_callback.return_retriever_resource_info(retrieval_resource_list)
  218. if document_context_list:
  219. document_context_list = sorted(document_context_list, key=lambda x: x.score or 0.0, reverse=True)
  220. return str("\n".join([document_context.content for document_context in document_context_list]))
  221. return ""
  222. def single_retrieve(
  223. self,
  224. app_id: str,
  225. tenant_id: str,
  226. user_id: str,
  227. user_from: str,
  228. available_datasets: list,
  229. query: str,
  230. model_instance: ModelInstance,
  231. model_config: ModelConfigWithCredentialsEntity,
  232. planning_strategy: PlanningStrategy,
  233. message_id: Optional[str] = None,
  234. ):
  235. tools = []
  236. for dataset in available_datasets:
  237. description = dataset.description
  238. if not description:
  239. description = "useful for when you want to answer queries about the " + dataset.name
  240. description = description.replace("\n", "").replace("\r", "")
  241. message_tool = PromptMessageTool(
  242. name=dataset.id,
  243. description=description,
  244. parameters={
  245. "type": "object",
  246. "properties": {},
  247. "required": [],
  248. },
  249. )
  250. tools.append(message_tool)
  251. dataset_id = None
  252. if planning_strategy == PlanningStrategy.REACT_ROUTER:
  253. react_multi_dataset_router = ReactMultiDatasetRouter()
  254. dataset_id = react_multi_dataset_router.invoke(
  255. query, tools, model_config, model_instance, user_id, tenant_id
  256. )
  257. elif planning_strategy == PlanningStrategy.ROUTER:
  258. function_call_router = FunctionCallMultiDatasetRouter()
  259. dataset_id = function_call_router.invoke(query, tools, model_config, model_instance)
  260. if dataset_id:
  261. # get retrieval model config
  262. dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first()
  263. if dataset:
  264. results = []
  265. if dataset.provider == "external":
  266. external_documents = ExternalDatasetService.fetch_external_knowledge_retrieval(
  267. tenant_id=dataset.tenant_id,
  268. dataset_id=dataset_id,
  269. query=query,
  270. external_retrieval_parameters=dataset.retrieval_model,
  271. )
  272. for external_document in external_documents:
  273. document = Document(
  274. page_content=external_document.get("content"),
  275. metadata=external_document.get("metadata"),
  276. provider="external",
  277. )
  278. document.metadata["score"] = external_document.get("score")
  279. document.metadata["title"] = external_document.get("title")
  280. document.metadata["dataset_id"] = dataset_id
  281. document.metadata["dataset_name"] = dataset.name
  282. results.append(document)
  283. else:
  284. retrieval_model_config = dataset.retrieval_model or default_retrieval_model
  285. # get top k
  286. top_k = retrieval_model_config["top_k"]
  287. # get retrieval method
  288. if dataset.indexing_technique == "economy":
  289. retrieval_method = "keyword_search"
  290. else:
  291. retrieval_method = retrieval_model_config["search_method"]
  292. # get reranking model
  293. reranking_model = (
  294. retrieval_model_config["reranking_model"]
  295. if retrieval_model_config["reranking_enable"]
  296. else None
  297. )
  298. # get score threshold
  299. score_threshold = 0.0
  300. score_threshold_enabled = retrieval_model_config.get("score_threshold_enabled")
  301. if score_threshold_enabled:
  302. score_threshold = retrieval_model_config.get("score_threshold")
  303. with measure_time() as timer:
  304. results = RetrievalService.retrieve(
  305. retrieval_method=retrieval_method,
  306. dataset_id=dataset.id,
  307. query=query,
  308. top_k=top_k,
  309. score_threshold=score_threshold,
  310. reranking_model=reranking_model,
  311. reranking_mode=retrieval_model_config.get("reranking_mode", "reranking_model"),
  312. weights=retrieval_model_config.get("weights", None),
  313. )
  314. self._on_query(query, [dataset_id], app_id, user_from, user_id)
  315. if results:
  316. self._on_retrieval_end(results, message_id, timer)
  317. return results
  318. return []
  319. def multiple_retrieve(
  320. self,
  321. app_id: str,
  322. tenant_id: str,
  323. user_id: str,
  324. user_from: str,
  325. available_datasets: list,
  326. query: str,
  327. top_k: int,
  328. score_threshold: float,
  329. reranking_mode: str,
  330. reranking_model: Optional[dict] = None,
  331. weights: Optional[dict] = None,
  332. reranking_enable: bool = True,
  333. message_id: Optional[str] = None,
  334. ):
  335. threads = []
  336. all_documents = []
  337. dataset_ids = [dataset.id for dataset in available_datasets]
  338. index_type = None
  339. for dataset in available_datasets:
  340. index_type = dataset.indexing_technique
  341. retrieval_thread = threading.Thread(
  342. target=self._retriever,
  343. kwargs={
  344. "flask_app": current_app._get_current_object(),
  345. "dataset_id": dataset.id,
  346. "query": query,
  347. "top_k": top_k,
  348. "all_documents": all_documents,
  349. },
  350. )
  351. threads.append(retrieval_thread)
  352. retrieval_thread.start()
  353. for thread in threads:
  354. thread.join()
  355. with measure_time() as timer:
  356. if reranking_enable:
  357. # do rerank for searched documents
  358. data_post_processor = DataPostProcessor(tenant_id, reranking_mode, reranking_model, weights, False)
  359. all_documents = data_post_processor.invoke(
  360. query=query, documents=all_documents, score_threshold=score_threshold, top_n=top_k
  361. )
  362. else:
  363. if index_type == "economy":
  364. all_documents = self.calculate_keyword_score(query, all_documents, top_k)
  365. elif index_type == "high_quality":
  366. all_documents = self.calculate_vector_score(all_documents, top_k, score_threshold)
  367. self._on_query(query, dataset_ids, app_id, user_from, user_id)
  368. if all_documents:
  369. self._on_retrieval_end(all_documents, message_id, timer)
  370. return all_documents
  371. def _on_retrieval_end(
  372. self, documents: list[Document], message_id: Optional[str] = None, timer: Optional[dict] = None
  373. ) -> None:
  374. """Handle retrieval end."""
  375. dify_documents = [document for document in documents if document.provider == "dify"]
  376. for document in dify_documents:
  377. query = db.session.query(DocumentSegment).filter(
  378. DocumentSegment.index_node_id == document.metadata["doc_id"]
  379. )
  380. # if 'dataset_id' in document.metadata:
  381. if "dataset_id" in document.metadata:
  382. query = query.filter(DocumentSegment.dataset_id == document.metadata["dataset_id"])
  383. # add hit count to document segment
  384. query.update({DocumentSegment.hit_count: DocumentSegment.hit_count + 1}, synchronize_session=False)
  385. db.session.commit()
  386. # get tracing instance
  387. trace_manager: TraceQueueManager | None = (
  388. self.application_generate_entity.trace_manager if self.application_generate_entity else None
  389. )
  390. if trace_manager:
  391. trace_manager.add_trace_task(
  392. TraceTask(
  393. TraceTaskName.DATASET_RETRIEVAL_TRACE, message_id=message_id, documents=documents, timer=timer
  394. )
  395. )
  396. def _on_query(self, query: str, dataset_ids: list[str], app_id: str, user_from: str, user_id: str) -> None:
  397. """
  398. Handle query.
  399. """
  400. if not query:
  401. return
  402. dataset_queries = []
  403. for dataset_id in dataset_ids:
  404. dataset_query = DatasetQuery(
  405. dataset_id=dataset_id,
  406. content=query,
  407. source="app",
  408. source_app_id=app_id,
  409. created_by_role=user_from,
  410. created_by=user_id,
  411. )
  412. dataset_queries.append(dataset_query)
  413. if dataset_queries:
  414. db.session.add_all(dataset_queries)
  415. db.session.commit()
  416. def _retriever(self, flask_app: Flask, dataset_id: str, query: str, top_k: int, all_documents: list):
  417. with flask_app.app_context():
  418. dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first()
  419. if not dataset:
  420. return []
  421. if dataset.provider == "external":
  422. external_documents = ExternalDatasetService.fetch_external_knowledge_retrieval(
  423. tenant_id=dataset.tenant_id,
  424. dataset_id=dataset_id,
  425. query=query,
  426. external_retrieval_parameters=dataset.retrieval_model,
  427. )
  428. for external_document in external_documents:
  429. document = Document(
  430. page_content=external_document.get("content"),
  431. metadata=external_document.get("metadata"),
  432. provider="external",
  433. )
  434. document.metadata["score"] = external_document.get("score")
  435. document.metadata["title"] = external_document.get("title")
  436. document.metadata["dataset_id"] = dataset_id
  437. document.metadata["dataset_name"] = dataset.name
  438. all_documents.append(document)
  439. else:
  440. # get retrieval model , if the model is not setting , using default
  441. retrieval_model = dataset.retrieval_model or default_retrieval_model
  442. if dataset.indexing_technique == "economy":
  443. # use keyword table query
  444. documents = RetrievalService.retrieve(
  445. retrieval_method="keyword_search", dataset_id=dataset.id, query=query, top_k=top_k
  446. )
  447. if documents:
  448. all_documents.extend(documents)
  449. else:
  450. if top_k > 0:
  451. # retrieval source
  452. documents = RetrievalService.retrieve(
  453. retrieval_method=retrieval_model["search_method"],
  454. dataset_id=dataset.id,
  455. query=query,
  456. top_k=retrieval_model.get("top_k") or 2,
  457. score_threshold=retrieval_model.get("score_threshold", 0.0)
  458. if retrieval_model["score_threshold_enabled"]
  459. else 0.0,
  460. reranking_model=retrieval_model.get("reranking_model", None)
  461. if retrieval_model["reranking_enable"]
  462. else None,
  463. reranking_mode=retrieval_model.get("reranking_mode") or "reranking_model",
  464. weights=retrieval_model.get("weights", None),
  465. )
  466. all_documents.extend(documents)
  467. def to_dataset_retriever_tool(
  468. self,
  469. tenant_id: str,
  470. dataset_ids: list[str],
  471. retrieve_config: DatasetRetrieveConfigEntity,
  472. return_resource: bool,
  473. invoke_from: InvokeFrom,
  474. hit_callback: DatasetIndexToolCallbackHandler,
  475. ) -> Optional[list[DatasetRetrieverBaseTool]]:
  476. """
  477. A dataset tool is a tool that can be used to retrieve information from a dataset
  478. :param tenant_id: tenant id
  479. :param dataset_ids: dataset ids
  480. :param retrieve_config: retrieve config
  481. :param return_resource: return resource
  482. :param invoke_from: invoke from
  483. :param hit_callback: hit callback
  484. """
  485. tools = []
  486. available_datasets = []
  487. for dataset_id in dataset_ids:
  488. # get dataset from dataset id
  489. dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
  490. # pass if dataset is not available
  491. if not dataset:
  492. continue
  493. # pass if dataset is not available
  494. if dataset and dataset.provider != "external" and dataset.available_document_count == 0:
  495. continue
  496. available_datasets.append(dataset)
  497. if retrieve_config.retrieve_strategy == DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE:
  498. # get retrieval model config
  499. default_retrieval_model = {
  500. "search_method": RetrievalMethod.SEMANTIC_SEARCH.value,
  501. "reranking_enable": False,
  502. "reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""},
  503. "top_k": 2,
  504. "score_threshold_enabled": False,
  505. }
  506. for dataset in available_datasets:
  507. retrieval_model_config = dataset.retrieval_model or default_retrieval_model
  508. # get top k
  509. top_k = retrieval_model_config["top_k"]
  510. # get score threshold
  511. score_threshold = None
  512. score_threshold_enabled = retrieval_model_config.get("score_threshold_enabled")
  513. if score_threshold_enabled:
  514. score_threshold = retrieval_model_config.get("score_threshold")
  515. from core.tools.utils.dataset_retriever.dataset_retriever_tool import DatasetRetrieverTool
  516. tool = DatasetRetrieverTool.from_dataset(
  517. dataset=dataset,
  518. top_k=top_k,
  519. score_threshold=score_threshold,
  520. hit_callbacks=[hit_callback],
  521. return_resource=return_resource,
  522. retriever_from=invoke_from.to_source(),
  523. )
  524. tools.append(tool)
  525. elif retrieve_config.retrieve_strategy == DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE:
  526. from core.tools.utils.dataset_retriever.dataset_multi_retriever_tool import DatasetMultiRetrieverTool
  527. tool = DatasetMultiRetrieverTool.from_dataset(
  528. dataset_ids=[dataset.id for dataset in available_datasets],
  529. tenant_id=tenant_id,
  530. top_k=retrieve_config.top_k or 2,
  531. score_threshold=retrieve_config.score_threshold,
  532. hit_callbacks=[hit_callback],
  533. return_resource=return_resource,
  534. retriever_from=invoke_from.to_source(),
  535. reranking_provider_name=retrieve_config.reranking_model.get("reranking_provider_name"),
  536. reranking_model_name=retrieve_config.reranking_model.get("reranking_model_name"),
  537. )
  538. tools.append(tool)
  539. return tools
  540. def calculate_keyword_score(self, query: str, documents: list[Document], top_k: int) -> list[Document]:
  541. """
  542. Calculate keywords scores
  543. :param query: search query
  544. :param documents: documents for reranking
  545. :return:
  546. """
  547. keyword_table_handler = JiebaKeywordTableHandler()
  548. query_keywords = keyword_table_handler.extract_keywords(query, None)
  549. documents_keywords = []
  550. for document in documents:
  551. # get the document keywords
  552. document_keywords = keyword_table_handler.extract_keywords(document.page_content, None)
  553. document.metadata["keywords"] = document_keywords
  554. documents_keywords.append(document_keywords)
  555. # Counter query keywords(TF)
  556. query_keyword_counts = Counter(query_keywords)
  557. # total documents
  558. total_documents = len(documents)
  559. # calculate all documents' keywords IDF
  560. all_keywords = set()
  561. for document_keywords in documents_keywords:
  562. all_keywords.update(document_keywords)
  563. keyword_idf = {}
  564. for keyword in all_keywords:
  565. # calculate include query keywords' documents
  566. doc_count_containing_keyword = sum(1 for doc_keywords in documents_keywords if keyword in doc_keywords)
  567. # IDF
  568. keyword_idf[keyword] = math.log((1 + total_documents) / (1 + doc_count_containing_keyword)) + 1
  569. query_tfidf = {}
  570. for keyword, count in query_keyword_counts.items():
  571. tf = count
  572. idf = keyword_idf.get(keyword, 0)
  573. query_tfidf[keyword] = tf * idf
  574. # calculate all documents' TF-IDF
  575. documents_tfidf = []
  576. for document_keywords in documents_keywords:
  577. document_keyword_counts = Counter(document_keywords)
  578. document_tfidf = {}
  579. for keyword, count in document_keyword_counts.items():
  580. tf = count
  581. idf = keyword_idf.get(keyword, 0)
  582. document_tfidf[keyword] = tf * idf
  583. documents_tfidf.append(document_tfidf)
  584. def cosine_similarity(vec1, vec2):
  585. intersection = set(vec1.keys()) & set(vec2.keys())
  586. numerator = sum(vec1[x] * vec2[x] for x in intersection)
  587. sum1 = sum(vec1[x] ** 2 for x in vec1)
  588. sum2 = sum(vec2[x] ** 2 for x in vec2)
  589. denominator = math.sqrt(sum1) * math.sqrt(sum2)
  590. if not denominator:
  591. return 0.0
  592. else:
  593. return float(numerator) / denominator
  594. similarities = []
  595. for document_tfidf in documents_tfidf:
  596. similarity = cosine_similarity(query_tfidf, document_tfidf)
  597. similarities.append(similarity)
  598. for document, score in zip(documents, similarities):
  599. # format document
  600. document.metadata["score"] = score
  601. documents = sorted(documents, key=lambda x: x.metadata["score"], reverse=True)
  602. return documents[:top_k] if top_k else documents
  603. def calculate_vector_score(
  604. self, all_documents: list[Document], top_k: int, score_threshold: float
  605. ) -> list[Document]:
  606. filter_documents = []
  607. for document in all_documents:
  608. if score_threshold is None or document.metadata["score"] >= score_threshold:
  609. filter_documents.append(document)
  610. if not filter_documents:
  611. return []
  612. filter_documents = sorted(filter_documents, key=lambda x: x.metadata["score"], reverse=True)
  613. return filter_documents[:top_k] if top_k else filter_documents