Browse Source

feat: Add fireworks custom llm intergration (#9333)

ice yao 6 months ago
parent
commit
aba70207ab

+ 73 - 0
api/core/model_runtime/model_providers/fireworks/fireworks.yaml

@@ -18,6 +18,7 @@ supported_model_types:
   - text-embedding
 configurate_methods:
   - predefined-model
+  - customizable-model
 provider_credential_schema:
   credential_form_schemas:
     - variable: fireworks_api_key
@@ -28,3 +29,75 @@ provider_credential_schema:
       placeholder:
         zh_Hans: 在此输入您的 API Key
         en_US: Enter your API Key
+model_credential_schema:
+  model:
+    label:
+      en_US: Model URL
+      zh_Hans: 模型URL
+    placeholder:
+      en_US: Enter your Model URL
+      zh_Hans: 输入模型URL
+  credential_form_schemas:
+    - variable: model_label_zh_Hanns
+      label:
+        zh_Hans: 模型中文名称
+        en_US: The zh_Hans of Model
+      required: true
+      type: text-input
+      placeholder:
+        zh_Hans: 在此输入您的模型中文名称
+        en_US: Enter your zh_Hans of Model
+    - variable: model_label_en_US
+      label:
+        zh_Hans: 模型英文名称
+        en_US: The en_US of Model
+      required: true
+      type: text-input
+      placeholder:
+        zh_Hans: 在此输入您的模型英文名称
+        en_US: Enter your en_US of Model
+    - variable: fireworks_api_key
+      label:
+        en_US: API Key
+      type: secret-input
+      required: true
+      placeholder:
+        zh_Hans: 在此输入您的 API Key
+        en_US: Enter your API Key
+    - variable: context_size
+      label:
+        zh_Hans: 模型上下文长度
+        en_US: Model context size
+      required: true
+      type: text-input
+      default: '4096'
+      placeholder:
+        zh_Hans: 在此输入您的模型上下文长度
+        en_US: Enter your Model context size
+    - variable: max_tokens
+      label:
+        zh_Hans: 最大 token 上限
+        en_US: Upper bound for max tokens
+      default: '4096'
+      type: text-input
+      show_on:
+        - variable: __model_type
+          value: llm
+    - variable: function_calling_type
+      label:
+        en_US: Function calling
+      type: select
+      required: false
+      default: no_call
+      options:
+        - value: no_call
+          label:
+            en_US: Not Support
+            zh_Hans: 不支持
+        - value: function_call
+          label:
+            en_US: Support
+            zh_Hans: 支持
+      show_on:
+        - variable: __model_type
+          value: llm

+ 58 - 1
api/core/model_runtime/model_providers/fireworks/llm/llm.py

@@ -8,7 +8,8 @@ from openai.types.chat.chat_completion_chunk import ChoiceDeltaFunctionCall, Cho
 from openai.types.chat.chat_completion_message import FunctionCall
 
 from core.model_runtime.callbacks.base_callback import Callback
-from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
+from core.model_runtime.entities.common_entities import I18nObject
+from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta
 from core.model_runtime.entities.message_entities import (
     AssistantPromptMessage,
     ImagePromptMessageContent,
@@ -20,6 +21,15 @@ from core.model_runtime.entities.message_entities import (
     ToolPromptMessage,
     UserPromptMessage,
 )
+from core.model_runtime.entities.model_entities import (
+    AIModelEntity,
+    FetchFrom,
+    ModelFeature,
+    ModelPropertyKey,
+    ModelType,
+    ParameterRule,
+    ParameterType,
+)
 from core.model_runtime.errors.validate import CredentialsValidateFailedError
 from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
 from core.model_runtime.model_providers.fireworks._common import _CommonFireworks
@@ -608,3 +618,50 @@ class FireworksLargeLanguageModel(_CommonFireworks, LargeLanguageModel):
                     num_tokens += self._get_num_tokens_by_gpt2(required_field)
 
         return num_tokens
+
+    def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity:
+        return AIModelEntity(
+            model=model,
+            label=I18nObject(
+                en_US=credentials.get("model_label_en_US", model),
+                zh_Hans=credentials.get("model_label_zh_Hanns", model),
+            ),
+            model_type=ModelType.LLM,
+            features=[ModelFeature.TOOL_CALL, ModelFeature.MULTI_TOOL_CALL, ModelFeature.STREAM_TOOL_CALL]
+            if credentials.get("function_calling_type") == "function_call"
+            else [],
+            fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
+            model_properties={
+                ModelPropertyKey.CONTEXT_SIZE: int(credentials.get("context_size", 4096)),
+                ModelPropertyKey.MODE: LLMMode.CHAT.value,
+            },
+            parameter_rules=[
+                ParameterRule(
+                    name="temperature",
+                    use_template="temperature",
+                    label=I18nObject(en_US="Temperature", zh_Hans="温度"),
+                    type=ParameterType.FLOAT,
+                ),
+                ParameterRule(
+                    name="max_tokens",
+                    use_template="max_tokens",
+                    default=512,
+                    min=1,
+                    max=int(credentials.get("max_tokens", 4096)),
+                    label=I18nObject(en_US="Max Tokens", zh_Hans="最大标记"),
+                    type=ParameterType.INT,
+                ),
+                ParameterRule(
+                    name="top_p",
+                    use_template="top_p",
+                    label=I18nObject(en_US="Top P", zh_Hans="Top P"),
+                    type=ParameterType.FLOAT,
+                ),
+                ParameterRule(
+                    name="top_k",
+                    use_template="top_k",
+                    label=I18nObject(en_US="Top K", zh_Hans="Top K"),
+                    type=ParameterType.FLOAT,
+                ),
+            ],
+        )