Procházet zdrojové kódy

Add Together.ai's OpenAI API-compatible inference endpoints (#1947)

Chenhe Gu před 1 rokem
rodič
revize
6075fee556

+ 1 - 1
api/core/model_runtime/model_providers/__base/model_provider.py

@@ -112,7 +112,7 @@ class ModelProvider(ABC):
         model_class = None
         for name, obj in vars(mod).items():
             if (isinstance(obj, type) and issubclass(obj, AIModel) and not obj.__abstractmethods__
-                    and obj != AIModel):
+                    and obj != AIModel and obj.__module__ == mod.__name__):
                 model_class = obj
                 break
 

+ 1 - 84
api/core/model_runtime/model_providers/openai_api_compatible/_common.py

@@ -40,87 +40,4 @@ class _CommonOAI_API_Compat:
                 requests.exceptions.ConnectTimeout,  # Timeout
                 requests.exceptions.ReadTimeout  # Timeout
             ]
-        }
-
-    def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity:
-        """
-            generate custom model entities from credentials
-        """
-        model_type = ModelType.LLM if credentials.get('__model_type') == 'llm' else ModelType.TEXT_EMBEDDING
-        
-        entity = AIModelEntity(
-            model=model,
-            label=I18nObject(en_US=model),
-            model_type=model_type,
-            fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
-            model_properties={
-                ModelPropertyKey.CONTEXT_SIZE: credentials.get('context_size', 16000),
-                ModelPropertyKey.MAX_CHUNKS: credentials.get('max_chunks', 1),
-            },
-            parameter_rules=[
-                ParameterRule(
-                    name=DefaultParameterName.TEMPERATURE.value,
-                    label=I18nObject(en_US="Temperature"),
-                    type=ParameterType.FLOAT,
-                    default=float(credentials.get('temperature', 1)),
-                    min=0,
-                    max=2
-                ),
-                ParameterRule(
-                    name=DefaultParameterName.TOP_P.value,
-                    label=I18nObject(en_US="Top P"),
-                    type=ParameterType.FLOAT,
-                    default=float(credentials.get('top_p', 1)),
-                    min=0,
-                    max=1
-                ),
-                ParameterRule(
-                    name="top_k",
-                    label=I18nObject(en_US="Top K"),
-                    type=ParameterType.INT,
-                    default=int(credentials.get('top_k', 1)),
-                    min=1,
-                    max=100
-                ),
-                ParameterRule(
-                    name=DefaultParameterName.FREQUENCY_PENALTY.value,
-                    label=I18nObject(en_US="Frequency Penalty"),
-                    type=ParameterType.FLOAT,
-                    default=float(credentials.get('frequency_penalty', 0)),
-                    min=-2,
-                    max=2
-                ),
-                ParameterRule(
-                    name=DefaultParameterName.PRESENCE_PENALTY.value,
-                    label=I18nObject(en_US="PRESENCE Penalty"),
-                    type=ParameterType.FLOAT,
-                    default=float(credentials.get('PRESENCE_penalty', 0)),
-                    min=-2,
-                    max=2
-                ),
-                ParameterRule(
-                    name=DefaultParameterName.MAX_TOKENS.value,
-                    label=I18nObject(en_US="Max Tokens"),
-                    type=ParameterType.INT,
-                    default=1024,
-                    min=1,
-                    max=int(credentials.get('max_tokens_to_sample', 4096)),
-                )
-            ],
-            pricing=PriceConfig(
-                input=Decimal(credentials.get('input_price', 0)),
-                output=Decimal(credentials.get('output_price', 0)),
-                unit=Decimal(credentials.get('unit', 0)),
-                currency=credentials.get('currency', "USD")
-            )
-        )
-
-        if model_type == ModelType.LLM:
-            if credentials['mode'] == 'chat':
-                entity.model_properties[ModelPropertyKey.MODE] = LLMMode.CHAT.value
-            elif credentials['mode'] == 'completion':
-                entity.model_properties[ModelPropertyKey.MODE] = LLMMode.COMPLETION.value
-            else:
-                raise ValueError(f"Unknown completion type {credentials['completion_type']}")
-        
-        return entity
+        }

