Browse Source

feat: support configurate openai compatible stream tool call (#3467)

Yeuoly 1 year ago
parent
commit
8f8e9de601

+ 71 - 55
api/core/model_runtime/model_providers/openai_api_compatible/llm/llm.py

@@ -170,13 +170,14 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
         features = []
 
         function_calling_type = credentials.get('function_calling_type', 'no_call')
-        if function_calling_type == 'function_call':
+        if function_calling_type in ['function_call']:
             features.append(ModelFeature.TOOL_CALL)
-        endpoint_url = credentials["endpoint_url"]
-        # if not endpoint_url.endswith('/'):
-        #     endpoint_url += '/'
-        # if 'https://api.openai.com/v1/' == endpoint_url:
-        #     features.append(ModelFeature.STREAM_TOOL_CALL)
+        elif function_calling_type in ['tool_call']:
+            features.append(ModelFeature.MULTI_TOOL_CALL)
+
+        stream_function_calling = credentials.get('stream_function_calling', 'supported')
+        if stream_function_calling == 'supported':
+            features.append(ModelFeature.STREAM_TOOL_CALL)
 
         vision_support = credentials.get('vision_support', 'not_support')
         if vision_support == 'support':
@@ -386,29 +387,37 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
 
         def increase_tool_call(new_tool_calls: list[AssistantPromptMessage.ToolCall]):
             def get_tool_call(tool_call_id: str):
-                tool_call = next(
-                    (tool_call for tool_call in tools_calls if tool_call.id == tool_call_id), None
-                )
+                if not tool_call_id:
+                    return tools_calls[-1]
+
+                tool_call = next((tool_call for tool_call in tools_calls if tool_call.id == tool_call_id), None)
                 if tool_call is None:
                     tool_call = AssistantPromptMessage.ToolCall(
-                        id='', 
-                        type='function', 
+                        id=tool_call_id,
+                        type="function",
                         function=AssistantPromptMessage.ToolCall.ToolCallFunction(
-                            name='',
-                            arguments=''
+                            name="",
+                            arguments=""
                         )
                     )
                     tools_calls.append(tool_call)
+
                 return tool_call
 
             for new_tool_call in new_tool_calls:
                 # get tool call
-                tool_call = get_tool_call(new_tool_call.id)
+                tool_call = get_tool_call(new_tool_call.function.name)
                 # update tool call
-                tool_call.id = new_tool_call.id
-                tool_call.type = new_tool_call.type
-                tool_call.function.name = new_tool_call.function.name
-                tool_call.function.arguments += new_tool_call.function.arguments
+                if new_tool_call.id:
+                    tool_call.id = new_tool_call.id
+                if new_tool_call.type:
+                    tool_call.type = new_tool_call.type
+                if new_tool_call.function.name:
+                    tool_call.function.name = new_tool_call.function.name
+                if new_tool_call.function.arguments:
+                    tool_call.function.arguments += new_tool_call.function.arguments
+
+        finish_reason = 'Unknown'
 
         for chunk in response.iter_lines(decode_unicode=True, delimiter=delimiter):
             if chunk:
@@ -438,7 +447,17 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
                     delta = choice['delta']
                     delta_content = delta.get('content')
 
-                    assistant_message_tool_calls = delta.get('tool_calls', None)
+                    assistant_message_tool_calls = None
+
+                    if 'tool_calls' in delta and credentials.get('function_calling_type', 'no_call') == 'tool_call':
+                        assistant_message_tool_calls = delta.get('tool_calls', None)
+                    elif 'function_call' in delta and credentials.get('function_calling_type', 'no_call') == 'function_call':
+                        assistant_message_tool_calls = [{
+                            'id': 'tool_call_id',
+                            'type': 'function',
+                            'function': delta.get('function_call', {})
+                        }]
+
                     # assistant_message_function_call = delta.delta.function_call
 
                     # extract tool calls from response
@@ -449,15 +468,13 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
                     if delta_content is None or delta_content == '':
                         continue
 
-                    # function_call = self._extract_response_function_call(assistant_message_function_call)
-                    # tool_calls = [function_call] if function_call else []
-
                     # transform assistant message to prompt message
                     assistant_prompt_message = AssistantPromptMessage(
                         content=delta_content,
-                        tool_calls=tool_calls if assistant_message_tool_calls else []
                     )
 
+                    # reset tool calls
+                    tool_calls = []
                     full_assistant_content += delta_content
                 elif 'text' in choice:
                     choice_text = choice.get('text', '')
@@ -470,37 +487,36 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
                 else:
                     continue
 
-                # check payload indicator for completion
-                if finish_reason is not None:
-                    yield LLMResultChunk(
-                        model=model,
-                        prompt_messages=prompt_messages,
-                        delta=LLMResultChunkDelta(
-                            index=chunk_index,
-                            message=AssistantPromptMessage(
-                                tool_calls=tools_calls,
-                            ),
-                            finish_reason=finish_reason
-                        )
-                    )
-
-                    yield create_final_llm_result_chunk(
+                yield LLMResultChunk(
+                    model=model,
+                    prompt_messages=prompt_messages,
+                    delta=LLMResultChunkDelta(
                         index=chunk_index,
                         message=assistant_prompt_message,
-                        finish_reason=finish_reason
-                    )
-                else:
-                    yield LLMResultChunk(
-                        model=model,
-                        prompt_messages=prompt_messages,
-                        delta=LLMResultChunkDelta(
-                            index=chunk_index,
-                            message=assistant_prompt_message,
-                        )
                     )
+                )
 
             chunk_index += 1
 
+        if tools_calls:
+            yield LLMResultChunk(
+                model=model,
+                prompt_messages=prompt_messages,
+                delta=LLMResultChunkDelta(
+                    index=chunk_index,
+                    message=AssistantPromptMessage(
+                        tool_calls=tools_calls,
+                        content=""
+                    ),
+                )
+            )
+
+        yield create_final_llm_result_chunk(
+            index=chunk_index,
+            message=AssistantPromptMessage(content=""),
+            finish_reason=finish_reason
+        )
+
     def _handle_generate_response(self, model: str, credentials: dict, response: requests.Response,
                                   prompt_messages: list[PromptMessage]) -> LLMResult:
 
@@ -757,13 +773,13 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
         if response_tool_calls:
             for response_tool_call in response_tool_calls:
                 function = AssistantPromptMessage.ToolCall.ToolCallFunction(
-                    name=response_tool_call["function"]["name"],
-                    arguments=response_tool_call["function"]["arguments"]
+                    name=response_tool_call.get("function", {}).get("name", ""),
+                    arguments=response_tool_call.get("function", {}).get("arguments", "")
                 )
 
                 tool_call = AssistantPromptMessage.ToolCall(
-                    id=response_tool_call["id"],
-                    type=response_tool_call["type"],
+                    id=response_tool_call.get("id", ""),
+                    type=response_tool_call.get("type", ""),
                     function=function
                 )
                 tool_calls.append(tool_call)
@@ -781,12 +797,12 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
         tool_call = None
         if response_function_call:
             function = AssistantPromptMessage.ToolCall.ToolCallFunction(
-                name=response_function_call['name'],
-                arguments=response_function_call['arguments']
+                name=response_function_call.get('name', ''),
+                arguments=response_function_call.get('arguments', '')
             )
 
             tool_call = AssistantPromptMessage.ToolCall(
-                id=response_function_call['name'],
+                id=response_function_call.get('id', ''),
                 type="function",
                 function=function
             )

+ 23 - 5
api/core/model_runtime/model_providers/openai_api_compatible/openai_api_compatible.yaml

@@ -87,13 +87,31 @@ model_credential_schema:
       options:
         - value: function_call
           label:
+            en_US: Function Call
+            zh_Hans: Function Call
+        - value: tool_call
+          label:
+            en_US: Tool Call
+            zh_Hans: Tool Call
+        - value: no_call
+          label:
+            en_US: Not Support
+            zh_Hans: 不支持
+    - variable: stream_function_calling
+      show_on:
+        - variable: __model_type
+          value: llm
+      label:
+        en_US: Stream function calling
+      type: select
+      required: false
+      default: not_supported
+      options:
+        - value: supported
+          label:
             en_US: Support
             zh_Hans: 支持
-#        - value: tool_call
-#          label:
-#            en_US: Tool Call
-#            zh_Hans: Tool Call
-        - value: no_call
+        - value: not_supported
           label:
             en_US: Not Support
             zh_Hans: 不支持