ソースを参照

feat (new llm): add support for openrouter (#3042)

Salem Korayem 1 年間 前
コミット
6b4c8e76e6

+ 1 - 0
api/core/model_runtime/model_providers/_position.yaml

@@ -6,6 +6,7 @@
 - cohere
 - bedrock
 - togetherai
+- openrouter
 - ollama
 - mistralai
 - groq

+ 0 - 0
api/core/model_runtime/model_providers/openrouter/__init__.py


File diff suppressed because it is too large
+ 11 - 0
api/core/model_runtime/model_providers/openrouter/_assets/openrouter.svg


File diff suppressed because it is too large
+ 10 - 0
api/core/model_runtime/model_providers/openrouter/_assets/openrouter_square.svg


+ 0 - 0
api/core/model_runtime/model_providers/openrouter/llm/__init__.py


+ 46 - 0
api/core/model_runtime/model_providers/openrouter/llm/llm.py

@@ -0,0 +1,46 @@
+from collections.abc import Generator
+from typing import Optional, Union
+
+from core.model_runtime.entities.llm_entities import LLMResult
+from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool
+from core.model_runtime.entities.model_entities import AIModelEntity
+from core.model_runtime.model_providers.openai_api_compatible.llm.llm import OAIAPICompatLargeLanguageModel
+
+
+class OpenRouterLargeLanguageModel(OAIAPICompatLargeLanguageModel):
+
+    def _update_endpoint_url(self, credentials: dict):
+        credentials['endpoint_url'] = "https://openrouter.ai/api/v1"
+        return credentials
+
+    def _invoke(self, model: str, credentials: dict,
+                prompt_messages: list[PromptMessage], model_parameters: dict,
+                tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
+                stream: bool = True, user: Optional[str] = None) \
+            -> Union[LLMResult, Generator]:
+        cred_with_endpoint = self._update_endpoint_url(credentials=credentials)
+
+        return super()._invoke(model, cred_with_endpoint, prompt_messages, model_parameters, tools, stop, stream, user)
+
+    def validate_credentials(self, model: str, credentials: dict) -> None:
+        cred_with_endpoint = self._update_endpoint_url(credentials=credentials)
+
+        return super().validate_credentials(model, cred_with_endpoint)
+
+    def _generate(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict,
+                  tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
+                  stream: bool = True, user: Optional[str] = None) -> Union[LLMResult, Generator]:
+        cred_with_endpoint = self._update_endpoint_url(credentials=credentials)
+
+        return super()._generate(model, cred_with_endpoint, prompt_messages, model_parameters, tools, stop, stream, user)
+
+    def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity:
+        cred_with_endpoint = self._update_endpoint_url(credentials=credentials)
+
+        return super().get_customizable_model_schema(model, cred_with_endpoint)
+
+    def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
+                       tools: Optional[list[PromptMessageTool]] = None) -> int:
+        cred_with_endpoint = self._update_endpoint_url(credentials=credentials)
+
+        return super().get_num_tokens(model, cred_with_endpoint, prompt_messages, tools)

+ 11 - 0
api/core/model_runtime/model_providers/openrouter/openrouter.py

@@ -0,0 +1,11 @@
+import logging
+
+from core.model_runtime.model_providers.__base.model_provider import ModelProvider
+
+logger = logging.getLogger(__name__)
+
+
+class OpenRouterProvider(ModelProvider):
+
+    def validate_provider_credentials(self, credentials: dict) -> None:
+        pass

+ 75 - 0
api/core/model_runtime/model_providers/openrouter/openrouter.yaml

@@ -0,0 +1,75 @@
+provider: openrouter
+label:
+  en_US: openrouter.ai
+icon_small:
+  en_US: openrouter_square.svg
+icon_large:
+  en_US: openrouter.svg
+background: "#F1EFED"
+help:
+  title:
+    en_US: Get your API key from openrouter.ai
+    zh_Hans: 从 openrouter.ai 获取 API Key
+  url:
+    en_US: https://openrouter.ai/keys
+supported_model_types:
+  - llm
+configurate_methods:
+  - customizable-model
+model_credential_schema:
+  model:
+    label:
+      en_US: Model Name
+      zh_Hans: 模型名称
+    placeholder:
+      en_US: Enter full model name
+      zh_Hans: 输入模型全称
+  credential_form_schemas:
+    - variable: api_key
+      required: true
+      label:
+        en_US: API Key
+      type: secret-input
+      placeholder:
+        zh_Hans: 在此输入您的 API Key
+        en_US: Enter your API Key
+    - variable: mode
+      show_on:
+        - variable: __model_type
+          value: llm
+      label:
+        en_US: Completion mode
+      type: select
+      required: false
+      default: chat
+      placeholder:
+        zh_Hans: 选择对话类型
+        en_US: Select completion mode
+      options:
+        - value: completion
+          label:
+            en_US: Completion
+            zh_Hans: 补全
+        - value: chat
+          label:
+            en_US: Chat
+            zh_Hans: 对话
+    - variable: context_size
+      label:
+        zh_Hans: 模型上下文长度
+        en_US: Model context size
+      required: true
+      type: text-input
+      default: "4096"
+      placeholder:
+        zh_Hans: 在此输入您的模型上下文长度
+        en_US: Enter your Model context size
+    - variable: max_tokens_to_sample
+      label:
+        zh_Hans: 最大 token 上限
+        en_US: Upper bound for max tokens
+      show_on:
+        - variable: __model_type
+          value: llm
+      default: "4096"
+      type: text-input

