瀏覽代碼

fix: image token calc of OpenAI Compatible API (#3368)

takatost 1 年之前
父節點
當前提交
9a1ea9ac03
共有 1 個文件被更改,包括 12 次插入3 次删除
  1. 12 3
      api/core/model_runtime/model_providers/openai_api_compatible/llm/llm.py

+ 12 - 3
api/core/model_runtime/model_providers/openai_api_compatible/llm/llm.py

@@ -416,7 +416,7 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
                 if chunk.startswith(':'):
                     continue
                 decoded_chunk = chunk.strip().lstrip('data: ').lstrip()
-                chunk_json = None
+
                 try:
                     chunk_json = json.loads(decoded_chunk)
                 # stream ended
@@ -620,7 +620,7 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
 
         return message_dict
 
-    def _num_tokens_from_string(self, model: str, text: str,
+    def _num_tokens_from_string(self, model: str, text: Union[str, list[PromptMessageContent]],
                                 tools: Optional[list[PromptMessageTool]] = None) -> int:
         """
         Approximate num tokens for model with gpt2 tokenizer.
@@ -630,7 +630,16 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
         :param tools: tools for tool calling
         :return: number of tokens
         """
-        num_tokens = self._get_num_tokens_by_gpt2(text)
+        if isinstance(text, str):
+            full_text = text
+        else:
+            full_text = ''
+            for message_content in text:
+                if message_content.type == PromptMessageContentType.TEXT:
+                    message_content = cast(PromptMessageContent, message_content)
+                    full_text += message_content.data
+
+        num_tokens = self._get_num_tokens_by_gpt2(full_text)
 
         if tools:
             num_tokens += self._num_tokens_for_tools(tools)