|
@@ -613,6 +613,13 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel):
|
|
|
# clear illegal prompt messages
|
|
|
prompt_messages = self._clear_illegal_prompt_messages(model, prompt_messages)
|
|
|
|
|
|
+ block_as_stream = False
|
|
|
+ if model.startswith("o1"):
|
|
|
+ block_as_stream = True
|
|
|
+ stream = False
|
|
|
+ if "stream_options" in extra_model_kwargs:
|
|
|
+ del extra_model_kwargs["stream_options"]
|
|
|
+
|
|
|
# chat model
|
|
|
response = client.chat.completions.create(
|
|
|
messages=[self._convert_prompt_message_to_dict(m) for m in prompt_messages],
|
|
@@ -625,7 +632,39 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel):
|
|
|
if stream:
|
|
|
return self._handle_chat_generate_stream_response(model, credentials, response, prompt_messages, tools)
|
|
|
|
|
|
- return self._handle_chat_generate_response(model, credentials, response, prompt_messages, tools)
|
|
|
+ block_result = self._handle_chat_generate_response(model, credentials, response, prompt_messages, tools)
|
|
|
+
|
|
|
+ if block_as_stream:
|
|
|
+ return self._handle_chat_block_as_stream_response(block_result, prompt_messages)
|
|
|
+
|
|
|
+ return block_result
|
|
|
+
|
|
|
+ def _handle_chat_block_as_stream_response(
|
|
|
+ self,
|
|
|
+ block_result: LLMResult,
|
|
|
+ prompt_messages: list[PromptMessage],
|
|
|
+ ) -> Generator[LLMResultChunk, None, None]:
|
|
|
+ """
|
|
|
+ Handle llm chat response
|
|
|
+
|
|
|
+ :param model: model name
|
|
|
+ :param credentials: credentials
|
|
|
+ :param response: response
|
|
|
+ :param prompt_messages: prompt messages
|
|
|
+ :param tools: tools for tool calling
|
|
|
+ :return: llm response chunk generator
|
|
|
+ """
|
|
|
+ yield LLMResultChunk(
|
|
|
+ model=block_result.model,
|
|
|
+ prompt_messages=prompt_messages,
|
|
|
+ system_fingerprint=block_result.system_fingerprint,
|
|
|
+ delta=LLMResultChunkDelta(
|
|
|
+ index=0,
|
|
|
+ message=block_result.message,
|
|
|
+ finish_reason="stop",
|
|
|
+ usage=block_result.usage,
|
|
|
+ ),
|
|
|
+ )
|
|
|
|
|
|
def _handle_chat_generate_response(
|
|
|
self,
|
|
@@ -960,7 +999,7 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel):
|
|
|
model = model.split(":")[1]
|
|
|
|
|
|
# Currently, we can use gpt4o to calculate chatgpt-4o-latest's token.
|
|
|
- if model == "chatgpt-4o-latest":
|
|
|
+ if model == "chatgpt-4o-latest" or model.startswith("o1"):
|
|
|
model = "gpt-4o"
|
|
|
|
|
|
try:
|
|
@@ -975,7 +1014,7 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel):
|
|
|
tokens_per_message = 4
|
|
|
# if there's a name, the role is omitted
|
|
|
tokens_per_name = -1
|
|
|
- elif model.startswith("gpt-3.5-turbo") or model.startswith("gpt-4"):
|
|
|
+ elif model.startswith("gpt-3.5-turbo") or model.startswith("gpt-4") or model.startswith("o1"):
|
|
|
tokens_per_message = 3
|
|
|
tokens_per_name = 1
|
|
|
else:
|