+ 11 - 8
api/core/model_runtime/model_providers/openai_api_compatible/llm/llm.py

@@ -158,7 +158,7 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
             model_type=ModelType.LLM,
             fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
             model_properties={
-                ModelPropertyKey.CONTEXT_SIZE: int(credentials.get('context_size')),
+                ModelPropertyKey.CONTEXT_SIZE: int(credentials.get('context_size', "4096")),
                 ModelPropertyKey.MODE: credentials.get('mode'),
             },
             parameter_rules=[
@@ -196,9 +196,9 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
                 ),
                 ParameterRule(
                     name=DefaultParameterName.PRESENCE_PENALTY.value,
-                    label=I18nObject(en_US="PRESENCE Penalty"),
+                    label=I18nObject(en_US="Presence Penalty"),
                     type=ParameterType.FLOAT,
-                    default=float(credentials.get('PRESENCE_penalty', 0)),
+                    default=float(credentials.get('presence_penalty', 0)),
                     min=-2,
                     max=2
                 ),
@@ -219,6 +219,13 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
             )
         )
 
+        if credentials['mode'] == 'chat':
+            entity.model_properties[ModelPropertyKey.MODE] = LLMMode.CHAT.value
+        elif credentials['mode'] == 'completion':
+            entity.model_properties[ModelPropertyKey.MODE] = LLMMode.COMPLETION.value
+        else:
+            raise ValueError(f"Unknown completion type {credentials['completion_type']}")
+    
         return entity
 
     # validate_credentials method has been rewritten to use the requests library for compatibility with all providers following OpenAI's API standard.
@@ -261,7 +268,7 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
         if completion_type is LLMMode.CHAT:
             endpoint_url = urljoin(endpoint_url, 'chat/completions')
             data['messages'] = [self._convert_prompt_message_to_dict(m) for m in prompt_messages]
-        elif completion_type == LLMMode.COMPLETION:
+        elif completion_type is LLMMode.COMPLETION:
             endpoint_url = urljoin(endpoint_url, 'completions')
             data['prompt'] = prompt_messages[0].content
         else:
@@ -291,10 +298,6 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
             stream=stream
         )
 
-        # Debug: Print request headers and json data
-        logger.debug(f"Request headers: {headers}")
-        logger.debug(f"Request JSON data: {data}")
-
         if response.status_code != 200:
             raise InvokeError(f"API request failed with status code {response.status_code}: {response.text}")
 

+ 2 - 2
api/core/model_runtime/model_providers/openai_api_compatible/openai_api_compatible.yaml

@@ -2,8 +2,8 @@ provider: openai_api_compatible
 label:
   en_US: OpenAI-API-compatible
 description:
-  en_US: All model providers compatible with OpenAI's API standard, such as Together.ai.
-  zh_Hans: 兼容 OpenAI API 的模型供应商,例如 Together.ai
+  en_US: Model providers compatible with OpenAI's API standard, such as LM Studio.
+  zh_Hans: 兼容 OpenAI API 的模型供应商,例如 LM Studio 
 supported_model_types:
 - llm
 - text-embedding

+ 1 - 1
api/core/model_runtime/model_providers/openai_api_compatible/text_embedding/text_embedding.py

@@ -112,7 +112,7 @@ class OAICompatEmbeddingModel(_CommonOAI_API_Compat, TextEmbeddingModel):
             credentials=credentials,
             tokens=used_tokens
         )
