Browse Source

feat: add claude3 function calling (#5889)

longzhihun 9 months ago
parent
commit
aecdfa2d5c

+ 2 - 0
api/core/model_runtime/model_providers/bedrock/llm/anthropic.claude-3-haiku-v1.yaml

@@ -5,6 +5,8 @@ model_type: llm
 features:
   - agent-thought
   - vision
+  - tool-call
+  - stream-tool-call
 model_properties:
   mode: chat
   context_size: 200000

+ 2 - 0
api/core/model_runtime/model_providers/bedrock/llm/anthropic.claude-3-opus-v1.yaml

@@ -5,6 +5,8 @@ model_type: llm
 features:
   - agent-thought
   - vision
+  - tool-call
+  - stream-tool-call
 model_properties:
   mode: chat
   context_size: 200000

+ 2 - 0
api/core/model_runtime/model_providers/bedrock/llm/anthropic.claude-3-sonnet-v1.5.yaml

@@ -5,6 +5,8 @@ model_type: llm
 features:
   - agent-thought
   - vision
+  - tool-call
+  - stream-tool-call
 model_properties:
   mode: chat
   context_size: 200000

+ 2 - 0
api/core/model_runtime/model_providers/bedrock/llm/anthropic.claude-3-sonnet-v1.yaml

@@ -5,6 +5,8 @@ model_type: llm
 features:
   - agent-thought
   - vision
+  - tool-call
+  - stream-tool-call
 model_properties:
   mode: chat
   context_size: 200000

+ 106 - 39
api/core/model_runtime/model_providers/bedrock/llm/llm.py

@@ -29,6 +29,7 @@ from core.model_runtime.entities.message_entities import (
     PromptMessageTool,
     SystemPromptMessage,
     TextPromptMessageContent,
+    ToolPromptMessage,
     UserPromptMessage,
 )
 from core.model_runtime.errors.invoke import (
@@ -68,7 +69,7 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
         # TODO: consolidate different invocation methods for models based on base model capabilities
         # invoke anthropic models via boto3 client
         if "anthropic" in model:
-            return self._generate_anthropic(model, credentials, prompt_messages, model_parameters, stop, stream, user)
+            return self._generate_anthropic(model, credentials, prompt_messages, model_parameters, stop, stream, user, tools)
         # invoke Cohere models via boto3 client
         if "cohere.command-r" in model:
             return self._generate_cohere_chat(model, credentials, prompt_messages, model_parameters, stop, stream, user, tools)
@@ -151,7 +152,7 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
 
 
     def _generate_anthropic(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict,
-                stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None) -> Union[LLMResult, Generator]:
+                stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None, tools: Optional[list[PromptMessageTool]] = None,) -> Union[LLMResult, Generator]:
         """
         Invoke Anthropic large language model
 
@@ -171,23 +172,24 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
         system, prompt_message_dicts = self._convert_converse_prompt_messages(prompt_messages)
         inference_config, additional_model_fields = self._convert_converse_api_model_parameters(model_parameters, stop)
 
+        parameters = {
+            'modelId': model,
+            'messages': prompt_message_dicts,
+            'inferenceConfig': inference_config,
+            'additionalModelRequestFields': additional_model_fields,
+        }
+
+        if system and len(system) > 0:
+            parameters['system'] = system
+
+        if tools:
+            parameters['toolConfig'] = self._convert_converse_tool_config(tools=tools)
+
         if stream:
-            response = bedrock_client.converse_stream(
-                modelId=model,
-                messages=prompt_message_dicts,
-                system=system,
-                inferenceConfig=inference_config,
-                additionalModelRequestFields=additional_model_fields
-            )
+            response = bedrock_client.converse_stream(**parameters)
             return self._handle_converse_stream_response(model, credentials, response, prompt_messages)
         else:
-            response = bedrock_client.converse(
-                modelId=model,
-                messages=prompt_message_dicts,
-                system=system,
-                inferenceConfig=inference_config,
-                additionalModelRequestFields=additional_model_fields
-            )
+            response = bedrock_client.converse(**parameters)
             return self._handle_converse_response(model, credentials, response, prompt_messages)
 
     def _handle_converse_response(self, model: str, credentials: dict, response: dict,
@@ -246,12 +248,18 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
             output_tokens = 0
             finish_reason = None
             index = 0
+            tool_calls: list[AssistantPromptMessage.ToolCall] = []
+            tool_use = {}
 
             for chunk in response['stream']:
                 if 'messageStart' in chunk:
                     return_model = model
                 elif 'messageStop' in chunk:
                     finish_reason = chunk['messageStop']['stopReason']
+                elif 'contentBlockStart' in chunk:
+                    tool = chunk['contentBlockStart']['start']['toolUse']
+                    tool_use['toolUseId'] = tool['toolUseId']
+                    tool_use['name'] = tool['name']
                 elif 'metadata' in chunk:
                     input_tokens = chunk['metadata']['usage']['inputTokens']
                     output_tokens = chunk['metadata']['usage']['outputTokens']
@@ -260,29 +268,49 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
                         model=return_model,
                         prompt_messages=prompt_messages,
                         delta=LLMResultChunkDelta(
-                            index=index + 1,
+                            index=index,
                             message=AssistantPromptMessage(
-                                content=''
+                                content='',
+                                tool_calls=tool_calls
                             ),
                             finish_reason=finish_reason,
                             usage=usage
                         )
                     )
                 elif 'contentBlockDelta' in chunk:
-                    chunk_text = chunk['contentBlockDelta']['delta']['text'] if chunk['contentBlockDelta']['delta']['text'] else ''
-                    full_assistant_content += chunk_text
-                    assistant_prompt_message = AssistantPromptMessage(
-                        content=chunk_text if chunk_text else '',
-                    )
-                    index = chunk['contentBlockDelta']['contentBlockIndex']
-                    yield LLMResultChunk(
-                        model=model,
-                        prompt_messages=prompt_messages,
-                        delta=LLMResultChunkDelta(
-                            index=index,
-                            message=assistant_prompt_message,
+                    delta = chunk['contentBlockDelta']['delta']
+                    if 'text' in delta:
+                        chunk_text = delta['text'] if delta['text'] else ''
+                        full_assistant_content += chunk_text
+                        assistant_prompt_message = AssistantPromptMessage(
+                            content=chunk_text if chunk_text else '',
                         )
-                    )
+                        index = chunk['contentBlockDelta']['contentBlockIndex']
+                        yield LLMResultChunk(
+                            model=model,
+                            prompt_messages=prompt_messages,
+                            delta=LLMResultChunkDelta(
+                                index=index+1,
+                                message=assistant_prompt_message,
+                            )
+                        )
+                    elif 'toolUse' in delta:
+                        if 'input' not in tool_use:
+                            tool_use['input'] = ''
+                        tool_use['input'] += delta['toolUse']['input']
+                elif 'contentBlockStop' in chunk:
+                    if 'input' in tool_use:
+                        tool_call = AssistantPromptMessage.ToolCall(
+                            id=tool_use['toolUseId'],
+                            type='function',
+                            function=AssistantPromptMessage.ToolCall.ToolCallFunction(
+                                name=tool_use['name'],
+                                arguments=tool_use['input']
+                            )
+                        )
+                        tool_calls.append(tool_call)
+                        tool_use = {}
+
         except Exception as ex:
             raise InvokeError(str(ex))
     
@@ -312,16 +340,10 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
         """
 
         system = []
-        first_loop = True
         for message in prompt_messages:
             if isinstance(message, SystemPromptMessage):
                 message.content=message.content.strip()
-                if first_loop:
-                    system=message.content
-                    first_loop=False
-                else:
-                    system+="\n"
-                    system+=message.content
+                system.append({"text": message.content})
 
         prompt_message_dicts = []
         for message in prompt_messages:
@@ -330,6 +352,25 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
 
         return system, prompt_message_dicts
 
+    def _convert_converse_tool_config(self, tools: Optional[list[PromptMessageTool]] = None) -> dict:
+        tool_config = {}
+        configs = []
+        if tools:
+            for tool in tools:
+                configs.append(
+                    {
+                        "toolSpec": {
+                            "name": tool.name,
+                            "description": tool.description,
+                            "inputSchema": {
+                                "json": tool.parameters
+                            }
+                        }
+                    }
+                )
+            tool_config["tools"] = configs
+            return tool_config
+    
     def _convert_prompt_message_to_dict(self, message: PromptMessage) -> dict:
         """
         Convert PromptMessage to dict
@@ -379,10 +420,32 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
                 message_dict = {"role": "user", "content": sub_messages}
         elif isinstance(message, AssistantPromptMessage):
             message = cast(AssistantPromptMessage, message)
-            message_dict = {"role": "assistant", "content": [{'text': message.content}]}
+            if message.tool_calls:
+                message_dict = {
+                    "role": "assistant", "content":[{
+                        "toolUse": {
+                            "toolUseId": message.tool_calls[0].id,
+                            "name": message.tool_calls[0].function.name,
+                            "input": json.loads(message.tool_calls[0].function.arguments)
+                        }
+                    }]
+                }
+            else:
+                message_dict = {"role": "assistant", "content": [{'text': message.content}]}
         elif isinstance(message, SystemPromptMessage):
             message = cast(SystemPromptMessage, message)
             message_dict = [{'text': message.content}]
+        elif isinstance(message, ToolPromptMessage):
+            message = cast(ToolPromptMessage, message)
+            message_dict = {
+                "role": "user",
+                "content": [{
+                    "toolResult": {
+                        "toolUseId": message.tool_call_id,
+                        "content": [{"json": {"text": message.content}}]
+                    }                   
+                }]
+            }
         else:
             raise ValueError(f"Got unknown type {message}")
 
@@ -401,11 +464,13 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
         """
         prefix = model.split('.')[0]
         model_name = model.split('.')[1]
+        
         if isinstance(prompt_messages, str):
             prompt = prompt_messages
         else:
             prompt = self._convert_messages_to_prompt(prompt_messages, prefix, model_name)
 
+
         return self._get_num_tokens_by_gpt2(prompt)
 
     def validate_credentials(self, model: str, credentials: dict) -> None:
@@ -494,6 +559,8 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
             message_text = f"{ai_prompt} {content}"
         elif isinstance(message, SystemPromptMessage):
             message_text = content
+        elif isinstance(message, ToolPromptMessage):
+            message_text = f"{human_prompt_prefix} {message.content}"
         else:
             raise ValueError(f"Got unknown type {message}")