Browse Source

Add novita.ai as model provider (#4961)

Jason 10 months ago
parent
commit
b7ff765d8d

File diff suppressed because it is too large
+ 19 - 0
api/core/model_runtime/model_providers/novita/_assets/icon_l_en.svg


File diff suppressed because it is too large
+ 10 - 0
api/core/model_runtime/model_providers/novita/_assets/icon_s_en.svg


+ 36 - 0
api/core/model_runtime/model_providers/novita/llm/Nous-Hermes-2-Mixtral-8x7B-DPO.yaml

@@ -0,0 +1,36 @@
+model: Nous-Hermes-2-Mixtral-8x7B-DPO
+label:
+  zh_Hans: Nous-Hermes-2-Mixtral-8x7B-DPO
+  en_US: Nous-Hermes-2-Mixtral-8x7B-DPO
+model_type: llm
+features:
+  - agent-thought
+model_properties:
+  mode: chat
+  context_size: 32768
+parameter_rules:
+  - name: temperature
+    use_template: temperature
+    min: 0
+    max: 2
+    default: 1
+  - name: top_p
+    use_template: top_p
+    min: 0
+    max: 1
+    default: 1
+  - name: max_tokens
+    use_template: max_tokens
+    min: 1
+    max: 2048
+    default: 512
+  - name: frequency_penalty
+    use_template: frequency_penalty
+    min: -2
+    max: 2
+    default: 0
+  - name: presence_penalty
+    use_template: presence_penalty
+    min: -2
+    max: 2
+    default: 0

+ 36 - 0
api/core/model_runtime/model_providers/novita/llm/llama-3-70b-instruct.yaml

@@ -0,0 +1,36 @@
+model: meta-llama/llama-3-70b-instruct
+label:
+  zh_Hans: meta-llama/llama-3-70b-instruct
+  en_US: meta-llama/llama-3-70b-instruct
+model_type: llm
+features:
+  - agent-thought
+model_properties:
+  mode: chat
+  context_size: 8192
+parameter_rules:
+  - name: temperature
+    use_template: temperature
+    min: 0
+    max: 2
+    default: 1
+  - name: top_p
+    use_template: top_p
+    min: 0
+    max: 1
+    default: 1
+  - name: max_tokens
+    use_template: max_tokens
+    min: 1
+    max: 2048
+    default: 512
+  - name: frequency_penalty
+    use_template: frequency_penalty
+    min: -2
+    max: 2
+    default: 0
+  - name: presence_penalty
+    use_template: presence_penalty
+    min: -2
+    max: 2
+    default: 0

+ 36 - 0
api/core/model_runtime/model_providers/novita/llm/llama-3-8b-instruct.yaml

@@ -0,0 +1,36 @@
+model: meta-llama/llama-3-8b-instruct
+label:
+  zh_Hans: meta-llama/llama-3-8b-instruct
+  en_US: meta-llama/llama-3-8b-instruct
+model_type: llm
+features:
+  - agent-thought
+model_properties:
+  mode: chat
+  context_size: 8192
+parameter_rules:
+  - name: temperature
+    use_template: temperature
+    min: 0
+    max: 2
+    default: 1
+  - name: top_p
+    use_template: top_p
+    min: 0
+    max: 1
+    default: 1
+  - name: max_tokens
+    use_template: max_tokens
+    min: 1
+    max: 2048
+    default: 512
+  - name: frequency_penalty
+    use_template: frequency_penalty
+    min: -2
+    max: 2
+    default: 0
+  - name: presence_penalty
+    use_template: presence_penalty
+    min: -2
+    max: 2
+    default: 0

+ 48 - 0
api/core/model_runtime/model_providers/novita/llm/llm.py

@@ -0,0 +1,48 @@
+from collections.abc import Generator
+from typing import Optional, Union
+
+from core.model_runtime.entities.llm_entities import LLMResult
+from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool
+from core.model_runtime.entities.model_entities import AIModelEntity
+from core.model_runtime.model_providers.openai_api_compatible.llm.llm import OAIAPICompatLargeLanguageModel
+
+
+class NovitaLargeLanguageModel(OAIAPICompatLargeLanguageModel):
+
+    def _update_endpoint_url(self, credentials: dict):
+        credentials['endpoint_url'] = "https://api.novita.ai/v3/openai"
+        credentials['extra_headers'] = { 'X-Novita-Source': 'dify.ai' }
+        return credentials
+
+    def _invoke(self, model: str, credentials: dict,
+                prompt_messages: list[PromptMessage], model_parameters: dict,
+                tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
+                stream: bool = True, user: Optional[str] = None) \
+            -> Union[LLMResult, Generator]:
+        cred_with_endpoint = self._update_endpoint_url(credentials=credentials)
+        return super()._invoke(model, cred_with_endpoint, prompt_messages, model_parameters, tools, stop, stream, user)
+    def validate_credentials(self, model: str, credentials: dict) -> None:
+        cred_with_endpoint = self._update_endpoint_url(credentials=credentials)
+        self._add_custom_parameters(credentials, model)
+        return super().validate_credentials(model, cred_with_endpoint)
+
+    @classmethod
+    def _add_custom_parameters(cls, credentials: dict, model: str) -> None:
+        credentials['mode'] = 'chat'
+
+    def _generate(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict,
+                  tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
+                  stream: bool = True, user: Optional[str] = None) -> Union[LLMResult, Generator]:
+        cred_with_endpoint = self._update_endpoint_url(credentials=credentials)
+        return super()._generate(model, cred_with_endpoint, prompt_messages, model_parameters, tools, stop, stream, user)
+
+    def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity:
+        cred_with_endpoint = self._update_endpoint_url(credentials=credentials)
+
+        return super().get_customizable_model_schema(model, cred_with_endpoint)
+
+    def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
+                       tools: Optional[list[PromptMessageTool]] = None) -> int:
+        cred_with_endpoint = self._update_endpoint_url(credentials=credentials)
+
+        return super().get_num_tokens(model, cred_with_endpoint, prompt_messages, tools)