-
+        
         return TextEmbeddingResult(
             embeddings=batched_embeddings,
             usage=usage,

docker/volumes/db/scripts/init_extension.sh → api/core/model_runtime/model_providers/togetherai/__init__.py


Rozdílová data souboru nebyla zobrazena, protože soubor je příliš velký
+ 13 - 0
api/core/model_runtime/model_providers/togetherai/_assets/togetherai.svg


+ 19 - 0
api/core/model_runtime/model_providers/togetherai/_assets/togetherai_square.svg

@@ -0,0 +1,19 @@
+<svg width="16" height="16" viewBox="0 0 16 16" fill="none" xmlns="http://www.w3.org/2000/svg">
+<g clip-path="url(#clip0_15960_46917)">
+<mask id="mask0_15960_46917" style="mask-type:luminance" maskUnits="userSpaceOnUse" x="0" y="0" width="16" height="16">
+<path d="M16 0H0V16H16V0Z" fill="white"/>
+</mask>
+<g mask="url(#mask0_15960_46917)">
+<path d="M13.1765 0H2.82353C1.26414 0 0 1.26414 0 2.82353V13.1765C0 14.7359 1.26414 16 2.82353 16H13.1765C14.7359 16 16 14.7359 16 13.1765V2.82353C16 1.26414 14.7359 0 13.1765 0Z" fill="#F1EFED"/>
+<path d="M11.4119 7.64706C12.9713 7.64706 14.2354 6.38292 14.2354 4.82353C14.2354 3.26414 12.9713 2 11.4119 2C9.85252 2 8.58838 3.26414 8.58838 4.82353C8.58838 6.38292 9.85252 7.64706 11.4119 7.64706Z" fill="#D3D1D1"/>
+<path d="M11.4119 14.2354C12.9713 14.2354 14.2354 12.9713 14.2354 11.4119C14.2354 9.85252 12.9713 8.58838 11.4119 8.58838C9.85252 8.58838 8.58838 9.85252 8.58838 11.4119C8.58838 12.9713 9.85252 14.2354 11.4119 14.2354Z" fill="#D3D1D1"/>
+<path d="M4.82353 14.2354C6.38292 14.2354 7.64706 12.9713 7.64706 11.4119C7.64706 9.85252 6.38292 8.58838 4.82353 8.58838C3.26414 8.58838 2 9.85252 2 11.4119C2 12.9713 3.26414 14.2354 4.82353 14.2354Z" fill="#D3D1D1"/>
+<path d="M4.82353 7.64706C6.38292 7.64706 7.64706 6.38292 7.64706 4.82353C7.64706 3.26414 6.38292 2 4.82353 2C3.26414 2 2 3.26414 2 4.82353C2 6.38292 3.26414 7.64706 4.82353 7.64706Z" fill="#0F6FFF"/>
+</g>
+</g>
+<defs>
+<clipPath id="clip0_15960_46917">
+<rect width="16" height="16" fill="white"/>
+</clipPath>
+</defs>
+</svg>

+ 0 - 0
api/core/model_runtime/model_providers/togetherai/llm/__init__.py


+ 45 - 0
api/core/model_runtime/model_providers/togetherai/llm/llm.py

@@ -0,0 +1,45 @@
+from typing import Generator, List, Optional, Union
+from core.model_runtime.entities.llm_entities import LLMResult
+from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool
+from core.model_runtime.entities.model_entities import AIModelEntity
+from core.model_runtime.model_providers.openai_api_compatible.llm.llm import OAIAPICompatLargeLanguageModel
+
+class TogetherAILargeLanguageModel(OAIAPICompatLargeLanguageModel):
+
+    def _update_endpoint_url(self, credentials: dict):
+        credentials['endpoint_url'] = "https://api.together.xyz/v1"
+        return credentials
+
+    def _invoke(self, model: str, credentials: dict,
+                prompt_messages: list[PromptMessage], model_parameters: dict,
+                tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None,
+                stream: bool = True, user: Optional[str] = None) \
+            -> Union[LLMResult, Generator]:
+        cred_with_endpoint = self._update_endpoint_url(credentials=credentials)
+
+        return super()._invoke(model, cred_with_endpoint, prompt_messages, model_parameters, tools, stop, stream, user)
+
+    def validate_credentials(self, model: str, credentials: dict) -> None:
+        cred_with_endpoint = self._update_endpoint_url(credentials=credentials)
+
+        return super().validate_credentials(model, cred_with_endpoint)
+
+    def _generate(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict,
+                  tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None,
+                  stream: bool = True, user: Optional[str] = None) -> Union[LLMResult, Generator]:
+        cred_with_endpoint = self._update_endpoint_url(credentials=credentials)
+
+        return super()._generate(model, cred_with_endpoint, prompt_messages, model_parameters, tools, stop, stream, user)
+
+    def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity:
+        cred_with_endpoint = self._update_endpoint_url(credentials=credentials)
+
+        return super().get_customizable_model_schema(model, cred_with_endpoint)
+
+    def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
+                       tools: Optional[list[PromptMessageTool]] = None) -> int:
+        cred_with_endpoint = self._update_endpoint_url(credentials=credentials)
+
+        return super().get_num_tokens(model, cred_with_endpoint, prompt_messages, tools)
+
+

