12345678910111213141516171819202122232425262728293031323334353637 |
- from abc import abstractmethod
- from typing import Any, Optional, List
- from langchain.schema import Document
- from core.model_providers.models.base import BaseProviderModel
- from core.model_providers.models.entity.model_params import ModelType
- from core.model_providers.providers.base import BaseModelProvider
- import logging
- logger = logging.getLogger(__name__)
- class BaseReranking(BaseProviderModel):
- name: str
- type: ModelType = ModelType.RERANKING
- def __init__(self, model_provider: BaseModelProvider, client: Any, name: str):
- super().__init__(model_provider, client)
- self.name = name
- @property
- def base_model_name(self) -> str:
- """
- get base model name
-
- :return: str
- """
- return self.name
- @abstractmethod
- def rerank(self, query: str, documents: List[Document], score_threshold: Optional[float], top_k: Optional[int]) -> Optional[List[Document]]:
- raise NotImplementedError
- @abstractmethod
- def handle_exceptions(self, ex: Exception) -> Exception:
- raise NotImplementedError
|