+ 0 - 0
api/tests/integration_tests/model_runtime/openrouter/__init__.py


+ 118 - 0
api/tests/integration_tests/model_runtime/openrouter/test_llm.py

@@ -0,0 +1,118 @@
+import os
+from typing import Generator
+
+import pytest
+from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
+from core.model_runtime.entities.message_entities import (AssistantPromptMessage, PromptMessageTool,
+                                                          SystemPromptMessage, UserPromptMessage)
+from core.model_runtime.errors.validate import CredentialsValidateFailedError
+from core.model_runtime.model_providers.openrouter.llm.llm import OpenRouterLargeLanguageModel
+
+
+def test_validate_credentials():
+    model = OpenRouterLargeLanguageModel()
+
+    with pytest.raises(CredentialsValidateFailedError):
+        model.validate_credentials(
+            model='mistralai/mixtral-8x7b-instruct',
+            credentials={
+                'api_key': 'invalid_key',
+                'mode': 'chat'
+            }
+        )
+
+    model.validate_credentials(
+        model='mistralai/mixtral-8x7b-instruct',
+        credentials={
+            'api_key': os.environ.get('TOGETHER_API_KEY'),
+            'mode': 'chat'
+        }
+    )
+
+
+def test_invoke_model():
+    model = OpenRouterLargeLanguageModel()
+
+    response = model.invoke(
+        model='mistralai/mixtral-8x7b-instruct',
+        credentials={
+            'api_key': os.environ.get('TOGETHER_API_KEY'),
+            'mode': 'completion'
+        },
+        prompt_messages=[
+            SystemPromptMessage(
+                content='You are a helpful AI assistant.',
+            ),
+            UserPromptMessage(
+                content='Who are you?'
+            )
+        ],
+        model_parameters={
+            'temperature': 1.0,
+            'top_k': 2,
+            'top_p': 0.5,
+        },
+        stop=['How'],
+        stream=False,
+        user="abc-123"
+    )
+
+    assert isinstance(response, LLMResult)
+    assert len(response.message.content) > 0
+
+
+def test_invoke_stream_model():
+    model = OpenRouterLargeLanguageModel()
+
+    response = model.invoke(
+        model='mistralai/mixtral-8x7b-instruct',
+        credentials={
+            'api_key': os.environ.get('TOGETHER_API_KEY'),
+            'mode': 'chat'
+        },
+        prompt_messages=[
+            SystemPromptMessage(
+                content='You are a helpful AI assistant.',
+            ),
+            UserPromptMessage(
+                content='Who are you?'
+            )
+        ],
+        model_parameters={
+            'temperature': 1.0,
+            'top_k': 2,
+            'top_p': 0.5,
+        },
+        stop=['How'],
+        stream=True,
+        user="abc-123"
+    )
+
+    assert isinstance(response, Generator)
+
+    for chunk in response:
+        assert isinstance(chunk, LLMResultChunk)
+        assert isinstance(chunk.delta, LLMResultChunkDelta)
+        assert isinstance(chunk.delta.message, AssistantPromptMessage)
+
+
+def test_get_num_tokens():
+    model = OpenRouterLargeLanguageModel()
+
+    num_tokens = model.get_num_tokens(
+        model='mistralai/mixtral-8x7b-instruct',
+        credentials={
+            'api_key': os.environ.get('TOGETHER_API_KEY'),
+        },
+        prompt_messages=[
+            SystemPromptMessage(
+                content='You are a helpful AI assistant.',
+            ),
+            UserPromptMessage(
+                content='Hello World!'
+            )
+        ]
+    )
+
+    assert isinstance(num_tokens, int)
+    assert num_tokens == 21