浏览代码

add jina-reranker-v1-base-en (#2676)

Joshua 1 年之前
父节点
当前提交
8523b34be7

+ 3 - 2
api/core/model_runtime/model_providers/jina/jina.yaml

@@ -2,7 +2,7 @@ provider: jina
 label:
   en_US: Jina
 description:
-  en_US: Embedding Model Supported
+  en_US: Embedding and Rerank Model Supported
 icon_small:
   en_US: icon_s_en.svg
 icon_large:
@@ -13,9 +13,10 @@ help:
     en_US: Get your API key from Jina AI
     zh_Hans: 从 Jina 获取 API Key
   url:
-    en_US: https://jina.ai/embeddings/
+    en_US: https://jina.ai/
 supported_model_types:
   - text-embedding
+  - rerank
 configurate_methods:
   - predefined-model
 provider_credential_schema:

+ 0 - 0
api/core/model_runtime/model_providers/jina/rerank/__init__.py


+ 4 - 0
api/core/model_runtime/model_providers/jina/rerank/jina-reranker-v1-base-en.yaml

@@ -0,0 +1,4 @@
+model: jina-reranker-v1-base-en
+model_type: rerank
+model_properties:
+  context_size: 8192

+ 105 - 0
api/core/model_runtime/model_providers/jina/rerank/rerank.py

@@ -0,0 +1,105 @@
+from typing import Optional
+
+import httpx
+
+from core.model_runtime.entities.rerank_entities import RerankDocument, RerankResult
+from core.model_runtime.errors.invoke import (
+    InvokeAuthorizationError,
+    InvokeBadRequestError,
+    InvokeConnectionError,
+    InvokeError,
+    InvokeRateLimitError,
+    InvokeServerUnavailableError,
+)
+from core.model_runtime.errors.validate import CredentialsValidateFailedError
+from core.model_runtime.model_providers.__base.rerank_model import RerankModel
+
+
+class JinaRerankModel(RerankModel):
+    """
+    Model class for Jina rerank model.
+    """
+
+    def _invoke(self, model: str, credentials: dict,
+                query: str, docs: list[str], score_threshold: Optional[float] = None, top_n: Optional[int] = None,
+                user: Optional[str] = None) -> RerankResult:
+        """
+        Invoke rerank model
+
+        :param model: model name
+        :param credentials: model credentials
+        :param query: search query
+        :param docs: docs for reranking
+        :param score_threshold: score threshold
+        :param top_n: top n documents to return
+        :param user: unique user id
+        :return: rerank result
+        """
+        if len(docs) == 0:
+            return RerankResult(model=model, docs=[])
+
+        try:
+            response = httpx.post(
+                "https://api.jina.ai/v1/rerank",
+                json={
+                    "model": model,
+                    "query": query,
+                    "documents": docs,
+                    "top_n": top_n
+                },
+                headers={"Authorization": f"Bearer {credentials.get('api_key')}"}  
+            )
+            response.raise_for_status() 
+            results = response.json()
+
+            rerank_documents = []
+            for result in results['results']:  
+                rerank_document = RerankDocument(
+                    index=result['index'],
+                    text=result['document']['text'],
+                    score=result['relevance_score'],
+                )
+                if score_threshold is None or result['relevance_score'] >= score_threshold:
+                    rerank_documents.append(rerank_document)
+
+            return RerankResult(model=model, docs=rerank_documents)
+        except httpx.HTTPStatusError as e:
+            raise InvokeServerUnavailableError(str(e))  
+
+    def validate_credentials(self, model: str, credentials: dict) -> None:
+        """
+        Validate model credentials
+
+        :param model: model name
+        :param credentials: model credentials
+        :return:
+        """
+        try:
+            
+            self._invoke(
+                model=model,
+                credentials=credentials,
+                query="What is the capital of the United States?",
+                docs=[
+                    "Carson City is the capital city of the American state of Nevada. At the 2010 United States "
+                    "Census, Carson City had a population of 55,274.",
+                    "The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean that "
+                    "are a political division controlled by the United States. Its capital is Saipan.",
+                ],
+                score_threshold=0.8
+            )
+        except Exception as ex:
+            raise CredentialsValidateFailedError(str(ex))
+
+    @property
+    def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
+        """
+        Map model invoke error to unified error
+        """
+        return {
+            InvokeConnectionError: [httpx.ConnectError],
+            InvokeServerUnavailableError: [httpx.RemoteProtocolError],
+            InvokeRateLimitError: [], 
+            InvokeAuthorizationError: [httpx.HTTPStatusError],  
+            InvokeBadRequestError: [httpx.RequestError]
+        }