+ 36 - 0
api/core/model_runtime/model_providers/novita/llm/lzlv_70b.yaml

@@ -0,0 +1,36 @@
+model: lzlv_70b
+label:
+  zh_Hans: lzlv_70b
+  en_US: lzlv_70b
+model_type: llm
+features:
+  - agent-thought
+model_properties:
+  mode: chat
+  context_size: 4096
+parameter_rules:
+  - name: temperature
+    use_template: temperature
+    min: 0
+    max: 2
+    default: 1
+  - name: top_p
+    use_template: top_p
+    min: 0
+    max: 1
+    default: 1
+  - name: max_tokens
+    use_template: max_tokens
+    min: 1
+    max: 2048
+    default: 512
+  - name: frequency_penalty
+    use_template: frequency_penalty
+    min: -2
+    max: 2
+    default: 0
+  - name: presence_penalty
+    use_template: presence_penalty
+    min: -2
+    max: 2
+    default: 0

+ 36 - 0
api/core/model_runtime/model_providers/novita/llm/mythomax-l2-13b.yaml

@@ -0,0 +1,36 @@
+model: gryphe/mythomax-l2-13b
+label:
+  zh_Hans: gryphe/mythomax-l2-13b
+  en_US: gryphe/mythomax-l2-13b
+model_type: llm
+features:
+  - agent-thought
+model_properties:
+  mode: chat
+  context_size: 4096
+parameter_rules:
+  - name: temperature
+    use_template: temperature
+    min: 0
+    max: 2
+    default: 1
+  - name: top_p
+    use_template: top_p
+    min: 0
+    max: 1
+    default: 1
+  - name: max_tokens
+    use_template: max_tokens
+    min: 1
+    max: 2048
+    default: 512
+  - name: frequency_penalty
+    use_template: frequency_penalty
+    min: -2
+    max: 2
+    default: 0
+  - name: presence_penalty
+    use_template: presence_penalty
+    min: -2
+    max: 2
+    default: 0

+ 36 - 0
api/core/model_runtime/model_providers/novita/llm/nous-hermes-llama2-13b.yaml

@@ -0,0 +1,36 @@
+model: nousresearch/nous-hermes-llama2-13b
+label:
+  zh_Hans: nousresearch/nous-hermes-llama2-13b
+  en_US: nousresearch/nous-hermes-llama2-13b
+model_type: llm
+features:
+  - agent-thought
+model_properties:
+  mode: chat
+  context_size: 4096
+parameter_rules:
+  - name: temperature
+    use_template: temperature
+    min: 0
+    max: 2
+    default: 1
+  - name: top_p
+    use_template: top_p
+    min: 0
+    max: 1
+    default: 1
+  - name: max_tokens
+    use_template: max_tokens
+    min: 1
+    max: 2048
+    default: 512
+  - name: frequency_penalty
+    use_template: frequency_penalty
+    min: -2
+    max: 2
+    default: 0
+  - name: presence_penalty
+    use_template: presence_penalty
+    min: -2
+    max: 2
+    default: 0

