| 
					
				 | 
			
			
				@@ -312,20 +312,118 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         if user: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             extra_model_kwargs["user"] = user 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        # chat model 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        messages = [self._convert_prompt_message_to_dict(m) for m in prompt_messages] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        response = client.chat.completions.create( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            messages=messages, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            model=model, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            stream=stream, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            **model_parameters, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            **extra_model_kwargs, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            # clear illegal prompt messages 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            prompt_messages = self._clear_illegal_prompt_messages(model, prompt_messages) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            block_as_stream = False 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            if model.startswith("o1"): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                if stream: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    block_as_stream = True 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    stream = False 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    if "stream_options" in extra_model_kwargs: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                        del extra_model_kwargs["stream_options"] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                if "stop" in extra_model_kwargs: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    del extra_model_kwargs["stop"] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            # chat model 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            response = client.chat.completions.create( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                messages=[self._convert_prompt_message_to_dict(m) for m in prompt_messages], 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                model=model, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                stream=stream, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                **model_parameters, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                **extra_model_kwargs, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            ) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            if stream: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                return self._handle_chat_generate_stream_response(model, credentials, response, prompt_messages, tools) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            block_result = self._handle_chat_generate_response(model, credentials, response, prompt_messages, tools) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            if block_as_stream: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                return self._handle_chat_block_as_stream_response(block_result, prompt_messages, stop) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            return block_result 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    def _handle_chat_block_as_stream_response( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        self, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        block_result: LLMResult, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        prompt_messages: list[PromptMessage], 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        stop: Optional[list[str]] = None, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    ) -> Generator[LLMResultChunk, None, None]: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        """ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        Handle llm chat response 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        :param model: model name 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        :param credentials: credentials 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        :param response: response 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        :param prompt_messages: prompt messages 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        :param tools: tools for tool calling 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        :param stop: stop words 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        :return: llm response chunk generator 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        """ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        text = block_result.message.content 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        text = cast(str, text) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        if stop: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            text = self.enforce_stop_tokens(text, stop) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        yield LLMResultChunk( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            model=block_result.model, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            prompt_messages=prompt_messages, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            system_fingerprint=block_result.system_fingerprint, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            delta=LLMResultChunkDelta( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                index=0, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                message=AssistantPromptMessage(content=text), 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                finish_reason="stop", 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                usage=block_result.usage, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            ), 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         ) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        if stream: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            return self._handle_chat_generate_stream_response(model, credentials, response, prompt_messages, tools) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    def _clear_illegal_prompt_messages(self, model: str, prompt_messages: list[PromptMessage]) -> list[PromptMessage]: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        """ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        Clear illegal prompt messages for OpenAI API 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        :param model: model name 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        :param prompt_messages: prompt messages 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        :return: cleaned prompt messages 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        """ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        checklist = ["gpt-4-turbo", "gpt-4-turbo-2024-04-09"] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        if model in checklist: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            # count how many user messages are there 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            user_message_count = len([m for m in prompt_messages if isinstance(m, UserPromptMessage)]) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            if user_message_count > 1: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                for prompt_message in prompt_messages: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    if isinstance(prompt_message, UserPromptMessage): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                        if isinstance(prompt_message.content, list): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                            prompt_message.content = "\n".join( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                                [ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                                    item.data 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                                    if item.type == PromptMessageContentType.TEXT 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                                    else "[IMAGE]" 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                                    if item.type == PromptMessageContentType.IMAGE 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                                    else "" 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                                    for item in prompt_message.content 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                                ] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                            ) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        if model.startswith("o1"): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            system_message_count = len([m for m in prompt_messages if isinstance(m, SystemPromptMessage)]) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            if system_message_count > 0: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                new_prompt_messages = [] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                for prompt_message in prompt_messages: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    if isinstance(prompt_message, SystemPromptMessage): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                        prompt_message = UserPromptMessage( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                            content=prompt_message.content, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                            name=prompt_message.name, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                        ) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    new_prompt_messages.append(prompt_message) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                prompt_messages = new_prompt_messages 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        return self._handle_chat_generate_response(model, credentials, response, prompt_messages, tools) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        return prompt_messages 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     def _handle_chat_generate_response( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         self, 
			 | 
		
	
	
		
			
				| 
					
				 | 
			
			
				@@ -560,7 +658,7 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             tokens_per_message = 4 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             # if there's a name, the role is omitted 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             tokens_per_name = -1 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        elif model.startswith("gpt-35-turbo") or model.startswith("gpt-4"): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        elif model.startswith("gpt-35-turbo") or model.startswith("gpt-4") or model.startswith("o1"): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             tokens_per_message = 3 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             tokens_per_name = 1 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         else: 
			 |