retrieval_service.py 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596
  1. from typing import Optional
  2. from flask import current_app, Flask
  3. from langchain.embeddings.base import Embeddings
  4. from core.index.vector_index.vector_index import VectorIndex
  5. from core.model_providers.model_factory import ModelFactory
  6. from extensions.ext_database import db
  7. from models.dataset import Dataset
  8. default_retrieval_model = {
  9. 'search_method': 'semantic_search',
  10. 'reranking_enable': False,
  11. 'reranking_model': {
  12. 'reranking_provider_name': '',
  13. 'reranking_model_name': ''
  14. },
  15. 'top_k': 2,
  16. 'score_threshold_enabled': False
  17. }
  18. class RetrievalService:
  19. @classmethod
  20. def embedding_search(cls, flask_app: Flask, dataset_id: str, query: str,
  21. top_k: int, score_threshold: Optional[float], reranking_model: Optional[dict],
  22. all_documents: list, search_method: str, embeddings: Embeddings):
  23. with flask_app.app_context():
  24. dataset = db.session.query(Dataset).filter(
  25. Dataset.id == dataset_id
  26. ).first()
  27. vector_index = VectorIndex(
  28. dataset=dataset,
  29. config=current_app.config,
  30. embeddings=embeddings
  31. )
  32. documents = vector_index.search(
  33. query,
  34. search_type='similarity_score_threshold',
  35. search_kwargs={
  36. 'k': top_k,
  37. 'score_threshold': score_threshold,
  38. 'filter': {
  39. 'group_id': [dataset.id]
  40. }
  41. }
  42. )
  43. if documents:
  44. if reranking_model and search_method == 'semantic_search':
  45. rerank = ModelFactory.get_reranking_model(
  46. tenant_id=dataset.tenant_id,
  47. model_provider_name=reranking_model['reranking_provider_name'],
  48. model_name=reranking_model['reranking_model_name']
  49. )
  50. all_documents.extend(rerank.rerank(query, documents, score_threshold, len(documents)))
  51. else:
  52. all_documents.extend(documents)
  53. @classmethod
  54. def full_text_index_search(cls, flask_app: Flask, dataset_id: str, query: str,
  55. top_k: int, score_threshold: Optional[float], reranking_model: Optional[dict],
  56. all_documents: list, search_method: str, embeddings: Embeddings):
  57. with flask_app.app_context():
  58. dataset = db.session.query(Dataset).filter(
  59. Dataset.id == dataset_id
  60. ).first()
  61. vector_index = VectorIndex(
  62. dataset=dataset,
  63. config=current_app.config,
  64. embeddings=embeddings
  65. )
  66. documents = vector_index.search_by_full_text_index(
  67. query,
  68. search_type='similarity_score_threshold',
  69. top_k=top_k
  70. )
  71. if documents:
  72. if reranking_model and search_method == 'full_text_search':
  73. rerank = ModelFactory.get_reranking_model(
  74. tenant_id=dataset.tenant_id,
  75. model_provider_name=reranking_model['reranking_provider_name'],
  76. model_name=reranking_model['reranking_model_name']
  77. )
  78. all_documents.extend(rerank.rerank(query, documents, score_threshold, len(documents)))
  79. else:
  80. all_documents.extend(documents)