Browse Source

azure add o1-mini、o1-preview models (#9088)

Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com>
Charlie.Wei 6 months ago
parent
commit
55679b4389

+ 74 - 1
api/core/model_runtime/model_providers/azure_openai/_constant.py

@@ -1081,8 +1081,81 @@ LLM_BASE_MODELS = [
             ),
         ),
     ),
+    AzureBaseModel(
+        base_model_name="o1-preview",
+        entity=AIModelEntity(
+            model="fake-deployment-name",
+            label=I18nObject(
+                en_US="fake-deployment-name-label",
+            ),
+            model_type=ModelType.LLM,
+            features=[
+                ModelFeature.AGENT_THOUGHT,
+            ],
+            fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
+            model_properties={
+                ModelPropertyKey.MODE: LLMMode.CHAT.value,
+                ModelPropertyKey.CONTEXT_SIZE: 128000,
+            },
+            parameter_rules=[
+                ParameterRule(
+                    name="response_format",
+                    label=I18nObject(zh_Hans="回复格式", en_US="response_format"),
+                    type="string",
+                    help=I18nObject(
+                        zh_Hans="指定模型必须输出的格式", en_US="specifying the format that the model must output"
+                    ),
+                    required=False,
+                    options=["text", "json_object"],
+                ),
+                _get_max_tokens(default=512, min_val=1, max_val=32768),
+            ],
+            pricing=PriceConfig(
+                input=15.00,
+                output=60.00,
+                unit=0.000001,
+                currency="USD",
+            ),
+        ),
+    ),
+    AzureBaseModel(
+        base_model_name="o1-mini",
+        entity=AIModelEntity(
+            model="fake-deployment-name",
+            label=I18nObject(
+                en_US="fake-deployment-name-label",
+            ),
+            model_type=ModelType.LLM,
+            features=[
+                ModelFeature.AGENT_THOUGHT,
+            ],
+            fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
+            model_properties={
+                ModelPropertyKey.MODE: LLMMode.CHAT.value,
+                ModelPropertyKey.CONTEXT_SIZE: 128000,
+            },
+            parameter_rules=[
+                ParameterRule(
+                    name="response_format",
+                    label=I18nObject(zh_Hans="回复格式", en_US="response_format"),
+                    type="string",
+                    help=I18nObject(
+                        zh_Hans="指定模型必须输出的格式", en_US="specifying the format that the model must output"
+                    ),
+                    required=False,
+                    options=["text", "json_object"],
+                ),
+                _get_max_tokens(default=512, min_val=1, max_val=65536),
+            ],
+            pricing=PriceConfig(
+                input=3.00,
+                output=12.00,
+                unit=0.000001,
+                currency="USD",
+            ),
+        ),
+    ),
 ]
