Browse Source

feat: the think tag display of siliconflow's deepseek r1 (#13153)

非法操作 2 months ago
parent
commit
840729afa5
1 changed files with 214 additions and 2 deletions
  1. 214 2
      api/core/model_runtime/model_providers/siliconflow/llm/llm.py

+ 214 - 2
api/core/model_runtime/model_providers/siliconflow/llm/llm.py

@@ -1,9 +1,16 @@
+import json
 from collections.abc import Generator
 from typing import Optional, Union
 
+import requests
+
 from core.model_runtime.entities.common_entities import I18nObject
-from core.model_runtime.entities.llm_entities import LLMMode, LLMResult
-from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool
+from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta
+from core.model_runtime.entities.message_entities import (
+    AssistantPromptMessage,
+    PromptMessage,
+    PromptMessageTool,
+)
 from core.model_runtime.entities.model_entities import (
     AIModelEntity,
     FetchFrom,
@@ -89,3 +96,208 @@ class SiliconflowLargeLanguageModel(OAIAPICompatLargeLanguageModel):
                 ),
             ],
         )
+
+    def _handle_generate_stream_response(
+        self, model: str, credentials: dict, response: requests.Response, prompt_messages: list[PromptMessage]
+    ) -> Generator:
+        """
+        Handle llm stream response
+
+        :param model: model name
+        :param credentials: model credentials
+        :param response: streamed response
+        :param prompt_messages: prompt messages
+        :return: llm response chunk generator
+        """
+        full_assistant_content = ""
+        chunk_index = 0
+        is_reasoning_started = False  # Add flag to track reasoning state
+
+        def create_final_llm_result_chunk(
+            id: Optional[str], index: int, message: AssistantPromptMessage, finish_reason: str, usage: dict
+        ) -> LLMResultChunk:
+            # calculate num tokens
+            prompt_tokens = usage and usage.get("prompt_tokens")
+            if prompt_tokens is None:
+                prompt_tokens = self._num_tokens_from_string(model, prompt_messages[0].content)
+            completion_tokens = usage and usage.get("completion_tokens")
+            if completion_tokens is None:
+                completion_tokens = self._num_tokens_from_string(model, full_assistant_content)
+
+            # transform usage
+            usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
+
+            return LLMResultChunk(
+                id=id,
+                model=model,
+                prompt_messages=prompt_messages,
+                delta=LLMResultChunkDelta(index=index, message=message, finish_reason=finish_reason, usage=usage),
+            )
+
+        # delimiter for stream response, need unicode_escape
+        import codecs
+
+        delimiter = credentials.get("stream_mode_delimiter", "\n\n")
+        delimiter = codecs.decode(delimiter, "unicode_escape")
+
+        tools_calls: list[AssistantPromptMessage.ToolCall] = []
+
+        def increase_tool_call(new_tool_calls: list[AssistantPromptMessage.ToolCall]):
+            def get_tool_call(tool_call_id: str):
+                if not tool_call_id:
+                    return tools_calls[-1]
+
+                tool_call = next((tool_call for tool_call in tools_calls if tool_call.id == tool_call_id), None)
+                if tool_call is None:
+                    tool_call = AssistantPromptMessage.ToolCall(
+                        id=tool_call_id,
+                        type="function",
+                        function=AssistantPromptMessage.ToolCall.ToolCallFunction(name="", arguments=""),
+                    )
+                    tools_calls.append(tool_call)
+
+                return tool_call
+
+            for new_tool_call in new_tool_calls:
+                # get tool call
+                tool_call = get_tool_call(new_tool_call.function.name)
+                # update tool call
+                if new_tool_call.id:
+                    tool_call.id = new_tool_call.id
+                if new_tool_call.type:
+                    tool_call.type = new_tool_call.type
+                if new_tool_call.function.name:
+                    tool_call.function.name = new_tool_call.function.name
+                if new_tool_call.function.arguments:
+                    tool_call.function.arguments += new_tool_call.function.arguments
+
+        finish_reason = None  # The default value of finish_reason is None
+        message_id, usage = None, None
+        for chunk in response.iter_lines(decode_unicode=True, delimiter=delimiter):
+            chunk = chunk.strip()
+            if chunk:
+                # ignore sse comments
+                if chunk.startswith(":"):
+                    continue
+                decoded_chunk = chunk.strip().removeprefix("data:").lstrip()
+                if decoded_chunk == "[DONE]":  # Some provider returns "data: [DONE]"
+                    continue
+
+                try:
+                    chunk_json: dict = json.loads(decoded_chunk)
+                # stream ended
+                except json.JSONDecodeError as e:
+                    yield create_final_llm_result_chunk(
+                        id=message_id,
+                        index=chunk_index + 1,
+                        message=AssistantPromptMessage(content=""),
+                        finish_reason="Non-JSON encountered.",
+                        usage=usage,
+                    )
+                    break
+                # handle the error here. for issue #11629
+                if chunk_json.get("error") and chunk_json.get("choices") is None:
+                    raise ValueError(chunk_json.get("error"))
+
+                if chunk_json:
+                    if u := chunk_json.get("usage"):
+                        usage = u
+                if not chunk_json or len(chunk_json["choices"]) == 0:
+                    continue
+
+                choice = chunk_json["choices"][0]
+                finish_reason = chunk_json["choices"][0].get("finish_reason")
+                message_id = chunk_json.get("id")
+                chunk_index += 1
+
+                if "delta" in choice:
+                    delta = choice["delta"]
+                    delta_content = delta.get("content")
+
+                    assistant_message_tool_calls = None
+
+                    if "tool_calls" in delta and credentials.get("function_calling_type", "no_call") == "tool_call":
+                        assistant_message_tool_calls = delta.get("tool_calls", None)
+                    elif (
+                        "function_call" in delta
+                        and credentials.get("function_calling_type", "no_call") == "function_call"
+                    ):
+                        assistant_message_tool_calls = [
+                            {"id": "tool_call_id", "type": "function", "function": delta.get("function_call", {})}
+                        ]
+
+                    # assistant_message_function_call = delta.delta.function_call
+
+                    # extract tool calls from response
+                    if assistant_message_tool_calls:
+                        tool_calls = self._extract_response_tool_calls(assistant_message_tool_calls)
+                        increase_tool_call(tool_calls)
+
+                    if delta_content is None or delta_content == "":
+                        continue
+
+                    # Check for think tags
+                    if "<think>" in delta_content:
+                        is_reasoning_started = True
+                        # Remove <think> tag and add markdown quote
+                        delta_content = "> 💭 " + delta_content.replace("<think>", "")
+                    elif "</think>" in delta_content:
+                        # Remove </think> tag and add newlines to end quote block
+                        delta_content = delta_content.replace("</think>", "") + "\n\n"
+                        is_reasoning_started = False
+                    elif is_reasoning_started:
+                        # Add quote markers for content within thinking block
+                        if "\n\n" in delta_content:
+                            delta_content = delta_content.replace("\n\n", "\n> ")
+                        elif "\n" in delta_content:
+                            delta_content = delta_content.replace("\n", "\n> ")
+
+                    # transform assistant message to prompt message
+                    assistant_prompt_message = AssistantPromptMessage(
+                        content=delta_content,
+                    )
+
+                    # reset tool calls
+                    tool_calls = []
+                    full_assistant_content += delta_content
+                elif "text" in choice:
+                    choice_text = choice.get("text", "")
+                    if choice_text == "":
+                        continue
+
+                    # transform assistant message to prompt message
+                    assistant_prompt_message = AssistantPromptMessage(content=choice_text)
+                    full_assistant_content += choice_text
+                else:
+                    continue
+
+                yield LLMResultChunk(
+                    id=message_id,
+                    model=model,
+                    prompt_messages=prompt_messages,
+                    delta=LLMResultChunkDelta(
+                        index=chunk_index,
+                        message=assistant_prompt_message,
+                    ),
+                )
+
+            chunk_index += 1
+
+        if tools_calls:
+            yield LLMResultChunk(
+                id=message_id,
+                model=model,
+                prompt_messages=prompt_messages,
+                delta=LLMResultChunkDelta(
+                    index=chunk_index,
+                    message=AssistantPromptMessage(tool_calls=tools_calls, content=""),
+                ),
+            )
+
+        yield create_final_llm_result_chunk(
+            id=message_id,
+            index=chunk_index,
+            message=AssistantPromptMessage(content=""),
+            finish_reason=finish_reason,
+            usage=usage,
+        )