|
@@ -312,20 +312,118 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
|
|
|
if user:
|
|
|
extra_model_kwargs["user"] = user
|
|
|
|
|
|
- # chat model
|
|
|
- messages = [self._convert_prompt_message_to_dict(m) for m in prompt_messages]
|
|
|
- response = client.chat.completions.create(
|
|
|
- messages=messages,
|
|
|
- model=model,
|
|
|
- stream=stream,
|
|
|
- **model_parameters,
|
|
|
- **extra_model_kwargs,
|
|
|
+ # clear illegal prompt messages
|
|
|
+ prompt_messages = self._clear_illegal_prompt_messages(model, prompt_messages)
|
|
|
+
|
|
|
+ block_as_stream = False
|
|
|
+ if model.startswith("o1"):
|
|
|
+ if stream:
|
|
|
+ block_as_stream = True
|
|
|
+ stream = False
|
|
|
+
|
|
|
+ if "stream_options" in extra_model_kwargs:
|
|
|
+ del extra_model_kwargs["stream_options"]
|
|
|
+
|
|
|
+ if "stop" in extra_model_kwargs:
|
|
|
+ del extra_model_kwargs["stop"]
|
|
|
+
|
|
|
+ # chat model
|
|
|
+ response = client.chat.completions.create(
|
|
|
+ messages=[self._convert_prompt_message_to_dict(m) for m in prompt_messages],
|
|
|
+ model=model,
|
|
|
+ stream=stream,
|
|
|
+ **model_parameters,
|
|
|
+ **extra_model_kwargs,
|
|
|
+ )
|
|
|
+
|
|
|
+ if stream:
|
|
|
+ return self._handle_chat_generate_stream_response(model, credentials, response, prompt_messages, tools)
|
|
|
+
|
|
|
+ block_result = self._handle_chat_generate_response(model, credentials, response, prompt_messages, tools)
|
|
|
+
|
|
|
+ if block_as_stream:
|
|
|
+ return self._handle_chat_block_as_stream_response(block_result, prompt_messages, stop)
|
|
|
+
|
|
|
+ return block_result
|
|
|
+
|
|
|
+ def _handle_chat_block_as_stream_response(
|
|
|
+ self,
|
|
|
+ block_result: LLMResult,
|
|
|
+ prompt_messages: list[PromptMessage],
|
|
|
+ stop: Optional[list[str]] = None,
|
|
|
+ ) -> Generator[LLMResultChunk, None, None]:
|
|
|
+ """
|
|
|
+ Handle llm chat response
|
|
|
+
|
|
|
+ :param model: model name
|
|
|
+ :param credentials: credentials
|
|
|
+ :param response: response
|
|
|
+ :param prompt_messages: prompt messages
|
|
|
+ :param tools: tools for tool calling
|
|
|
+ :param stop: stop words
|
|
|
+ :return: llm response chunk generator
|
|
|
+ """
|
|
|
+ text = block_result.message.content
|
|
|
+ text = cast(str, text)
|
|
|
+
|
|
|
+ if stop:
|
|
|
+ text = self.enforce_stop_tokens(text, stop)
|
|
|
+
|
|
|
+ yield LLMResultChunk(
|
|
|
+ model=block_result.model,
|
|
|
+ prompt_messages=prompt_messages,
|
|
|
+ system_fingerprint=block_result.system_fingerprint,
|
|
|
+ delta=LLMResultChunkDelta(
|
|
|
+ index=0,
|
|
|
+ message=AssistantPromptMessage(content=text),
|
|
|
+ finish_reason="stop",
|
|
|
+ usage=block_result.usage,
|
|
|
+ ),
|
|
|
)
|
|
|
|
|
|
- if stream:
|
|
|
- return self._handle_chat_generate_stream_response(model, credentials, response, prompt_messages, tools)
|
|
|
+ def _clear_illegal_prompt_messages(self, model: str, prompt_messages: list[PromptMessage]) -> list[PromptMessage]:
|
|
|
+ """
|
|
|
+ Clear illegal prompt messages for OpenAI API
|
|
|
+
|
|
|
+ :param model: model name
|
|
|
+ :param prompt_messages: prompt messages
|
|
|
+ :return: cleaned prompt messages
|
|
|
+ """
|
|
|
+ checklist = ["gpt-4-turbo", "gpt-4-turbo-2024-04-09"]
|
|
|
+
|
|
|
+ if model in checklist:
|
|
|
+ # count how many user messages are there
|
|
|
+ user_message_count = len([m for m in prompt_messages if isinstance(m, UserPromptMessage)])
|
|
|
+ if user_message_count > 1:
|
|
|
+ for prompt_message in prompt_messages:
|
|
|
+ if isinstance(prompt_message, UserPromptMessage):
|
|
|
+ if isinstance(prompt_message.content, list):
|
|
|
+ prompt_message.content = "\n".join(
|
|
|
+ [
|
|
|
+ item.data
|
|
|
+ if item.type == PromptMessageContentType.TEXT
|
|
|
+ else "[IMAGE]"
|
|
|
+ if item.type == PromptMessageContentType.IMAGE
|
|
|
+ else ""
|
|
|
+ for item in prompt_message.content
|
|
|
+ ]
|
|
|
+ )
|
|
|
+
|
|
|
+ if model.startswith("o1"):
|
|
|
+ system_message_count = len([m for m in prompt_messages if isinstance(m, SystemPromptMessage)])
|
|
|
+ if system_message_count > 0:
|
|
|
+ new_prompt_messages = []
|
|
|
+ for prompt_message in prompt_messages:
|
|
|
+ if isinstance(prompt_message, SystemPromptMessage):
|
|
|
+ prompt_message = UserPromptMessage(
|
|
|
+ content=prompt_message.content,
|
|
|
+ name=prompt_message.name,
|
|
|
+ )
|
|
|
+
|
|
|
+ new_prompt_messages.append(prompt_message)
|
|
|
+ prompt_messages = new_prompt_messages
|
|
|
|
|
|
- return self._handle_chat_generate_response(model, credentials, response, prompt_messages, tools)
|
|
|
+ return prompt_messages
|
|
|
|
|
|
def _handle_chat_generate_response(
|
|
|
self,
|
|
@@ -560,7 +658,7 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
|
|
|
tokens_per_message = 4
|
|
|
# if there's a name, the role is omitted
|
|
|
tokens_per_name = -1
|
|
|
- elif model.startswith("gpt-35-turbo") or model.startswith("gpt-4"):
|
|
|
+ elif model.startswith("gpt-35-turbo") or model.startswith("gpt-4") or model.startswith("o1"):
|
|
|
tokens_per_message = 3
|
|
|
tokens_per_name = 1
|
|
|
else:
|