+ 13 - 0
api/core/model_runtime/model_providers/togetherai/togetherai.py

@@ -0,0 +1,13 @@
+import logging
+
+from core.model_runtime.entities.model_entities import ModelType
+from core.model_runtime.errors.validate import CredentialsValidateFailedError
+from core.model_runtime.model_providers.__base.model_provider import ModelProvider
+
+logger = logging.getLogger(__name__)
+
+
+class TogetherAIProvider(ModelProvider):
+
+    def validate_provider_credentials(self, credentials: dict) -> None:
+        pass

+ 75 - 0
api/core/model_runtime/model_providers/togetherai/togetherai.yaml

@@ -0,0 +1,75 @@
+provider: togetherai
+label:
+  en_US: together.ai
+icon_small:
+    en_US: togetherai_square.svg
+icon_large:
+    en_US: togetherai.svg
+background: "#F1EFED"
+help:
+  title:
+    en_US: Get your API key from together.ai
+    zh_Hans: 从 together.ai 获取 API Key
+  url:
+    en_US: https://api.together.xyz/
+supported_model_types:
+- llm
+configurate_methods:
+- customizable-model
+model_credential_schema:
+  model:
+    label:
+      en_US: Model Name
+      zh_Hans: 模型名称
+    placeholder:
+      en_US: Enter full model name   
+      zh_Hans: 输入模型全称
+  credential_form_schemas:
+  - variable: api_key
+    label:
+      en_US: API Key
+    type: secret-input
+    required: false
+    placeholder:
+      zh_Hans: 在此输入您的 API Key
+      en_US: Enter your API Key
+  - variable: mode
+    show_on:
+      - variable: __model_type
+        value: llm
+    label:
+      en_US: Completion mode
+    type: select
+    required: false
+    default: chat
+    placeholder:
+      zh_Hans: 选择对话类型
+      en_US: Select completion mode
+    options:
+      - value: completion
+        label:
+          en_US: Completion
+          zh_Hans: 补全
+      - value: chat
+        label:
+          en_US: Chat
+          zh_Hans: 对话
+  - variable: context_size
+    label:
+      zh_Hans: 模型上下文长度
+      en_US: Model context size
+    required: true
+    type: text-input
+    default: '4096'
+    placeholder:
+      zh_Hans: 在此输入您的模型上下文长度
+      en_US: Enter your Model context size
+  - variable: max_tokens_to_sample
+    label:
+      zh_Hans: 最大 token 上限
+      en_US: Upper bound for max tokens
+    show_on:
+    - variable: __model_type
+      value: llm
+    default: '4096'
+    type: text-input

+ 4 - 2
api/tests/integration_tests/model_runtime/openai/test_text_embedding.py

@@ -39,13 +39,15 @@ def test_invoke_model(setup_openai_mock):
         },
         texts=[
             "hello",
-            "world"
+            "world",
+            " ".join(["long_text"] * 100),
+            " ".join(["another_long_text"] * 100)
         ],
         user="abc-123"
     )
 
     assert isinstance(result, TextEmbeddingResult)
-    assert len(result.embeddings) == 2
+    assert len(result.embeddings) == 4
     assert result.usage.total_tokens == 2
 
 

+ 5 - 3
api/tests/integration_tests/model_runtime/openai_api_compatible/test_text_embedding.py

@@ -46,14 +46,16 @@ def test_invoke_model():
         },
         texts=[
             "hello",
-            "world"
+            "world",
+            " ".join(["long_text"] * 100),
+            " ".join(["another_long_text"] * 100)
         ],
         user="abc-123"
     )
 
     assert isinstance(result, TextEmbeddingResult)
