rerank_factory.py 697 B

1234567891011121314151617
  1. from core.rag.rerank.rerank_base import BaseRerankRunner
  2. from core.rag.rerank.rerank_model import RerankModelRunner
  3. from core.rag.rerank.rerank_type import RerankMode
  4. from core.rag.rerank.weight_rerank import WeightRerankRunner
  5. class RerankRunnerFactory:
  6. @staticmethod
  7. def create_rerank_runner(runner_type: str, *args, **kwargs) -> BaseRerankRunner:
  8. match runner_type:
  9. case RerankMode.RERANKING_MODEL.value:
  10. return RerankModelRunner(*args, **kwargs)
  11. case RerankMode.WEIGHTED_SCORE.value:
  12. return WeightRerankRunner(*args, **kwargs)
  13. case _:
  14. raise ValueError(f"Unknown runner type: {runner_type}")