weight_rerank.py 6.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179
  1. import math
  2. from collections import Counter
  3. from typing import Optional
  4. import numpy as np
  5. from core.embedding.cached_embedding import CacheEmbedding
  6. from core.model_manager import ModelManager
  7. from core.model_runtime.entities.model_entities import ModelType
  8. from core.rag.datasource.keyword.jieba.jieba_keyword_table_handler import JiebaKeywordTableHandler
  9. from core.rag.models.document import Document
  10. from core.rag.rerank.entity.weight import VectorSetting, Weights
  11. class WeightRerankRunner:
  12. def __init__(self, tenant_id: str, weights: Weights) -> None:
  13. self.tenant_id = tenant_id
  14. self.weights = weights
  15. def run(self, query: str, documents: list[Document], score_threshold: Optional[float] = None,
  16. top_n: Optional[int] = None, user: Optional[str] = None) -> list[Document]:
  17. """
  18. Run rerank model
  19. :param query: search query
  20. :param documents: documents for reranking
  21. :param score_threshold: score threshold
  22. :param top_n: top n
  23. :param user: unique user id if needed
  24. :return:
  25. """
  26. docs = []
  27. doc_id = []
  28. unique_documents = []
  29. for document in documents:
  30. if document.metadata['doc_id'] not in doc_id:
  31. doc_id.append(document.metadata['doc_id'])
  32. docs.append(document.page_content)
  33. unique_documents.append(document)
  34. documents = unique_documents
  35. rerank_documents = []
  36. query_scores = self._calculate_keyword_score(query, documents)
  37. query_vector_scores = self._calculate_cosine(self.tenant_id, query, documents, self.weights.vector_setting)
  38. for document, query_score, query_vector_score in zip(documents, query_scores, query_vector_scores):
  39. # format document
  40. score = self.weights.vector_setting.vector_weight * query_vector_score + \
  41. self.weights.keyword_setting.keyword_weight * query_score
  42. if score_threshold and score < score_threshold:
  43. continue
  44. document.metadata['score'] = score
  45. rerank_documents.append(document)
  46. rerank_documents = sorted(rerank_documents, key=lambda x: x.metadata['score'], reverse=True)
  47. return rerank_documents[:top_n] if top_n else rerank_documents
  48. def _calculate_keyword_score(self, query: str, documents: list[Document]) -> list[float]:
  49. """
  50. Calculate BM25 scores
  51. :param query: search query
  52. :param documents: documents for reranking
  53. :return:
  54. """
  55. keyword_table_handler = JiebaKeywordTableHandler()
  56. query_keywords = keyword_table_handler.extract_keywords(query, None)
  57. documents_keywords = []
  58. for document in documents:
  59. # get the document keywords
  60. document_keywords = keyword_table_handler.extract_keywords(document.page_content, None)
  61. document.metadata['keywords'] = document_keywords
  62. documents_keywords.append(document_keywords)
  63. # Counter query keywords(TF)
  64. query_keyword_counts = Counter(query_keywords)
  65. # total documents
  66. total_documents = len(documents)
  67. # calculate all documents' keywords IDF
  68. all_keywords = set()
  69. for document_keywords in documents_keywords:
  70. all_keywords.update(document_keywords)
  71. keyword_idf = {}
  72. for keyword in all_keywords:
  73. # calculate include query keywords' documents
  74. doc_count_containing_keyword = sum(1 for doc_keywords in documents_keywords if keyword in doc_keywords)
  75. # IDF
  76. keyword_idf[keyword] = math.log((1 + total_documents) / (1 + doc_count_containing_keyword)) + 1
  77. query_tfidf = {}
  78. for keyword, count in query_keyword_counts.items():
  79. tf = count
  80. idf = keyword_idf.get(keyword, 0)
  81. query_tfidf[keyword] = tf * idf
  82. # calculate all documents' TF-IDF
  83. documents_tfidf = []
  84. for document_keywords in documents_keywords:
  85. document_keyword_counts = Counter(document_keywords)
  86. document_tfidf = {}
  87. for keyword, count in document_keyword_counts.items():
  88. tf = count
  89. idf = keyword_idf.get(keyword, 0)
  90. document_tfidf[keyword] = tf * idf
  91. documents_tfidf.append(document_tfidf)
  92. def cosine_similarity(vec1, vec2):
  93. intersection = set(vec1.keys()) & set(vec2.keys())
  94. numerator = sum(vec1[x] * vec2[x] for x in intersection)
  95. sum1 = sum(vec1[x] ** 2 for x in vec1.keys())
  96. sum2 = sum(vec2[x] ** 2 for x in vec2.keys())
  97. denominator = math.sqrt(sum1) * math.sqrt(sum2)
  98. if not denominator:
  99. return 0.0
  100. else:
  101. return float(numerator) / denominator
  102. similarities = []
  103. for document_tfidf in documents_tfidf:
  104. similarity = cosine_similarity(query_tfidf, document_tfidf)
  105. similarities.append(similarity)
  106. # for idx, similarity in enumerate(similarities):
  107. # print(f"Document {idx + 1} similarity: {similarity}")
  108. return similarities
  109. def _calculate_cosine(self, tenant_id: str, query: str, documents: list[Document],
  110. vector_setting: VectorSetting) -> list[float]:
  111. """
  112. Calculate Cosine scores
  113. :param query: search query
  114. :param documents: documents for reranking
  115. :return:
  116. """
  117. query_vector_scores = []
  118. model_manager = ModelManager()
  119. embedding_model = model_manager.get_model_instance(
  120. tenant_id=tenant_id,
  121. provider=vector_setting.embedding_provider_name,
  122. model_type=ModelType.TEXT_EMBEDDING,
  123. model=vector_setting.embedding_model_name
  124. )
  125. cache_embedding = CacheEmbedding(embedding_model)
  126. query_vector = cache_embedding.embed_query(query)
  127. for document in documents:
  128. # calculate cosine similarity
  129. if 'score' in document.metadata:
  130. query_vector_scores.append(document.metadata['score'])
  131. else:
  132. content_vector = document.metadata['vector']
  133. # transform to NumPy
  134. vec1 = np.array(query_vector)
  135. vec2 = np.array(document.metadata['vector'])
  136. # calculate dot product
  137. dot_product = np.dot(vec1, vec2)
  138. # calculate norm
  139. norm_vec1 = np.linalg.norm(vec1)
  140. norm_vec2 = np.linalg.norm(vec2)
  141. # calculate cosine similarity
  142. cosine_sim = dot_product / (norm_vec1 * norm_vec2)
  143. query_vector_scores.append(cosine_sim)
  144. return query_vector_scores