Browse Source

feat: add OpenAI o1 series models support (#8328)

takatost 7 months ago
parent
commit
e90d3c29ab

+ 4 - 0
api/core/model_runtime/model_providers/openai/llm/_position.yaml

@@ -5,6 +5,10 @@
 - chatgpt-4o-latest
 - gpt-4o-mini
 - gpt-4o-mini-2024-07-18
+- o1-preview
+- o1-preview-2024-09-12
+- o1-mini
+- o1-mini-2024-09-12
 - gpt-4-turbo
 - gpt-4-turbo-2024-04-09
 - gpt-4-turbo-preview

+ 42 - 3
api/core/model_runtime/model_providers/openai/llm/llm.py

@@ -613,6 +613,13 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel):
         # clear illegal prompt messages
         prompt_messages = self._clear_illegal_prompt_messages(model, prompt_messages)
 
+        block_as_stream = False
+        if model.startswith("o1"):
+            block_as_stream = True
+            stream = False
+            if "stream_options" in extra_model_kwargs:
+                del extra_model_kwargs["stream_options"]
+
         # chat model
         response = client.chat.completions.create(
             messages=[self._convert_prompt_message_to_dict(m) for m in prompt_messages],
@@ -625,7 +632,39 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel):
         if stream:
             return self._handle_chat_generate_stream_response(model, credentials, response, prompt_messages, tools)
 
-        return self._handle_chat_generate_response(model, credentials, response, prompt_messages, tools)
+        block_result = self._handle_chat_generate_response(model, credentials, response, prompt_messages, tools)
+
+        if block_as_stream:
+            return self._handle_chat_block_as_stream_response(block_result, prompt_messages)
+
+        return block_result
+
+    def _handle_chat_block_as_stream_response(
+        self,
+        block_result: LLMResult,
+        prompt_messages: list[PromptMessage],
+    ) -> Generator[LLMResultChunk, None, None]:
+        """
+        Handle llm chat response
+
+        :param model: model name
+        :param credentials: credentials
+        :param response: response
+        :param prompt_messages: prompt messages
+        :param tools: tools for tool calling
+        :return: llm response chunk generator
+        """
+        yield LLMResultChunk(
+            model=block_result.model,
+            prompt_messages=prompt_messages,
+            system_fingerprint=block_result.system_fingerprint,
+            delta=LLMResultChunkDelta(
+                index=0,
+                message=block_result.message,
+                finish_reason="stop",
+                usage=block_result.usage,
+            ),
+        )
 
     def _handle_chat_generate_response(
         self,
@@ -960,7 +999,7 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel):
             model = model.split(":")[1]
 
         # Currently, we can use gpt4o to calculate chatgpt-4o-latest's token.
-        if model == "chatgpt-4o-latest":
+        if model == "chatgpt-4o-latest" or model.startswith("o1"):
             model = "gpt-4o"
 
         try:
@@ -975,7 +1014,7 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel):
             tokens_per_message = 4
             # if there's a name, the role is omitted
             tokens_per_name = -1
-        elif model.startswith("gpt-3.5-turbo") or model.startswith("gpt-4"):
+        elif model.startswith("gpt-3.5-turbo") or model.startswith("gpt-4") or model.startswith("o1"):
             tokens_per_message = 3
             tokens_per_name = 1
         else:

+ 33 - 0
api/core/model_runtime/model_providers/openai/llm/o1-mini-2024-09-12.yaml

@@ -0,0 +1,33 @@
+model: o1-mini-2024-09-12
+label:
+  zh_Hans: o1-mini-2024-09-12
+  en_US: o1-mini-2024-09-12
+model_type: llm
+features:
+  - agent-thought
+model_properties:
+  mode: chat
+  context_size: 128000
+parameter_rules:
+  - name: max_tokens
+    use_template: max_tokens
+    default: 65563
+    min: 1
+    max: 65563
+  - name: response_format
+    label:
+      zh_Hans: 回复格式
+      en_US: response_format
+    type: string
+    help:
+      zh_Hans: 指定模型必须输出的格式
+      en_US: specifying the format that the model must output
+    required: false
+    options:
+      - text
+      - json_object
+pricing:
+  input: '3.00'
+  output: '12.00'
+  unit: '0.000001'
+  currency: USD

+ 33 - 0
api/core/model_runtime/model_providers/openai/llm/o1-mini.yaml

@@ -0,0 +1,33 @@
+model: o1-mini
+label:
+  zh_Hans: o1-mini
+  en_US: o1-mini
+model_type: llm
+features:
+  - agent-thought
+model_properties:
+  mode: chat
+  context_size: 128000
+parameter_rules:
+  - name: max_tokens
+    use_template: max_tokens
+    default: 65563
+    min: 1
+    max: 65563
+  - name: response_format
+    label:
+      zh_Hans: 回复格式
+      en_US: response_format
+    type: string
+    help:
+      zh_Hans: 指定模型必须输出的格式
+      en_US: specifying the format that the model must output
+    required: false
+    options:
+      - text
+      - json_object
+pricing:
+  input: '3.00'
+  output: '12.00'
+  unit: '0.000001'
+  currency: USD

+ 33 - 0
api/core/model_runtime/model_providers/openai/llm/o1-preview-2024-09-12.yaml

@@ -0,0 +1,33 @@
+model: o1-preview-2024-09-12
+label:
+  zh_Hans: o1-preview-2024-09-12
+  en_US: o1-preview-2024-09-12
+model_type: llm
+features:
+  - agent-thought
+model_properties:
+  mode: chat
+  context_size: 128000
+parameter_rules:
+  - name: max_tokens
+    use_template: max_tokens
+    default: 32768
+    min: 1
+    max: 32768
+  - name: response_format
+    label:
+      zh_Hans: 回复格式
+      en_US: response_format
+    type: string
+    help:
+      zh_Hans: 指定模型必须输出的格式
+      en_US: specifying the format that the model must output
+    required: false
+    options:
+      - text
+      - json_object
+pricing:
+  input: '15.00'
+  output: '60.00'
+  unit: '0.000001'
+  currency: USD

+ 33 - 0
api/core/model_runtime/model_providers/openai/llm/o1-preview.yaml

@@ -0,0 +1,33 @@
+model: o1-preview
+label:
+  zh_Hans: o1-preview
+  en_US: o1-preview
+model_type: llm
+features:
+  - agent-thought
+model_properties:
+  mode: chat
+  context_size: 128000
+parameter_rules:
+  - name: max_tokens
+    use_template: max_tokens
+    default: 32768
+    min: 1
+    max: 32768
+  - name: response_format
+    label:
+      zh_Hans: 回复格式
+      en_US: response_format
+    type: string
+    help:
+      zh_Hans: 指定模型必须输出的格式
+      en_US: specifying the format that the model must output
+    required: false
+    options:
+      - text
+      - json_object
+pricing:
+  input: '15.00'
+  output: '60.00'
+  unit: '0.000001'
+  currency: USD

+ 2 - 1
api/pyproject.toml

@@ -60,7 +60,8 @@ ignore = [
     "SIM113", # eumerate-for-loop
     "SIM117", # multiple-with-statements
     "SIM210", # if-expr-with-true-false
-    "SIM300", # yoda-conditions
+    "SIM300", # yoda-conditions,
+    "PT004", # pytest-no-assert
 ]
 
 [tool.ruff.lint.per-file-ignores]