base.py 1.1 KB

12345678910111213141516171819202122232425262728293031323334353637
  1. from abc import abstractmethod
  2. from typing import Any, Optional, List
  3. from langchain.schema import Document
  4. from core.model_providers.models.base import BaseProviderModel
  5. from core.model_providers.models.entity.model_params import ModelType
  6. from core.model_providers.providers.base import BaseModelProvider
  7. import logging
  8. logger = logging.getLogger(__name__)
  9. class BaseReranking(BaseProviderModel):
  10. name: str
  11. type: ModelType = ModelType.RERANKING
  12. def __init__(self, model_provider: BaseModelProvider, client: Any, name: str):
  13. super().__init__(model_provider, client)
  14. self.name = name
  15. @property
  16. def base_model_name(self) -> str:
  17. """
  18. get base model name
  19. :return: str
  20. """
  21. return self.name
  22. @abstractmethod
  23. def rerank(self, query: str, documents: List[Document], score_threshold: Optional[float], top_k: Optional[int]) -> Optional[List[Document]]:
  24. raise NotImplementedError
  25. @abstractmethod
  26. def handle_exceptions(self, ex: Exception) -> Exception:
  27. raise NotImplementedError