|
@@ -416,7 +416,7 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
|
|
|
if chunk.startswith(':'):
|
|
|
continue
|
|
|
decoded_chunk = chunk.strip().lstrip('data: ').lstrip()
|
|
|
- chunk_json = None
|
|
|
+
|
|
|
try:
|
|
|
chunk_json = json.loads(decoded_chunk)
|
|
|
# stream ended
|
|
@@ -620,7 +620,7 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
|
|
|
|
|
|
return message_dict
|
|
|
|
|
|
- def _num_tokens_from_string(self, model: str, text: str,
|
|
|
+ def _num_tokens_from_string(self, model: str, text: Union[str, list[PromptMessageContent]],
|
|
|
tools: Optional[list[PromptMessageTool]] = None) -> int:
|
|
|
"""
|
|
|
Approximate num tokens for model with gpt2 tokenizer.
|
|
@@ -630,7 +630,16 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
|
|
|
:param tools: tools for tool calling
|
|
|
:return: number of tokens
|
|
|
"""
|
|
|
- num_tokens = self._get_num_tokens_by_gpt2(text)
|
|
|
+ if isinstance(text, str):
|
|
|
+ full_text = text
|
|
|
+ else:
|
|
|
+ full_text = ''
|
|
|
+ for message_content in text:
|
|
|
+ if message_content.type == PromptMessageContentType.TEXT:
|
|
|
+ message_content = cast(PromptMessageContent, message_content)
|
|
|
+ full_text += message_content.data
|
|
|
+
|
|
|
+ num_tokens = self._get_num_tokens_by_gpt2(full_text)
|
|
|
|
|
|
if tools:
|
|
|
num_tokens += self._num_tokens_for_tools(tools)
|