Ver código fonte

[seanguo] modify bedrock Claude3 invoke method to converse API (#5768)

Co-authored-by: Chenhe Gu <guchenhe@gmail.com>
longzhihun 9 meses atrás
pai
commit
fdfbbde10d

+ 82 - 132
api/core/model_runtime/model_providers/bedrock/llm/llm.py

@@ -1,22 +1,14 @@
+# standard import
 import base64
 import json
 import logging
 import mimetypes
-import time
 from collections.abc import Generator
 from typing import Optional, Union, cast
 
+# 3rd import
 import boto3
 import requests
-from anthropic import AnthropicBedrock, Stream
-from anthropic.types import (
-    ContentBlockDeltaEvent,
-    Message,
-    MessageDeltaEvent,
-    MessageStartEvent,
-    MessageStopEvent,
-    MessageStreamEvent,
-)
 from botocore.config import Config
 from botocore.exceptions import (
     ClientError,
@@ -27,7 +19,8 @@ from botocore.exceptions import (
 )
 from cohere import ChatMessage
 
-from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
+# local import
+from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
 from core.model_runtime.entities.message_entities import (
     AssistantPromptMessage,
     ImagePromptMessageContent,
@@ -38,7 +31,6 @@ from core.model_runtime.entities.message_entities import (
     TextPromptMessageContent,
     UserPromptMessage,
 )
-from core.model_runtime.entities.model_entities import PriceType
 from core.model_runtime.errors.invoke import (
     InvokeAuthorizationError,
     InvokeBadRequestError,
@@ -73,8 +65,8 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
         :param user: unique user id
         :return: full response or stream response chunk generator result
         """
-
-        # invoke anthropic models via anthropic official SDK
+        # TODO: consolidate different invocation methods for models based on base model capabilities
+        # invoke anthropic models via boto3 client
         if "anthropic" in model:
             return self._generate_anthropic(model, credentials, prompt_messages, model_parameters, stop, stream, user)
         # invoke Cohere models via boto3 client
@@ -171,48 +163,34 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
         :param stream: is stream response
         :return: full response or stream response chunk generator result
         """
-        # use Anthropic official SDK references
-        # - https://docs.anthropic.com/claude/reference/claude-on-amazon-bedrock
-        # - https://github.com/anthropics/anthropic-sdk-python
-        client = AnthropicBedrock(
-            aws_access_key=credentials.get("aws_access_key_id"),
-            aws_secret_key=credentials.get("aws_secret_access_key"),
-            aws_region=credentials["aws_region"],
-        )
+        bedrock_client = boto3.client(service_name='bedrock-runtime',
+                                      aws_access_key_id=credentials.get("aws_access_key_id"),
+                                      aws_secret_access_key=credentials.get("aws_secret_access_key"),
+                                      region_name=credentials["aws_region"])
 
-        extra_model_kwargs = {}
-        if stop:
-            extra_model_kwargs['stop_sequences'] = stop
-
-        # Notice: If you request the current version of the SDK to the bedrock server,
-        #         you will get the following error message and you need to wait for the service or SDK to be updated.
-        #         Response:  Error code: 400
-        #                    {'message': 'Malformed input request: #: subject must not be valid against schema
-        #                        {"required":["messages"]}#: extraneous key [metadata] is not permitted, please reformat your input and try again.'}
-        # TODO: Open in the future when the interface is properly supported
-        # if user:
-            # ref: https://github.com/anthropics/anthropic-sdk-python/blob/e84645b07ca5267066700a104b4d8d6a8da1383d/src/anthropic/resources/messages.py#L465
-            # extra_model_kwargs['metadata'] = message_create_params.Metadata(user_id=user)
-
-        system, prompt_message_dicts = self._convert_claude_prompt_messages(prompt_messages)
-
-        if system:
-            extra_model_kwargs['system'] = system
-
-        response = client.messages.create(
-            model=model,
-            messages=prompt_message_dicts,
-            stream=stream,
-            **model_parameters,
-            **extra_model_kwargs
-        )
+        system, prompt_message_dicts = self._convert_converse_prompt_messages(prompt_messages)
+        inference_config, additional_model_fields = self._convert_converse_api_model_parameters(model_parameters, stop)
 
         if stream:
-            return self._handle_claude_stream_response(model, credentials, response, prompt_messages)
-
-        return self._handle_claude_response(model, credentials, response, prompt_messages)
+            response = bedrock_client.converse_stream(
+                modelId=model,
+                messages=prompt_message_dicts,
+                system=system,
+                inferenceConfig=inference_config,
+                additionalModelRequestFields=additional_model_fields
+            )
+            return self._handle_converse_stream_response(model, credentials, response, prompt_messages)
+        else:
+            response = bedrock_client.converse(
+                modelId=model,
+                messages=prompt_message_dicts,
+                system=system,
+                inferenceConfig=inference_config,
+                additionalModelRequestFields=additional_model_fields
+            )
+            return self._handle_converse_response(model, credentials, response, prompt_messages)
 
-    def _handle_claude_response(self, model: str, credentials: dict, response: Message,
+    def _handle_converse_response(self, model: str, credentials: dict, response: dict,
                                 prompt_messages: list[PromptMessage]) -> LLMResult:
         """
         Handle llm chat response
@@ -223,17 +201,16 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
         :param prompt_messages: prompt messages
         :return: full response chunk generator result
         """
-
         # transform assistant message to prompt message
         assistant_prompt_message = AssistantPromptMessage(
-            content=response.content[0].text
+            content=response['output']['message']['content'][0]['text']
         )
 
         # calculate num tokens
-        if response.usage:
+        if response['usage']:
             # transform usage
-            prompt_tokens = response.usage.input_tokens
-            completion_tokens = response.usage.output_tokens
+            prompt_tokens = response['usage']['inputTokens']
+            completion_tokens = response['usage']['outputTokens']
         else:
             # calculate num tokens
             prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages)
@@ -242,17 +219,15 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
         # transform usage
         usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
 
-        # transform response
-        response = LLMResult(
-            model=response.model,
+        result = LLMResult(
+            model=model,
             prompt_messages=prompt_messages,
             message=assistant_prompt_message,
-            usage=usage
+            usage=usage,
         )
+        return result
 
-        return response
-
-    def _handle_claude_stream_response(self, model: str, credentials: dict, response: Stream[MessageStreamEvent],
+    def _handle_converse_stream_response(self, model: str, credentials: dict, response: dict,
                                         prompt_messages: list[PromptMessage], ) -> Generator:
         """
         Handle llm chat stream response
@@ -272,14 +247,14 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
             finish_reason = None
             index = 0
 
-            for chunk in response:
-                if isinstance(chunk, MessageStartEvent):
-                    return_model = chunk.message.model
-                    input_tokens = chunk.message.usage.input_tokens
-                elif isinstance(chunk, MessageDeltaEvent):
-                    output_tokens = chunk.usage.output_tokens
-                    finish_reason = chunk.delta.stop_reason
-                elif isinstance(chunk, MessageStopEvent):
+            for chunk in response['stream']:
+                if 'messageStart' in chunk:
+                    return_model = model
+                elif 'messageStop' in chunk:
+                    finish_reason = chunk['messageStop']['stopReason']
+                elif 'metadata' in chunk:
+                    input_tokens = chunk['metadata']['usage']['inputTokens']
+                    output_tokens = chunk['metadata']['usage']['outputTokens']
                     usage = self._calc_response_usage(model, credentials, input_tokens, output_tokens)
                     yield LLMResultChunk(
                         model=return_model,
@@ -293,13 +268,13 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
                             usage=usage
                         )
                     )
-                elif isinstance(chunk, ContentBlockDeltaEvent):
-                    chunk_text = chunk.delta.text if chunk.delta.text else ''
+                elif 'contentBlockDelta' in chunk:
+                    chunk_text = chunk['contentBlockDelta']['delta']['text'] if chunk['contentBlockDelta']['delta']['text'] else ''
                     full_assistant_content += chunk_text
                     assistant_prompt_message = AssistantPromptMessage(
                         content=chunk_text if chunk_text else '',
                     )
-                    index = chunk.index
+                    index = chunk['contentBlockDelta']['contentBlockIndex']
                     yield LLMResultChunk(
                         model=model,
                         prompt_messages=prompt_messages,
@@ -310,57 +285,33 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
                     )
         except Exception as ex:
             raise InvokeError(str(ex))
+    
+    def _convert_converse_api_model_parameters(self, model_parameters: dict, stop: Optional[list[str]] = None) -> tuple[dict, dict]:
+        inference_config = {}
+        additional_model_fields = {}
+        if 'max_tokens' in model_parameters:
+            inference_config['maxTokens'] = model_parameters['max_tokens']
+
+        if 'temperature' in model_parameters:
+            inference_config['temperature'] = model_parameters['temperature']
+        
+        if 'top_p' in model_parameters:
+            inference_config['topP'] = model_parameters['temperature']
 
-    def _calc_claude_response_usage(self, model: str, credentials: dict, prompt_tokens: int, completion_tokens: int) -> LLMUsage:
-        """
-        Calculate response usage
-
-        :param model: model name
-        :param credentials: model credentials
-        :param prompt_tokens: prompt tokens
-        :param completion_tokens: completion tokens
-        :return: usage
-        """
-        # get prompt price info
-        prompt_price_info = self.get_price(
-            model=model,
-            credentials=credentials,
-            price_type=PriceType.INPUT,
-            tokens=prompt_tokens,
-        )
-
-        # get completion price info
-        completion_price_info = self.get_price(
-            model=model,
-            credentials=credentials,
-            price_type=PriceType.OUTPUT,
-            tokens=completion_tokens
-        )
-
-        # transform usage
-        usage = LLMUsage(
-            prompt_tokens=prompt_tokens,
-            prompt_unit_price=prompt_price_info.unit_price,
-            prompt_price_unit=prompt_price_info.unit,
-            prompt_price=prompt_price_info.total_amount,
-            completion_tokens=completion_tokens,
-            completion_unit_price=completion_price_info.unit_price,
-            completion_price_unit=completion_price_info.unit,
-            completion_price=completion_price_info.total_amount,
-            total_tokens=prompt_tokens + completion_tokens,
-            total_price=prompt_price_info.total_amount + completion_price_info.total_amount,
-            currency=prompt_price_info.currency,
-            latency=time.perf_counter() - self.started_at
-        )
-
-        return usage
+        if stop:
+            inference_config['stopSequences'] = stop
+        
+        if 'top_k' in model_parameters:
+            additional_model_fields['top_k'] = model_parameters['top_k']
+        
+        return inference_config, additional_model_fields
 
-    def _convert_claude_prompt_messages(self, prompt_messages: list[PromptMessage]) -> tuple[str, list[dict]]:
+    def _convert_converse_prompt_messages(self, prompt_messages: list[PromptMessage]) -> tuple[str, list[dict]]:
         """
         Convert prompt messages to dict list and system
         """
 
-        system = ""
+        system = []
         first_loop = True
         for message in prompt_messages:
             if isinstance(message, SystemPromptMessage):
@@ -375,25 +326,24 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
         prompt_message_dicts = []
         for message in prompt_messages:
             if not isinstance(message, SystemPromptMessage):
-                prompt_message_dicts.append(self._convert_claude_prompt_message_to_dict(message))
+                prompt_message_dicts.append(self._convert_prompt_message_to_dict(message))
 
         return system, prompt_message_dicts
 
-    def _convert_claude_prompt_message_to_dict(self, message: PromptMessage) -> dict:
+    def _convert_prompt_message_to_dict(self, message: PromptMessage) -> dict:
         """
         Convert PromptMessage to dict
         """
         if isinstance(message, UserPromptMessage):
             message = cast(UserPromptMessage, message)
             if isinstance(message.content, str):
-                message_dict = {"role": "user", "content": message.content}
+                message_dict = {"role": "user", "content": [{'text': message.content}]}
             else:
                 sub_messages = []
                 for message_content in message.content:
                     if message_content.type == PromptMessageContentType.TEXT:
                         message_content = cast(TextPromptMessageContent, message_content)
                         sub_message_dict = {
-                            "type": "text",
                             "text": message_content.data
                         }
                         sub_messages.append(sub_message_dict)
@@ -404,24 +354,24 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
                             try:
                                 image_content = requests.get(message_content.data).content
                                 mime_type, _ = mimetypes.guess_type(message_content.data)
-                                base64_data = base64.b64encode(image_content).decode('utf-8')
                             except Exception as ex:
                                 raise ValueError(f"Failed to fetch image data from url {message_content.data}, {ex}")
                         else:
                             data_split = message_content.data.split(";base64,")
                             mime_type = data_split[0].replace("data:", "")
                             base64_data = data_split[1]
+                            image_content = base64.b64decode(base64_data)
 
                         if mime_type not in ["image/jpeg", "image/png", "image/gif", "image/webp"]:
                             raise ValueError(f"Unsupported image type {mime_type}, "
                                              f"only support image/jpeg, image/png, image/gif, and image/webp")
 
                         sub_message_dict = {
-                            "type": "image",
-                            "source": {
-                                "type": "base64",
-                                "media_type": mime_type,
-                                "data": base64_data
+                            "image": {
+                                "format": mime_type.replace('image/', ''),
+                                "source": {
+                                    "bytes": image_content
+                                }
                             }
                         }
                         sub_messages.append(sub_message_dict)
@@ -429,10 +379,10 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
                 message_dict = {"role": "user", "content": sub_messages}
         elif isinstance(message, AssistantPromptMessage):
             message = cast(AssistantPromptMessage, message)
-            message_dict = {"role": "assistant", "content": message.content}
+            message_dict = {"role": "assistant", "content": [{'text': message.content}]}
         elif isinstance(message, SystemPromptMessage):
             message = cast(SystemPromptMessage, message)
-            message_dict = {"role": "system", "content": message.content}
+            message_dict = [{'text': message.content}]
         else:
             raise ValueError(f"Got unknown type {message}")
 

+ 19 - 19
api/poetry.lock

@@ -534,41 +534,41 @@ files = [
 
 [[package]]
 name = "boto3"
-version = "1.28.17"
+version = "1.34.136"
 description = "The AWS SDK for Python"
 optional = false
-python-versions = ">= 3.7"
+python-versions = ">=3.8"
 files = [
-    {file = "boto3-1.28.17-py3-none-any.whl", hash = "sha256:bca0526f819e0f19c0f1e6eba3e2d1d6b6a92a45129f98c0d716e5aab6d9444b"},
-    {file = "boto3-1.28.17.tar.gz", hash = "sha256:90f7cfb5e1821af95b1fc084bc50e6c47fa3edc99f32de1a2591faa0c546bea7"},
+    {file = "boto3-1.34.136-py3-none-any.whl", hash = "sha256:d41037e2c680ab8d6c61a0a4ee6bf1fdd9e857f43996672830a95d62d6f6fa79"},
+    {file = "boto3-1.34.136.tar.gz", hash = "sha256:0314e6598f59ee0f34eb4e6d1a0f69fa65c146d2b88a6e837a527a9956ec2731"},
 ]
 
 [package.dependencies]
-botocore = ">=1.31.17,<1.32.0"
+botocore = ">=1.34.136,<1.35.0"
 jmespath = ">=0.7.1,<2.0.0"
-s3transfer = ">=0.6.0,<0.7.0"
+s3transfer = ">=0.10.0,<0.11.0"
 
 [package.extras]
 crt = ["botocore[crt] (>=1.21.0,<2.0a0)"]
 
 [[package]]
 name = "botocore"
-version = "1.31.85"
+version = "1.34.136"
 description = "Low-level, data-driven core of boto 3."
 optional = false
-python-versions = ">= 3.7"
+python-versions = ">=3.8"
 files = [
-    {file = "botocore-1.31.85-py3-none-any.whl", hash = "sha256:b8f35d65f2b45af50c36fc25cc1844d6bd61d38d2148b2ef133b8f10e198555d"},
-    {file = "botocore-1.31.85.tar.gz", hash = "sha256:ce58e688222df73ec5691f934be1a2122a52c9d11d3037b586b3fff16ed6d25f"},
+    {file = "botocore-1.34.136-py3-none-any.whl", hash = "sha256:c63fe9032091fb9e9477706a3ebfa4d0c109b807907051d892ed574f9b573e61"},
+    {file = "botocore-1.34.136.tar.gz", hash = "sha256:7f7135178692b39143c8f152a618d2a3b71065a317569a7102d2306d4946f42f"},
 ]
 
 [package.dependencies]
 jmespath = ">=0.7.1,<2.0.0"
 python-dateutil = ">=2.1,<3.0.0"
-urllib3 = {version = ">=1.25.4,<2.1", markers = "python_version >= \"3.10\""}
+urllib3 = {version = ">=1.25.4,<2.2.0 || >2.2.0,<3", markers = "python_version >= \"3.10\""}
 
 [package.extras]
-crt = ["awscrt (==0.19.12)"]
+crt = ["awscrt (==0.20.11)"]
 
 [[package]]
 name = "bottleneck"
@@ -7032,20 +7032,20 @@ files = [
 
 [[package]]
 name = "s3transfer"
-version = "0.6.2"
+version = "0.10.2"
 description = "An Amazon S3 Transfer Manager"
 optional = false
-python-versions = ">= 3.7"
+python-versions = ">=3.8"
 files = [
-    {file = "s3transfer-0.6.2-py3-none-any.whl", hash = "sha256:b014be3a8a2aab98cfe1abc7229cc5a9a0cf05eb9c1f2b86b230fd8df3f78084"},
-    {file = "s3transfer-0.6.2.tar.gz", hash = "sha256:cab66d3380cca3e70939ef2255d01cd8aece6a4907a9528740f668c4b0611861"},
+    {file = "s3transfer-0.10.2-py3-none-any.whl", hash = "sha256:eca1c20de70a39daee580aef4986996620f365c4e0fda6a86100231d62f1bf69"},
+    {file = "s3transfer-0.10.2.tar.gz", hash = "sha256:0711534e9356d3cc692fdde846b4a1e4b0cb6519971860796e6bc4c7aea00ef6"},
 ]
 
 [package.dependencies]
-botocore = ">=1.12.36,<2.0a.0"
+botocore = ">=1.33.2,<2.0a.0"
 
 [package.extras]
-crt = ["botocore[crt] (>=1.20.29,<2.0a.0)"]
+crt = ["botocore[crt] (>=1.33.2,<2.0a.0)"]
 
 [[package]]
 name = "safetensors"
@@ -9095,4 +9095,4 @@ testing = ["coverage (>=5.0.3)", "zope.event", "zope.testing"]
 [metadata]
 lock-version = "2.0"
 python-versions = "^3.10"
-content-hash = "d40bed69caecf3a2bcd5ec054288d7cb36a9a231fff210d4f1a42745dd3bf604"
+content-hash = "90f0e77567fbe5100d15bf2bc9472007aafc53c2fd594b6a90dd8455dea58582"

+ 1 - 1
api/pyproject.toml

@@ -107,7 +107,7 @@ authlib = "1.3.1"
 azure-identity = "1.16.1"
 azure-storage-blob = "12.13.0"
 beautifulsoup4 = "4.12.2"
-boto3 = "1.28.17"
+boto3 = "1.34.136"
 bs4 = "~0.0.1"
 cachetools = "~5.3.0"
 celery = "~5.3.6"