Browse Source

feat: AWS Bedrock Claude3 (#2864)

Co-authored-by: crazywoola <427733928@qq.com>
Co-authored-by: Chenhe Gu <guchenhe@gmail.com>
Su Yang 1 year ago
parent
commit
45e51e7730

+ 2 - 0
api/core/model_runtime/model_providers/bedrock/llm/_position.yaml

@@ -4,6 +4,8 @@
 - anthropic.claude-v1
 - anthropic.claude-v2
 - anthropic.claude-v2:1
+- anthropic.claude-3-sonnet-v1:0
+- anthropic.claude-3-haiku-v1:0
 - cohere.command-light-text-v14
 - cohere.command-text-v14
 - meta.llama2-13b-chat-v1

+ 57 - 0
api/core/model_runtime/model_providers/bedrock/llm/anthropic.claude-3-haiku-v1.yaml

@@ -0,0 +1,57 @@
+model: anthropic.claude-3-haiku-20240307-v1:0
+label:
+  en_US: Claude 3 Haiku
+model_type: llm
+features:
+  - agent-thought
+  - vision
+model_properties:
+  mode: chat
+  context_size: 200000
+# docs: https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-anthropic-claude-messages.html
+parameter_rules:
+  - name: max_tokens
+    use_template: max_tokens
+    required: true
+    type: int
+    default: 4096
+    min: 1
+    max: 4096
+    help:
+      zh_Hans: 停止前生成的最大令牌数。请注意,Anthropic Claude 模型可能会在达到 max_tokens 的值之前停止生成令牌。不同的 Anthropic Claude 模型对此参数具有不同的最大值。
+      en_US: The maximum number of tokens to generate before stopping. Note that Anthropic Claude models might stop generating tokens before reaching the value of max_tokens. Different Anthropic Claude models have different maximum values for this parameter.
+  # docs: https://docs.anthropic.com/claude/docs/system-prompts
+  - name: temperature
+    use_template: temperature
+    required: false
+    type: float
+    default: 1
+    min: 0.0
+    max: 1.0
+    help:
+      zh_Hans: 生成内容的随机性。
+      en_US: The amount of randomness injected into the response.
+  - name: top_p
+    required: false
+    type: float
+    default: 0.999
+    min: 0.000
+    max: 1.000
+    help:
+      zh_Hans: 在核采样中,Anthropic Claude 按概率递减顺序计算每个后续标记的所有选项的累积分布,并在达到 top_p 指定的特定概率时将其切断。您应该更改温度或top_p,但不能同时更改两者。
+      en_US: In nucleus sampling, Anthropic Claude computes the cumulative distribution over all the options for each subsequent token in decreasing probability order and cuts it off once it reaches a particular probability specified by top_p. You should alter either temperature or top_p, but not both.
+  - name: top_k
+    required: false
+    type: int
+    default: 0
+    min: 0
+    # tip docs from aws has error, max value is 500
+    max: 500
+    help:
+      zh_Hans: 对于每个后续标记,仅从前 K 个选项中进行采样。使用 top_k 删除长尾低概率响应。
+      en_US: Only sample from the top K options for each subsequent token. Use top_k to remove long tail low probability responses.
+pricing:
+  input: '0.003'
+  output: '0.015'
+  unit: '0.001'
+  currency: USD

+ 56 - 0
api/core/model_runtime/model_providers/bedrock/llm/anthropic.claude-3-sonnet-v1.yaml

@@ -0,0 +1,56 @@
+model: anthropic.claude-3-sonnet-20240229-v1:0
+label:
+  en_US: Claude 3 Sonnet
+model_type: llm
+features:
+  - agent-thought
+  - vision
+model_properties:
+  mode: chat
+  context_size: 200000
+# docs: https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-anthropic-claude-messages.html
+parameter_rules:
+  - name: max_tokens
+    use_template: max_tokens
+    required: true
+    type: int
+    default: 4096
+    min: 1
+    max: 4096
+    help:
+      zh_Hans: 停止前生成的最大令牌数。请注意,Anthropic Claude 模型可能会在达到 max_tokens 的值之前停止生成令牌。不同的 Anthropic Claude 模型对此参数具有不同的最大值。
+      en_US: The maximum number of tokens to generate before stopping. Note that Anthropic Claude models might stop generating tokens before reaching the value of max_tokens. Different Anthropic Claude models have different maximum values for this parameter.
+  - name: temperature
+    use_template: temperature
+    required: false
+    type: float
+    default: 1
+    min: 0.0
+    max: 1.0
+    help:
+      zh_Hans: 生成内容的随机性。
+      en_US: The amount of randomness injected into the response.
+  - name: top_p
+    required: false
+    type: float
+    default: 0.999
+    min: 0.000
+    max: 1.000
+    help:
+      zh_Hans: 在核采样中,Anthropic Claude 按概率递减顺序计算每个后续标记的所有选项的累积分布,并在达到 top_p 指定的特定概率时将其切断。您应该更改温度或top_p,但不能同时更改两者。
+      en_US: In nucleus sampling, Anthropic Claude computes the cumulative distribution over all the options for each subsequent token in decreasing probability order and cuts it off once it reaches a particular probability specified by top_p. You should alter either temperature or top_p, but not both.
+  - name: top_k
+    required: false
+    type: int
+    default: 0
+    min: 0
+    # tip docs from aws has error, max value is 500
+    max: 500
+    help:
+      zh_Hans: 对于每个后续标记,仅从前 K 个选项中进行采样。使用 top_k 删除长尾低概率响应。
+      en_US: Only sample from the top K options for each subsequent token. Use top_k to remove long tail low probability responses.
+pricing:
+  input: '0.00025'
+  output: '0.00125'
+  unit: '0.001'
+  currency: USD

+ 291 - 3
api/core/model_runtime/model_providers/bedrock/llm/llm.py

@@ -1,9 +1,22 @@
+import base64
 import json
 import logging
+import mimetypes
+import time
 from collections.abc import Generator
-from typing import Optional, Union
+from typing import Optional, Union, cast
 
 import boto3
+import requests
+from anthropic import AnthropicBedrock, Stream
+from anthropic.types import (
+    ContentBlockDeltaEvent,
+    Message,
+    MessageDeltaEvent,
+    MessageStartEvent,
+    MessageStopEvent,
+    MessageStreamEvent,
+)
 from botocore.config import Config
 from botocore.exceptions import (
     ClientError,
@@ -13,14 +26,18 @@ from botocore.exceptions import (
     UnknownServiceError,
 )
 
-from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
+from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
 from core.model_runtime.entities.message_entities import (
     AssistantPromptMessage,
+    ImagePromptMessageContent,
     PromptMessage,
+    PromptMessageContentType,
     PromptMessageTool,
     SystemPromptMessage,
+    TextPromptMessageContent,
     UserPromptMessage,
 )
+from core.model_runtime.entities.model_entities import PriceType
 from core.model_runtime.errors.invoke import (
     InvokeAuthorizationError,
     InvokeBadRequestError,
@@ -54,9 +71,268 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
         :param user: unique user id
         :return: full response or stream response chunk generator result
         """
+
+        # invoke claude 3 models via anthropic official SDK
+        if "anthropic.claude-3" in model:
+            return self._invoke_claude3(model, credentials, prompt_messages, model_parameters, stop, stream)
         # invoke model
         return self._generate(model, credentials, prompt_messages, model_parameters, stop, stream, user)
 
+    def _invoke_claude3(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict,
+                stop: Optional[list[str]] = None, stream: bool = True) -> Union[LLMResult, Generator]:
+        """
+        Invoke Claude3 large language model
+
+        :param model: model name
+        :param credentials: model credentials
+        :param prompt_messages: prompt messages
+        :param model_parameters: model parameters
+        :param stop: stop words
+        :param stream: is stream response
+        :return: full response or stream response chunk generator result
+        """
+        # use Anthropic official SDK references
+        # - https://docs.anthropic.com/claude/reference/claude-on-amazon-bedrock
+        # - https://github.com/anthropics/anthropic-sdk-python
+        client = AnthropicBedrock(
+            aws_access_key=credentials["aws_access_key_id"],
+            aws_secret_key=credentials["aws_secret_access_key"],
+            aws_region=credentials["aws_region"],
+        )
+
+        system, prompt_message_dicts = self._convert_claude3_prompt_messages(prompt_messages)
+
+        response = client.messages.create(
+            model=model,
+            messages=prompt_message_dicts,
+            stop_sequences=stop if stop else [],
+            system=system,
+            stream=stream,
+            **model_parameters,
+        )
+
+        if stream is False:
+            return self._handle_claude3_response(model, credentials, response, prompt_messages)
+        else:
+            return self._handle_claude3_stream_response(model, credentials, response, prompt_messages)
+
+    def _handle_claude3_response(self, model: str, credentials: dict, response: Message,
+                                prompt_messages: list[PromptMessage]) -> LLMResult:
+        """
+        Handle llm chat response
+
+        :param model: model name
+        :param credentials: credentials
+        :param response: response
+        :param prompt_messages: prompt messages
+        :return: full response chunk generator result
+        """
+
+        # transform assistant message to prompt message
+        assistant_prompt_message = AssistantPromptMessage(
+            content=response.content[0].text
+        )
+
+        # calculate num tokens
+        if response.usage:
+            # transform usage
+            prompt_tokens = response.usage.input_tokens
+            completion_tokens = response.usage.output_tokens
+        else:
+            # calculate num tokens
+            prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages)
+            completion_tokens = self.get_num_tokens(model, credentials, [assistant_prompt_message])
+
+        # transform usage
+        usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
+
+        # transform response
+        response = LLMResult(
+            model=response.model,
+            prompt_messages=prompt_messages,
+            message=assistant_prompt_message,
+            usage=usage
+        )
+
+        return response
+
+    def _handle_claude3_stream_response(self, model: str, credentials: dict, response: Stream[MessageStreamEvent],
+                                        prompt_messages: list[PromptMessage], ) -> Generator:
+        """
+        Handle llm chat stream response
+
+        :param model: model name
+        :param credentials: credentials
+        :param response: response
+        :param prompt_messages: prompt messages
+        :return: full response or stream response chunk generator result
+        """
+
+        try:
+            full_assistant_content = ''
+            return_model = None
+            input_tokens = 0
+            output_tokens = 0
+            finish_reason = None
+            index = 0
+
+            for chunk in response:
+                if isinstance(chunk, MessageStartEvent):
+                    return_model = chunk.message.model
+                    input_tokens = chunk.message.usage.input_tokens
+                elif isinstance(chunk, MessageDeltaEvent):
+                    output_tokens = chunk.usage.output_tokens
+                    finish_reason = chunk.delta.stop_reason
+                elif isinstance(chunk, MessageStopEvent):
+                    usage = self._calc_response_usage(model, credentials, input_tokens, output_tokens)
+                    yield LLMResultChunk(
+                        model=return_model,
+                        prompt_messages=prompt_messages,
+                        delta=LLMResultChunkDelta(
+                            index=index + 1,
+                            message=AssistantPromptMessage(
+                                content=''
+                            ),
+                            finish_reason=finish_reason,
+                            usage=usage
+                        )
+                    )
+                elif isinstance(chunk, ContentBlockDeltaEvent):
+                    chunk_text = chunk.delta.text if chunk.delta.text else ''
+                    full_assistant_content += chunk_text
+                    assistant_prompt_message = AssistantPromptMessage(
+                        content=chunk_text if chunk_text else '',
+                    )
+                    index = chunk.index
+                    yield LLMResultChunk(
+                        model=model,
+                        prompt_messages=prompt_messages,
+                        delta=LLMResultChunkDelta(
+                            index=index,
+                            message=assistant_prompt_message,
+                        )
+                    )
+        except Exception as ex:
+            raise InvokeError(str(ex))
+
+    def _calc_claude3_response_usage(self, model: str, credentials: dict, prompt_tokens: int, completion_tokens: int) -> LLMUsage:
+        """
+        Calculate response usage
+
+        :param model: model name
+        :param credentials: model credentials
+        :param prompt_tokens: prompt tokens
+        :param completion_tokens: completion tokens
+        :return: usage
+        """
+        # get prompt price info
+        prompt_price_info = self.get_price(
+            model=model,
+            credentials=credentials,
+            price_type=PriceType.INPUT,
+            tokens=prompt_tokens,
+        )
+
+        # get completion price info
+        completion_price_info = self.get_price(
+            model=model,
+            credentials=credentials,
+            price_type=PriceType.OUTPUT,
+            tokens=completion_tokens
+        )
+
+        # transform usage
+        usage = LLMUsage(
+            prompt_tokens=prompt_tokens,
+            prompt_unit_price=prompt_price_info.unit_price,
+            prompt_price_unit=prompt_price_info.unit,
+            prompt_price=prompt_price_info.total_amount,
+            completion_tokens=completion_tokens,
+            completion_unit_price=completion_price_info.unit_price,
+            completion_price_unit=completion_price_info.unit,
+            completion_price=completion_price_info.total_amount,
+            total_tokens=prompt_tokens + completion_tokens,
+            total_price=prompt_price_info.total_amount + completion_price_info.total_amount,
+            currency=prompt_price_info.currency,
+            latency=time.perf_counter() - self.started_at
+        )
+
+        return usage
+
+    def _convert_claude3_prompt_messages(self, prompt_messages: list[PromptMessage]) -> tuple[str, list[dict]]:
+        """
+        Convert prompt messages to dict list and system
+        """
+        system = ""
+        prompt_message_dicts = []
+
+        for message in prompt_messages:
+            if isinstance(message, SystemPromptMessage):
+                system += message.content + ("\n" if not system else "")
+            else:
+                prompt_message_dicts.append(self._convert_claude3_prompt_message_to_dict(message))
+
+        return system, prompt_message_dicts
+
+    def _convert_claude3_prompt_message_to_dict(self, message: PromptMessage) -> dict:
+        """
+        Convert PromptMessage to dict
+        """
+        if isinstance(message, UserPromptMessage):
+            message = cast(UserPromptMessage, message)
+            if isinstance(message.content, str):
+                message_dict = {"role": "user", "content": message.content}
+            else:
+                sub_messages = []
+                for message_content in message.content:
+                    if message_content.type == PromptMessageContentType.TEXT:
+                        message_content = cast(TextPromptMessageContent, message_content)
+                        sub_message_dict = {
+                            "type": "text",
+                            "text": message_content.data
+                        }
+                        sub_messages.append(sub_message_dict)
+                    elif message_content.type == PromptMessageContentType.IMAGE:
+                        message_content = cast(ImagePromptMessageContent, message_content)
+                        if not message_content.data.startswith("data:"):
+                            # fetch image data from url
+                            try:
+                                image_content = requests.get(message_content.data).content
+                                mime_type, _ = mimetypes.guess_type(message_content.data)
+                                base64_data = base64.b64encode(image_content).decode('utf-8')
+                            except Exception as ex:
+                                raise ValueError(f"Failed to fetch image data from url {message_content.data}, {ex}")
+                        else:
+                            data_split = message_content.data.split(";base64,")
+                            mime_type = data_split[0].replace("data:", "")
+                            base64_data = data_split[1]
+
+                        if mime_type not in ["image/jpeg", "image/png", "image/gif", "image/webp"]:
+                            raise ValueError(f"Unsupported image type {mime_type}, "
+                                             f"only support image/jpeg, image/png, image/gif, and image/webp")
+
+                        sub_message_dict = {
+                            "type": "image",
+                            "source": {
+                                "type": "base64",
+                                "media_type": mime_type,
+                                "data": base64_data
+                            }
+                        }
+                        sub_messages.append(sub_message_dict)
+
+                message_dict = {"role": "user", "content": sub_messages}
+        elif isinstance(message, AssistantPromptMessage):
+            message = cast(AssistantPromptMessage, message)
+            message_dict = {"role": "assistant", "content": message.content}
+        elif isinstance(message, SystemPromptMessage):
+            message = cast(SystemPromptMessage, message)
+            message_dict = {"role": "system", "content": message.content}
+        else:
+            raise ValueError(f"Got unknown type {message}")
+
+        return message_dict
+
     def get_num_tokens(self, model: str, credentials: dict, messages: list[PromptMessage] | str,
                        tools: Optional[list[PromptMessageTool]] = None) -> int:
         """
@@ -101,7 +377,19 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
         :param credentials: model credentials
         :return:
         """
-        
+
+        if "anthropic.claude-3" in model:
+            try:
+                self._invoke_claude3(model=model,
+                                        credentials=credentials,
+                                        prompt_messages=[{"role": "user", "content": "ping"}],
+                                        model_parameters={},
+                                        stop=None,
+                                        stream=False)
+
+            except Exception as ex:
+                raise CredentialsValidateFailedError(str(ex))
+
         try:
             ping_message = UserPromptMessage(content="ping")
             self._generate(model=model,

+ 1 - 1
api/requirements.txt

@@ -36,7 +36,7 @@ python-docx~=1.1.0
 pypdfium2==4.16.0
 resend~=0.7.0
 pyjwt~=2.8.0
-anthropic~=0.17.0
+anthropic~=0.20.0
 newspaper3k==0.2.8
 google-api-python-client==2.90.0
 wikipedia==1.4.0