+ 36 - 0
api/core/model_runtime/model_providers/novita/llm/openhermes-2.5-mistral-7b.yaml

@@ -0,0 +1,36 @@
+model: teknium/openhermes-2.5-mistral-7b
+label:
+  zh_Hans: teknium/openhermes-2.5-mistral-7b
+  en_US: teknium/openhermes-2.5-mistral-7b
+model_type: llm
+features:
+  - agent-thought
+model_properties:
+  mode: chat
+  context_size: 4096
+parameter_rules:
+  - name: temperature
+    use_template: temperature
+    min: 0
+    max: 2
+    default: 1
+  - name: top_p
+    use_template: top_p
+    min: 0
+    max: 1
+    default: 1
+  - name: max_tokens
+    use_template: max_tokens
+    min: 1
+    max: 2048
+    default: 512
+  - name: frequency_penalty
+    use_template: frequency_penalty
+    min: -2
+    max: 2
+    default: 0
+  - name: presence_penalty
+    use_template: presence_penalty
+    min: -2
+    max: 2
+    default: 0

+ 36 - 0
api/core/model_runtime/model_providers/novita/llm/wizardlm-2-8x22b.yaml

@@ -0,0 +1,36 @@
+model: microsoft/wizardlm-2-8x22b
+label:
+  zh_Hans: microsoft/wizardlm-2-8x22b
+  en_US: microsoft/wizardlm-2-8x22b
+model_type: llm
+features:
+  - agent-thought
+model_properties:
+  mode: chat
+  context_size: 65535
+parameter_rules:
+  - name: temperature
+    use_template: temperature
+    min: 0
+    max: 2
+    default: 1
+  - name: top_p
+    use_template: top_p
+    min: 0
+    max: 1
+    default: 1
+  - name: max_tokens
+    use_template: max_tokens
+    min: 1
+    max: 2048
+    default: 512
+  - name: frequency_penalty
+    use_template: frequency_penalty
+    min: -2
+    max: 2
+    default: 0
+  - name: presence_penalty
+    use_template: presence_penalty
+    min: -2
+    max: 2
+    default: 0

+ 31 - 0
api/core/model_runtime/model_providers/novita/novita.py

@@ -0,0 +1,31 @@
+import logging
+
+from core.model_runtime.entities.model_entities import ModelType
+from core.model_runtime.errors.validate import CredentialsValidateFailedError
+from core.model_runtime.model_providers.__base.model_provider import ModelProvider
+
+logger = logging.getLogger(__name__)
+
+
+class NovitaProvider(ModelProvider):
+    def validate_provider_credentials(self, credentials: dict) -> None:
+        """
+        Validate provider credentials
+        if validate failed, raise exception
+
+        :param credentials: provider credentials, credentials form defined in `provider_credential_schema`.
+        """
+        try:
+            model_instance = self.get_model_instance(ModelType.LLM)
+
+            # Use `meta-llama/llama-3-8b-instruct` model for validate,
+            # no matter what model you pass in, text completion model or chat model
+            model_instance.validate_credentials(
+                model='meta-llama/llama-3-8b-instruct',
+                credentials=credentials
+            )
+        except CredentialsValidateFailedError as ex:
+            raise ex
+        except Exception as ex:
+            logger.exception(f'{self.get_provider_schema().provider} credentials validate failed')
+            raise ex

+ 28 - 0
api/core/model_runtime/model_providers/novita/novita.yaml

@@ -0,0 +1,28 @@
+provider: novita
+label:
+  en_US: novita.ai
+icon_small:
+  en_US: icon_s_en.svg
+icon_large:
+  en_US: icon_l_en.svg
+background: "#eadeff"
+help:
+  title:
+    en_US: Get your API key from novita.ai
+    zh_Hans: 从 novita.ai 获取 API Key
+  url:
+    en_US: https://novita.ai/dashboard/key?utm_source=dify
+supported_model_types:
+  - llm
+configurate_methods:
+  - predefined-model
+provider_credential_schema:
+  credential_form_schemas:
+    - variable: api_key
+      required: true
+      label:
+        en_US: API Key
+      type: secret-input
+      placeholder:
+        zh_Hans: 在此输入您的 API Key
+        en_US: Enter your API Key

+ 7 - 1
api/core/model_runtime/model_providers/openai_api_compatible/llm/llm.py

@@ -74,7 +74,7 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
             tools=tools,
             stop=stop,
             stream=stream,
-            user=user
+            user=user,
         )
 
     def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
@@ -280,6 +280,12 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
             'Content-Type': 'application/json',
             'Accept-Charset': 'utf-8',
         }