-
 EMBEDDING_BASE_MODELS = [
     AzureBaseModel(
         base_model_name="text-embedding-ada-002",

+ 12 - 0
api/core/model_runtime/model_providers/azure_openai/azure_openai.yaml

@@ -121,6 +121,18 @@ model_credential_schema:
             - variable: __model_type
               value: llm
         - label:
+            en_US: o1-mini
+          value: o1-mini
+          show_on:
+            - variable: __model_type
+              value: llm
+        - label:
+            en_US: o1-preview
+          value: o1-preview
+          show_on:
+            - variable: __model_type
+              value: llm
+        - label:
             en_US: gpt-4o-mini
           value: gpt-4o-mini
           show_on:

+ 110 - 12
api/core/model_runtime/model_providers/azure_openai/llm/llm.py

@@ -312,20 +312,118 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
         if user:
             extra_model_kwargs["user"] = user
 
-        # chat model
-        messages = [self._convert_prompt_message_to_dict(m) for m in prompt_messages]
-        response = client.chat.completions.create(
-            messages=messages,
-            model=model,
-            stream=stream,
-            **model_parameters,
-            **extra_model_kwargs,
+            # clear illegal prompt messages
+            prompt_messages = self._clear_illegal_prompt_messages(model, prompt_messages)
+
+            block_as_stream = False
+            if model.startswith("o1"):
+                if stream:
+                    block_as_stream = True
+                    stream = False
+
+                    if "stream_options" in extra_model_kwargs:
+                        del extra_model_kwargs["stream_options"]
+
+                if "stop" in extra_model_kwargs:
+                    del extra_model_kwargs["stop"]
+
+            # chat model
+            response = client.chat.completions.create(
+                messages=[self._convert_prompt_message_to_dict(m) for m in prompt_messages],
+                model=model,
+                stream=stream,
+                **model_parameters,
+                **extra_model_kwargs,
+            )
+
+            if stream:
+                return self._handle_chat_generate_stream_response(model, credentials, response, prompt_messages, tools)
+
+            block_result = self._handle_chat_generate_response(model, credentials, response, prompt_messages, tools)
+
+            if block_as_stream:
+                return self._handle_chat_block_as_stream_response(block_result, prompt_messages, stop)
+
+            return block_result
+
+    def _handle_chat_block_as_stream_response(
+        self,
+        block_result: LLMResult,
+        prompt_messages: list[PromptMessage],
+        stop: Optional[list[str]] = None,
+    ) -> Generator[LLMResultChunk, None, None]:
+        """
+        Handle llm chat response
+
+        :param model: model name
+        :param credentials: credentials
+        :param response: response
+        :param prompt_messages: prompt messages
+        :param tools: tools for tool calling
+        :param stop: stop words
+        :return: llm response chunk generator
+        """
+        text = block_result.message.content
+        text = cast(str, text)
+
+        if stop:
+            text = self.enforce_stop_tokens(text, stop)
+
+        yield LLMResultChunk(
+            model=block_result.model,
+            prompt_messages=prompt_messages,
+            system_fingerprint=block_result.system_fingerprint,
+            delta=LLMResultChunkDelta(
+                index=0,
+                message=AssistantPromptMessage(content=text),
+                finish_reason="stop",
+                usage=block_result.usage,
+            ),
         )
 
-        if stream:
-            return self._handle_chat_generate_stream_response(model, credentials, response, prompt_messages, tools)
+    def _clear_illegal_prompt_messages(self, model: str, prompt_messages: list[PromptMessage]) -> list[PromptMessage]:
+        """
+        Clear illegal prompt messages for OpenAI API
+
+        :param model: model name
+        :param prompt_messages: prompt messages
+        :return: cleaned prompt messages
+        """
+        checklist = ["gpt-4-turbo", "gpt-4-turbo-2024-04-09"]
+
+        if model in checklist:
+            # count how many user messages are there
+            user_message_count = len([m for m in prompt_messages if isinstance(m, UserPromptMessage)])
+            if user_message_count > 1:
+                for prompt_message in prompt_messages:
+                    if isinstance(prompt_message, UserPromptMessage):
+                        if isinstance(prompt_message.content, list):
+                            prompt_message.content = "\n".join(
+                                [
+                                    item.data
+                                    if item.type == PromptMessageContentType.TEXT
+                                    else "[IMAGE]"
+                                    if item.type == PromptMessageContentType.IMAGE
+                                    else ""
+                                    for item in prompt_message.content
+                                ]
+                            )
+
+        if model.startswith("o1"):
+            system_message_count = len([m for m in prompt_messages if isinstance(m, SystemPromptMessage)])
+            if system_message_count > 0:
+                new_prompt_messages = []
+                for prompt_message in prompt_messages:
+                    if isinstance(prompt_message, SystemPromptMessage):
+                        prompt_message = UserPromptMessage(
+                            content=prompt_message.content,
+                            name=prompt_message.name,
+                        )
+
+                    new_prompt_messages.append(prompt_message)
+                prompt_messages = new_prompt_messages
 
-        return self._handle_chat_generate_response(model, credentials, response, prompt_messages, tools)
+        return prompt_messages
 
     def _handle_chat_generate_response(
         self,
@@ -560,7 +658,7 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
             tokens_per_message = 4
             # if there's a name, the role is omitted
             tokens_per_name = -1
-        elif model.startswith("gpt-35-turbo") or model.startswith("gpt-4"):
+        elif model.startswith("gpt-35-turbo") or model.startswith("gpt-4") or model.startswith("o1"):
             tokens_per_message = 3
             tokens_per_name = 1
         else: