Browse Source

feat:add wenxin rerank (#9431)

Co-authored-by: cuihz <cuihz@knowbox.cn>
Co-authored-by: crazywoola <427733928@qq.com>
chzphoenix 6 months ago
parent
commit
211f416806

+ 1 - 0
api/core/model_runtime/model_providers/wenxin/_common.py

@@ -120,6 +120,7 @@ class _CommonWenxin:
         "bge-large-en": "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/embeddings/bge_large_en",
         "bge-large-zh": "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/embeddings/bge_large_zh",
         "tao-8k": "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/embeddings/tao_8k",
+        "bce-reranker-base_v1": "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/reranker/bce_reranker_base",
     }
 
     function_calling_supports = [

+ 0 - 0
api/core/model_runtime/model_providers/wenxin/rerank/__init__.py


+ 8 - 0
api/core/model_runtime/model_providers/wenxin/rerank/bce-reranker-base_v1.yaml

@@ -0,0 +1,8 @@
+model: bce-reranker-base_v1
+model_type: rerank
+model_properties:
+  context_size: 4096
+pricing:
+  input: '0.0005'
+  unit: '0.001'
+  currency: RMB

+ 147 - 0
api/core/model_runtime/model_providers/wenxin/rerank/rerank.py

@@ -0,0 +1,147 @@
+from typing import Optional
+
+import httpx
+
+from core.model_runtime.entities.common_entities import I18nObject
+from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelPropertyKey, ModelType
+from core.model_runtime.entities.rerank_entities import RerankDocument, RerankResult
+from core.model_runtime.errors.invoke import (
+    InvokeAuthorizationError,
+    InvokeBadRequestError,
+    InvokeConnectionError,
+    InvokeError,
+    InvokeRateLimitError,
+    InvokeServerUnavailableError,
+)
+from core.model_runtime.errors.validate import CredentialsValidateFailedError
+from core.model_runtime.model_providers.__base.rerank_model import RerankModel
+from core.model_runtime.model_providers.wenxin._common import _CommonWenxin
+
+
+class WenxinRerank(_CommonWenxin):
+    def rerank(self, model: str, query: str, docs: list[str], top_n: Optional[int] = None):
+        access_token = self._get_access_token()
+        url = f"{self.api_bases[model]}?access_token={access_token}"
+
+        try:
+            response = httpx.post(
+                url,
+                json={"model": model, "query": query, "documents": docs, "top_n": top_n},
+                headers={"Content-Type": "application/json"},
+            )
+            response.raise_for_status()
+            return response.json()
+        except httpx.HTTPStatusError as e:
+            raise InvokeServerUnavailableError(str(e))
+
+
+class WenxinRerankModel(RerankModel):
+    """
+    Model class for wenxin rerank model.
+    """
+
+    def _invoke(
+        self,
+        model: str,
+        credentials: dict,
+        query: str,
+        docs: list[str],
+        score_threshold: Optional[float] = None,
+        top_n: Optional[int] = None,
+        user: Optional[str] = None,
+    ) -> RerankResult:
+        """
+        Invoke rerank model
+
+        :param model: model name
+        :param credentials: model credentials
+        :param query: search query
+        :param docs: docs for reranking
+        :param score_threshold: score threshold
+        :param top_n: top n documents to return
+        :param user: unique user id
+        :return: rerank result
+        """
+        if len(docs) == 0:
+            return RerankResult(model=model, docs=[])
+
+        api_key = credentials["api_key"]
+        secret_key = credentials["secret_key"]
+
+        wenxin_rerank: WenxinRerank = WenxinRerank(api_key, secret_key)
+
+        try:
+            results = wenxin_rerank.rerank(model, query, docs, top_n)
+
+            rerank_documents = []
+            for result in results["results"]:
+                index = result["index"]
+                if "document" in result:
+                    text = result["document"]
+                else:
+                    # llama.cpp rerank maynot return original documents
+                    text = docs[index]
+
+                rerank_document = RerankDocument(
+                    index=index,
+                    text=text,
+                    score=result["relevance_score"],
+                )
+
+                if score_threshold is None or result["relevance_score"] >= score_threshold:
+                    rerank_documents.append(rerank_document)
+
+            return RerankResult(model=model, docs=rerank_documents)
+        except httpx.HTTPStatusError as e:
+            raise InvokeServerUnavailableError(str(e))
+
+    def validate_credentials(self, model: str, credentials: dict) -> None:
+        """
+        Validate model credentials
+
+        :param model: model name
+        :param credentials: model credentials
+        :return:
+        """
+        try:
+            self._invoke(
+                model=model,
+                credentials=credentials,
+                query="What is the capital of the United States?",
+                docs=[
+                    "Carson City is the capital city of the American state of Nevada. At the 2010 United States "
+                    "Census, Carson City had a population of 55,274.",
+                    "The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean that "
+                    "are a political division controlled by the United States. Its capital is Saipan.",
+                ],
+                score_threshold=0.8,
+            )
+        except Exception as ex:
+            raise CredentialsValidateFailedError(str(ex))
+
+    @property
+    def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
+        """
+        Map model invoke error to unified error
+        """
+        return {
+            InvokeConnectionError: [httpx.ConnectError],
+            InvokeServerUnavailableError: [httpx.RemoteProtocolError],
+            InvokeRateLimitError: [],
+            InvokeAuthorizationError: [httpx.HTTPStatusError],
+            InvokeBadRequestError: [httpx.RequestError],
+        }
+
+    def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity:
+        """
+        generate custom model entities from credentials
+        """
+        entity = AIModelEntity(
+            model=model,
+            label=I18nObject(en_US=model),
+            model_type=ModelType.RERANK,
+            fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
+            model_properties={ModelPropertyKey.CONTEXT_SIZE: int(credentials.get("context_size"))},
+        )
+
+        return entity

+ 1 - 0
api/core/model_runtime/model_providers/wenxin/wenxin.yaml

@@ -18,6 +18,7 @@ help:
 supported_model_types:
   - llm
   - text-embedding
+  - rerank
 configurate_methods:
   - predefined-model
 provider_credential_schema:

+ 21 - 0
api/tests/integration_tests/model_runtime/wenxin/test_rerank.py

@@ -0,0 +1,21 @@
+import os
+from time import sleep
+
+from core.model_runtime.entities.rerank_entities import RerankResult
+from core.model_runtime.model_providers.wenxin.rerank.rerank import WenxinRerankModel
+
+
+def test_invoke_bce_reranker_base_v1():
+    sleep(3)
+    model = WenxinRerankModel()
+
+    response = model.invoke(
+        model="bce-reranker-base_v1",
+        credentials={"api_key": os.environ.get("WENXIN_API_KEY"), "secret_key": os.environ.get("WENXIN_SECRET_KEY")},
+        query="What is Deep Learning?",
+        docs=["Deep Learning is ...", "My Book is ..."],
+        user="abc-123",
+    )
+
+    assert isinstance(response, RerankResult)
+    assert len(response.docs) == 2