+        extra_headers = credentials.get('extra_headers')
+        if extra_headers is not None:
+            headers = {
+              **headers,
+              **extra_headers,
+            }
 
         api_key = credentials.get('api_key')
         if api_key:

+ 0 - 0
api/tests/integration_tests/model_runtime/novita/__init__.py


+ 123 - 0
api/tests/integration_tests/model_runtime/novita/test_llm.py

@@ -0,0 +1,123 @@
+import os
+from collections.abc import Generator
+
+import pytest
+
+from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
+from core.model_runtime.entities.message_entities import (
+    AssistantPromptMessage,
+    PromptMessageTool,
+    SystemPromptMessage,
+    UserPromptMessage,
+)
+from core.model_runtime.errors.validate import CredentialsValidateFailedError
+from core.model_runtime.model_providers.novita.llm.llm import NovitaLargeLanguageModel
+
+
+def test_validate_credentials():
+    model = NovitaLargeLanguageModel()
+
+    with pytest.raises(CredentialsValidateFailedError):
+        model.validate_credentials(
+            model='meta-llama/llama-3-8b-instruct',
+            credentials={
+                'api_key': 'invalid_key',
+                'mode': 'chat'
+            }
+        )
+
+    model.validate_credentials(
+        model='meta-llama/llama-3-8b-instruct',
+        credentials={
+            'api_key': os.environ.get('NOVITA_API_KEY'),
+            'mode': 'chat'
+        }
+    )
+
+
+def test_invoke_model():
+    model = NovitaLargeLanguageModel()
+
+    response = model.invoke(
+        model='meta-llama/llama-3-8b-instruct',
+        credentials={
+            'api_key': os.environ.get('NOVITA_API_KEY'),
+            'mode': 'completion'
+        },
+        prompt_messages=[
+            SystemPromptMessage(
+                content='You are a helpful AI assistant.',
+            ),
+            UserPromptMessage(
+                content='Who are you?'
+            )
+        ],
+        model_parameters={
+            'temperature': 1.0,
+            'top_p': 0.5,
+            'max_tokens': 10,
+        },
+        stop=['How'],
+        stream=False,
+        user="novita"
+    )
+
+    assert isinstance(response, LLMResult)
+    assert len(response.message.content) > 0
+
+
+def test_invoke_stream_model():
+    model = NovitaLargeLanguageModel()
+
+    response = model.invoke(
+        model='meta-llama/llama-3-8b-instruct',
+        credentials={
+            'api_key': os.environ.get('NOVITA_API_KEY'),
+            'mode': 'chat'
+        },
+        prompt_messages=[
+            SystemPromptMessage(
+                content='You are a helpful AI assistant.',
+            ),
+            UserPromptMessage(
+                content='Who are you?'
+            )
+        ],
+        model_parameters={
+            'temperature': 1.0,
+            'top_k': 2,
+            'top_p': 0.5,
+            'max_tokens': 100
+        },
+        stream=True,
+        user="novita"
+    )
+
+    assert isinstance(response, Generator)
+
+    for chunk in response:
+        assert isinstance(chunk, LLMResultChunk)
+        assert isinstance(chunk.delta, LLMResultChunkDelta)
+        assert isinstance(chunk.delta.message, AssistantPromptMessage)
+
+
+def test_get_num_tokens():
+    model = NovitaLargeLanguageModel()
+
+    num_tokens = model.get_num_tokens(
+        model='meta-llama/llama-3-8b-instruct',
+        credentials={
+            'api_key': os.environ.get('NOVITA_API_KEY'),
+        },
+        prompt_messages=[
+            SystemPromptMessage(
+                content='You are a helpful AI assistant.',
+            ),
+            UserPromptMessage(
+                content='Hello World!'
+            )
+        ]
+    )
+
+    assert isinstance(num_tokens, int)
+    assert num_tokens == 21

+ 21 - 0
api/tests/integration_tests/model_runtime/novita/test_provider.py

@@ -0,0 +1,21 @@
+import os
+
+import pytest
+
+from core.model_runtime.errors.validate import CredentialsValidateFailedError
+from core.model_runtime.model_providers.novita.novita import NovitaProvider
+
+
+def test_validate_provider_credentials():
+    provider = NovitaProvider()
+
+    with pytest.raises(CredentialsValidateFailedError):
+        provider.validate_provider_credentials(
+            credentials={}
+        )
+
+    provider.validate_provider_credentials(
+        credentials={
+            'api_key': os.environ.get('NOVITA_API_KEY'),
+        }
+    )