-    assert len(result.embeddings) == 2
-    assert result.usage.total_tokens == 2
+    assert len(result.embeddings) == 4
+    assert result.usage.total_tokens == 502
 
 
 def test_get_num_tokens():

+ 0 - 0
api/tests/integration_tests/model_runtime/togetherai/__init__.py


+ 117 - 0
api/tests/integration_tests/model_runtime/togetherai/test_llm.py

@@ -0,0 +1,117 @@
+import os
+from typing import Generator
+
+import pytest
+
+from core.model_runtime.entities.message_entities import AssistantPromptMessage, UserPromptMessage, \
+    SystemPromptMessage, PromptMessageTool
+from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunkDelta, \
+    LLMResultChunk
+from core.model_runtime.errors.validate import CredentialsValidateFailedError
+from core.model_runtime.model_providers.togetherai.llm.llm import TogetherAILargeLanguageModel
+
+
+def test_validate_credentials():
+    model = TogetherAILargeLanguageModel()
+
+    with pytest.raises(CredentialsValidateFailedError):
+        model.validate_credentials(
+            model='mistralai/Mixtral-8x7B-Instruct-v0.1',
+            credentials={
+                'api_key': 'invalid_key',
+                'mode': 'chat'
+            }
+        )
+
+    model.validate_credentials(
+        model='mistralai/Mixtral-8x7B-Instruct-v0.1',
+        credentials={
+            'api_key': os.environ.get('TOGETHER_API_KEY'),
+            'mode': 'chat'
+        }
+    )
+
+def test_invoke_model():
+    model = TogetherAILargeLanguageModel()
+
+    response = model.invoke(
+        model='mistralai/Mixtral-8x7B-Instruct-v0.1',
+        credentials={
+            'api_key': os.environ.get('TOGETHER_API_KEY'),
+            'mode': 'completion'
+        },
+        prompt_messages=[
+            SystemPromptMessage(
+                content='You are a helpful AI assistant.',
+            ),
+            UserPromptMessage(
+                content='Who are you?'
+            )
+        ],
+        model_parameters={
+            'temperature': 1.0,
+            'top_k': 2,
+            'top_p': 0.5,
+        },
+        stop=['How'],
+        stream=False,
+        user="abc-123"
+    )
+
+    assert isinstance(response, LLMResult)
+    assert len(response.message.content) > 0
+
+def test_invoke_stream_model():
+    model = TogetherAILargeLanguageModel()
+
+    response = model.invoke(
+        model='mistralai/Mixtral-8x7B-Instruct-v0.1',
+        credentials={
+            'api_key': os.environ.get('TOGETHER_API_KEY'),
+            'mode': 'chat'
+        },
+        prompt_messages=[
+            SystemPromptMessage(
+                content='You are a helpful AI assistant.',
+            ),
+            UserPromptMessage(
+                content='Who are you?'
+            )
+        ],
+        model_parameters={
+            'temperature': 1.0,
+            'top_k': 2,
+            'top_p': 0.5,
+        },
+        stop=['How'],
+        stream=True,
+        user="abc-123"
+    )
+
+    assert isinstance(response, Generator)
+
+    for chunk in response:
+        assert isinstance(chunk, LLMResultChunk)
+        assert isinstance(chunk.delta, LLMResultChunkDelta)
+        assert isinstance(chunk.delta.message, AssistantPromptMessage)
+
+def test_get_num_tokens():
+    model = TogetherAILargeLanguageModel()
+
+    num_tokens = model.get_num_tokens(
+        model='mistralai/Mixtral-8x7B-Instruct-v0.1',
+        credentials={
+            'api_key': os.environ.get('TOGETHER_API_KEY'),
+        },
+        prompt_messages=[
+            SystemPromptMessage(
+                content='You are a helpful AI assistant.',
+            ),
+            UserPromptMessage(
+                content='Hello World!'
+            )
+        ]
+    )
+
+    assert isinstance(num_tokens, int)
+    assert num_tokens == 21