Browse Source

chore: refactor the OpenAICompatible and improve thinking display (#13299)

非法操作 2 months ago
parent
commit
3eb3db0663

+ 39 - 0
api/core/model_runtime/model_providers/__base/large_language_model.py

@@ -30,6 +30,11 @@ from core.model_runtime.model_providers.__base.ai_model import AIModel
 
 logger = logging.getLogger(__name__)
 
+HTML_THINKING_TAG = (
+    '<details style="color:gray;background-color: #f5f5f5;padding: 8px;border-radius: 4px;" open> '
+    "<summary> Thinking... </summary>"
+)
+
 
 class LargeLanguageModel(AIModel):
     """
@@ -400,6 +405,40 @@ if you are not sure about the structure.
                     ),
                 )
 
+    def _wrap_thinking_by_reasoning_content(self, delta: dict, is_reasoning: bool) -> tuple[str, bool]:
+        """
+        If the reasoning response is from delta.get("reasoning_content"), we wrap
+        it with HTML details tag.
+
+        :param delta: delta dictionary from LLM streaming response
+        :param is_reasoning: is reasoning
+        :return: tuple of (processed_content, is_reasoning)
+        """
+
+        content = delta.get("content") or ""
+        reasoning_content = delta.get("reasoning_content")
+
+        if reasoning_content:
+            if not is_reasoning:
+                content = HTML_THINKING_TAG + reasoning_content
+                is_reasoning = True
+            else:
+                content = reasoning_content
+        elif is_reasoning:
+            content = "</details>" + content
+            is_reasoning = False
+        return content, is_reasoning
+
+    def _wrap_thinking_by_tag(self, content: str) -> str:
+        """
+        if the reasoning response is a <think>...</think> block from delta.get("content"),
+        we replace <think> to <detail>.
+
+        :param content: delta.get("content")
+        :return: processed_content
+        """
+        return content.replace("<think>", HTML_THINKING_TAG).replace("</think>", "</details>")
+
     def _invoke_result_generator(
         self,
         model: str,

+ 90 - 110
api/core/model_runtime/model_providers/openai_api_compatible/llm/llm.py

@@ -1,6 +1,5 @@
+import codecs
 import json
-import logging
-import re
 from collections.abc import Generator
 from decimal import Decimal
 from typing import Optional, Union, cast
@@ -39,8 +38,6 @@ from core.model_runtime.model_providers.__base.large_language_model import Large
 from core.model_runtime.model_providers.openai_api_compatible._common import _CommonOaiApiCompat
 from core.model_runtime.utils import helper
 
-logger = logging.getLogger(__name__)
-
 
 class OAIAPICompatLargeLanguageModel(_CommonOaiApiCompat, LargeLanguageModel):
     """
@@ -100,7 +97,7 @@ class OAIAPICompatLargeLanguageModel(_CommonOaiApiCompat, LargeLanguageModel):
         :param tools: tools for tool calling
         :return:
         """
-        return self._num_tokens_from_messages(model, prompt_messages, tools, credentials)
+        return self._num_tokens_from_messages(prompt_messages, tools, credentials)
 
     def validate_credentials(self, model: str, credentials: dict) -> None:
         """
@@ -399,6 +396,73 @@ class OAIAPICompatLargeLanguageModel(_CommonOaiApiCompat, LargeLanguageModel):
 
         return self._handle_generate_response(model, credentials, response, prompt_messages)
 
+    def _create_final_llm_result_chunk(
+        self,
+        index: int,
+        message: AssistantPromptMessage,
+        finish_reason: str,
+        usage: dict,
+        model: str,
+        prompt_messages: list[PromptMessage],
+        credentials: dict,
+        full_content: str,
+    ) -> LLMResultChunk:
+        # calculate num tokens
+        prompt_tokens = usage and usage.get("prompt_tokens")
+        if prompt_tokens is None:
+            prompt_tokens = self._num_tokens_from_string(text=prompt_messages[0].content)
+        completion_tokens = usage and usage.get("completion_tokens")
+        if completion_tokens is None:
+            completion_tokens = self._num_tokens_from_string(text=full_content)
+
+        # transform usage
+        usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
+
+        return LLMResultChunk(
+            model=model,
+            prompt_messages=prompt_messages,
+            delta=LLMResultChunkDelta(index=index, message=message, finish_reason=finish_reason, usage=usage),
+        )
+
+    def _get_tool_call(self, tool_call_id: str, tools_calls: list[AssistantPromptMessage.ToolCall]):
+        """
+        Get or create a tool call by ID
+
+        :param tool_call_id: tool call ID
+        :param tools_calls: list of existing tool calls
+        :return: existing or new tool call, updated tools_calls
+        """
+        if not tool_call_id:
+            return tools_calls[-1], tools_calls
+
+        tool_call = next((tool_call for tool_call in tools_calls if tool_call.id == tool_call_id), None)
+        if tool_call is None:
+            tool_call = AssistantPromptMessage.ToolCall(
+                id=tool_call_id,
+                type="function",
+                function=AssistantPromptMessage.ToolCall.ToolCallFunction(name="", arguments=""),
+            )
+            tools_calls.append(tool_call)
+
+        return tool_call, tools_calls
+
+    def _increase_tool_call(
+        self, new_tool_calls: list[AssistantPromptMessage.ToolCall], tools_calls: list[AssistantPromptMessage.ToolCall]
+    ) -> list[AssistantPromptMessage.ToolCall]:
+        for new_tool_call in new_tool_calls:
+            # get tool call
+            tool_call, tools_calls = self._get_tool_call(new_tool_call.function.name, tools_calls)
+            # update tool call
+            if new_tool_call.id:
+                tool_call.id = new_tool_call.id
+            if new_tool_call.type:
+                tool_call.type = new_tool_call.type
+            if new_tool_call.function.name:
+                tool_call.function.name = new_tool_call.function.name
+            if new_tool_call.function.arguments:
+                tool_call.function.arguments += new_tool_call.function.arguments
+        return tools_calls
+
     def _handle_generate_stream_response(
         self, model: str, credentials: dict, response: requests.Response, prompt_messages: list[PromptMessage]
     ) -> Generator:
@@ -411,71 +475,15 @@ class OAIAPICompatLargeLanguageModel(_CommonOaiApiCompat, LargeLanguageModel):
         :param prompt_messages: prompt messages
         :return: llm response chunk generator
         """
-        full_assistant_content = ""
         chunk_index = 0
-
-        def create_final_llm_result_chunk(
-            id: Optional[str], index: int, message: AssistantPromptMessage, finish_reason: str, usage: dict
-        ) -> LLMResultChunk:
-            # calculate num tokens
-            prompt_tokens = usage and usage.get("prompt_tokens")
-            if prompt_tokens is None:
-                prompt_tokens = self._num_tokens_from_string(model, prompt_messages[0].content)
-            completion_tokens = usage and usage.get("completion_tokens")
-            if completion_tokens is None:
-                completion_tokens = self._num_tokens_from_string(model, full_assistant_content)
-
-            # transform usage
-            usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
-
-            return LLMResultChunk(
-                id=id,
-                model=model,
-                prompt_messages=prompt_messages,
-                delta=LLMResultChunkDelta(index=index, message=message, finish_reason=finish_reason, usage=usage),
-            )
-
+        full_assistant_content = ""
+        tools_calls: list[AssistantPromptMessage.ToolCall] = []
+        finish_reason = None
+        usage = None
+        is_reasoning_started = False
         # delimiter for stream response, need unicode_escape
-        import codecs
-
         delimiter = credentials.get("stream_mode_delimiter", "\n\n")
         delimiter = codecs.decode(delimiter, "unicode_escape")
-
-        tools_calls: list[AssistantPromptMessage.ToolCall] = []
-
-        def increase_tool_call(new_tool_calls: list[AssistantPromptMessage.ToolCall]):
-            def get_tool_call(tool_call_id: str):
-                if not tool_call_id:
-                    return tools_calls[-1]
-
-                tool_call = next((tool_call for tool_call in tools_calls if tool_call.id == tool_call_id), None)
-                if tool_call is None:
-                    tool_call = AssistantPromptMessage.ToolCall(
-                        id=tool_call_id,
-                        type="function",
-                        function=AssistantPromptMessage.ToolCall.ToolCallFunction(name="", arguments=""),
-                    )
-                    tools_calls.append(tool_call)
-
-                return tool_call
-
-            for new_tool_call in new_tool_calls:
-                # get tool call
-                tool_call = get_tool_call(new_tool_call.function.name)
-                # update tool call
-                if new_tool_call.id:
-                    tool_call.id = new_tool_call.id
-                if new_tool_call.type:
-                    tool_call.type = new_tool_call.type
-                if new_tool_call.function.name:
-                    tool_call.function.name = new_tool_call.function.name
-                if new_tool_call.function.arguments:
-                    tool_call.function.arguments += new_tool_call.function.arguments
-
-        finish_reason = None  # The default value of finish_reason is None
-        message_id, usage = None, None
-        is_reasoning_started = False
-        is_reasoning_started_tag = False
         for chunk in response.iter_lines(decode_unicode=True, delimiter=delimiter):
             chunk = chunk.strip()
             if chunk:
@@ -490,12 +498,15 @@ class OAIAPICompatLargeLanguageModel(_CommonOaiApiCompat, LargeLanguageModel):
                     chunk_json: dict = json.loads(decoded_chunk)
                 # stream ended
                 except json.JSONDecodeError as e:
-                    yield create_final_llm_result_chunk(
-                        id=message_id,
+                    yield self._create_final_llm_result_chunk(
                         index=chunk_index + 1,
                         message=AssistantPromptMessage(content=""),
                         finish_reason="Non-JSON encountered.",
                         usage=usage,
+                        model=model,
+                        credentials=credentials,
+                        prompt_messages=prompt_messages,
+                        full_content=full_assistant_content,
                     )
                     break
                 # handle the error here. for issue #11629
@@ -510,42 +521,14 @@ class OAIAPICompatLargeLanguageModel(_CommonOaiApiCompat, LargeLanguageModel):
 
                 choice = chunk_json["choices"][0]
                 finish_reason = chunk_json["choices"][0].get("finish_reason")
-                message_id = chunk_json.get("id")
                 chunk_index += 1
 
                 if "delta" in choice:
                     delta = choice["delta"]
-                    delta_content = delta.get("content")
-                    if not delta_content:
-                        delta_content = ""
-
-                    if not is_reasoning_started_tag and "<think>" in delta_content:
-                        is_reasoning_started_tag = True
-                        delta_content = "> 💭 " + delta_content.replace("<think>", "")
-                    elif is_reasoning_started_tag and "</think>" in delta_content:
-                        delta_content = delta_content.replace("</think>", "") + "\n\n"
-                        is_reasoning_started_tag = False
-                    elif is_reasoning_started_tag:
-                        if "\n" in delta_content:
-                            delta_content = re.sub(r"\n(?!(>|\n))", "\n> ", delta_content)
-
-                    reasoning_content = delta.get("reasoning_content")
-                    if is_reasoning_started and not reasoning_content and not delta_content:
-                        delta_content = ""
-                    elif reasoning_content:
-                        if not is_reasoning_started:
-                            delta_content = "> 💭 " + reasoning_content
-                            is_reasoning_started = True
-                        else:
-                            delta_content = reasoning_content
-
-                        if "\n" in delta_content:
-                            delta_content = re.sub(r"\n(?!(>|\n))", "\n> ", delta_content)
-                    elif is_reasoning_started:
-                        # If we were in reasoning mode but now getting regular content,
-                        # add \n\n to close the reasoning block
-                        delta_content = "\n\n" + delta_content
-                        is_reasoning_started = False
+                    delta_content, is_reasoning_started = self._wrap_thinking_by_reasoning_content(
+                        delta, is_reasoning_started
+                    )
+                    delta_content = self._wrap_thinking_by_tag(delta_content)
 
                     assistant_message_tool_calls = None
 
@@ -559,12 +542,10 @@ class OAIAPICompatLargeLanguageModel(_CommonOaiApiCompat, LargeLanguageModel):
                             {"id": "tool_call_id", "type": "function", "function": delta.get("function_call", {})}
                         ]
 
-                    # assistant_message_function_call = delta.delta.function_call
-
                     # extract tool calls from response
                     if assistant_message_tool_calls:
                         tool_calls = self._extract_response_tool_calls(assistant_message_tool_calls)
-                        increase_tool_call(tool_calls)
+                        tools_calls = self._increase_tool_call(tool_calls, tools_calls)
 
                     if delta_content is None or delta_content == "":
                         continue
@@ -589,7 +570,6 @@ class OAIAPICompatLargeLanguageModel(_CommonOaiApiCompat, LargeLanguageModel):
                     continue
 
                 yield LLMResultChunk(
-                    id=message_id,
                     model=model,
                     prompt_messages=prompt_messages,
                     delta=LLMResultChunkDelta(
@@ -602,7 +582,6 @@ class OAIAPICompatLargeLanguageModel(_CommonOaiApiCompat, LargeLanguageModel):
 
         if tools_calls:
             yield LLMResultChunk(
-                id=message_id,
                 model=model,
                 prompt_messages=prompt_messages,
                 delta=LLMResultChunkDelta(
@@ -611,12 +590,15 @@ class OAIAPICompatLargeLanguageModel(_CommonOaiApiCompat, LargeLanguageModel):
                 ),
             )
 
-        yield create_final_llm_result_chunk(
-            id=message_id,
+        yield self._create_final_llm_result_chunk(
             index=chunk_index,
             message=AssistantPromptMessage(content=""),
             finish_reason=finish_reason,
             usage=usage,
+            model=model,
+            credentials=credentials,
+            prompt_messages=prompt_messages,
+            full_content=full_assistant_content,
         )
 
     def _handle_generate_response(
@@ -730,12 +712,11 @@ class OAIAPICompatLargeLanguageModel(_CommonOaiApiCompat, LargeLanguageModel):
         return message_dict
 
     def _num_tokens_from_string(
-        self, model: str, text: Union[str, list[PromptMessageContent]], tools: Optional[list[PromptMessageTool]] = None
+        self, text: Union[str, list[PromptMessageContent]], tools: Optional[list[PromptMessageTool]] = None
     ) -> int:
         """
         Approximate num tokens for model with gpt2 tokenizer.
 
-        :param model: model name
         :param text: prompt text
         :param tools: tools for tool calling
         :return: number of tokens
@@ -758,7 +739,6 @@ class OAIAPICompatLargeLanguageModel(_CommonOaiApiCompat, LargeLanguageModel):
 
     def _num_tokens_from_messages(
         self,
-        model: str,
         messages: list[PromptMessage],
         tools: Optional[list[PromptMessageTool]] = None,
         credentials: Optional[dict] = None,