|
@@ -1,9 +1,22 @@
|
|
|
+import base64
|
|
|
import json
|
|
|
import logging
|
|
|
+import mimetypes
|
|
|
+import time
|
|
|
from collections.abc import Generator
|
|
|
-from typing import Optional, Union
|
|
|
+from typing import Optional, Union, cast
|
|
|
|
|
|
import boto3
|
|
|
+import requests
|
|
|
+from anthropic import AnthropicBedrock, Stream
|
|
|
+from anthropic.types import (
|
|
|
+ ContentBlockDeltaEvent,
|
|
|
+ Message,
|
|
|
+ MessageDeltaEvent,
|
|
|
+ MessageStartEvent,
|
|
|
+ MessageStopEvent,
|
|
|
+ MessageStreamEvent,
|
|
|
+)
|
|
|
from botocore.config import Config
|
|
|
from botocore.exceptions import (
|
|
|
ClientError,
|
|
@@ -13,14 +26,18 @@ from botocore.exceptions import (
|
|
|
UnknownServiceError,
|
|
|
)
|
|
|
|
|
|
-from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
|
|
|
+from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
|
|
|
from core.model_runtime.entities.message_entities import (
|
|
|
AssistantPromptMessage,
|
|
|
+ ImagePromptMessageContent,
|
|
|
PromptMessage,
|
|
|
+ PromptMessageContentType,
|
|
|
PromptMessageTool,
|
|
|
SystemPromptMessage,
|
|
|
+ TextPromptMessageContent,
|
|
|
UserPromptMessage,
|
|
|
)
|
|
|
+from core.model_runtime.entities.model_entities import PriceType
|
|
|
from core.model_runtime.errors.invoke import (
|
|
|
InvokeAuthorizationError,
|
|
|
InvokeBadRequestError,
|
|
@@ -54,9 +71,268 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
|
|
|
:param user: unique user id
|
|
|
:return: full response or stream response chunk generator result
|
|
|
"""
|
|
|
+
|
|
|
+ # invoke claude 3 models via anthropic official SDK
|
|
|
+ if "anthropic.claude-3" in model:
|
|
|
+ return self._invoke_claude3(model, credentials, prompt_messages, model_parameters, stop, stream)
|
|
|
# invoke model
|
|
|
return self._generate(model, credentials, prompt_messages, model_parameters, stop, stream, user)
|
|
|
|
|
|
+ def _invoke_claude3(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict,
|
|
|
+ stop: Optional[list[str]] = None, stream: bool = True) -> Union[LLMResult, Generator]:
|
|
|
+ """
|
|
|
+ Invoke Claude3 large language model
|
|
|
+
|
|
|
+ :param model: model name
|
|
|
+ :param credentials: model credentials
|
|
|
+ :param prompt_messages: prompt messages
|
|
|
+ :param model_parameters: model parameters
|
|
|
+ :param stop: stop words
|
|
|
+ :param stream: is stream response
|
|
|
+ :return: full response or stream response chunk generator result
|
|
|
+ """
|
|
|
+ # use Anthropic official SDK references
|
|
|
+ # - https://docs.anthropic.com/claude/reference/claude-on-amazon-bedrock
|
|
|
+ # - https://github.com/anthropics/anthropic-sdk-python
|
|
|
+ client = AnthropicBedrock(
|
|
|
+ aws_access_key=credentials["aws_access_key_id"],
|
|
|
+ aws_secret_key=credentials["aws_secret_access_key"],
|
|
|
+ aws_region=credentials["aws_region"],
|
|
|
+ )
|
|
|
+
|
|
|
+ system, prompt_message_dicts = self._convert_claude3_prompt_messages(prompt_messages)
|
|
|
+
|
|
|
+ response = client.messages.create(
|
|
|
+ model=model,
|
|
|
+ messages=prompt_message_dicts,
|
|
|
+ stop_sequences=stop if stop else [],
|
|
|
+ system=system,
|
|
|
+ stream=stream,
|
|
|
+ **model_parameters,
|
|
|
+ )
|
|
|
+
|
|
|
+ if stream is False:
|
|
|
+ return self._handle_claude3_response(model, credentials, response, prompt_messages)
|
|
|
+ else:
|
|
|
+ return self._handle_claude3_stream_response(model, credentials, response, prompt_messages)
|
|
|
+
|
|
|
+ def _handle_claude3_response(self, model: str, credentials: dict, response: Message,
|
|
|
+ prompt_messages: list[PromptMessage]) -> LLMResult:
|
|
|
+ """
|
|
|
+ Handle llm chat response
|
|
|
+
|
|
|
+ :param model: model name
|
|
|
+ :param credentials: credentials
|
|
|
+ :param response: response
|
|
|
+ :param prompt_messages: prompt messages
|
|
|
+ :return: full response chunk generator result
|
|
|
+ """
|
|
|
+
|
|
|
+ # transform assistant message to prompt message
|
|
|
+ assistant_prompt_message = AssistantPromptMessage(
|
|
|
+ content=response.content[0].text
|
|
|
+ )
|
|
|
+
|
|
|
+ # calculate num tokens
|
|
|
+ if response.usage:
|
|
|
+ # transform usage
|
|
|
+ prompt_tokens = response.usage.input_tokens
|
|
|
+ completion_tokens = response.usage.output_tokens
|
|
|
+ else:
|
|
|
+ # calculate num tokens
|
|
|
+ prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages)
|
|
|
+ completion_tokens = self.get_num_tokens(model, credentials, [assistant_prompt_message])
|
|
|
+
|
|
|
+ # transform usage
|
|
|
+ usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
|
|
|
+
|
|
|
+ # transform response
|
|
|
+ response = LLMResult(
|
|
|
+ model=response.model,
|
|
|
+ prompt_messages=prompt_messages,
|
|
|
+ message=assistant_prompt_message,
|
|
|
+ usage=usage
|
|
|
+ )
|
|
|
+
|
|
|
+ return response
|
|
|
+
|
|
|
+ def _handle_claude3_stream_response(self, model: str, credentials: dict, response: Stream[MessageStreamEvent],
|
|
|
+ prompt_messages: list[PromptMessage], ) -> Generator:
|
|
|
+ """
|
|
|
+ Handle llm chat stream response
|
|
|
+
|
|
|
+ :param model: model name
|
|
|
+ :param credentials: credentials
|
|
|
+ :param response: response
|
|
|
+ :param prompt_messages: prompt messages
|
|
|
+ :return: full response or stream response chunk generator result
|
|
|
+ """
|
|
|
+
|
|
|
+ try:
|
|
|
+ full_assistant_content = ''
|
|
|
+ return_model = None
|
|
|
+ input_tokens = 0
|
|
|
+ output_tokens = 0
|
|
|
+ finish_reason = None
|
|
|
+ index = 0
|
|
|
+
|
|
|
+ for chunk in response:
|
|
|
+ if isinstance(chunk, MessageStartEvent):
|
|
|
+ return_model = chunk.message.model
|
|
|
+ input_tokens = chunk.message.usage.input_tokens
|
|
|
+ elif isinstance(chunk, MessageDeltaEvent):
|
|
|
+ output_tokens = chunk.usage.output_tokens
|
|
|
+ finish_reason = chunk.delta.stop_reason
|
|
|
+ elif isinstance(chunk, MessageStopEvent):
|
|
|
+ usage = self._calc_response_usage(model, credentials, input_tokens, output_tokens)
|
|
|
+ yield LLMResultChunk(
|
|
|
+ model=return_model,
|
|
|
+ prompt_messages=prompt_messages,
|
|
|
+ delta=LLMResultChunkDelta(
|
|
|
+ index=index + 1,
|
|
|
+ message=AssistantPromptMessage(
|
|
|
+ content=''
|
|
|
+ ),
|
|
|
+ finish_reason=finish_reason,
|
|
|
+ usage=usage
|
|
|
+ )
|
|
|
+ )
|
|
|
+ elif isinstance(chunk, ContentBlockDeltaEvent):
|
|
|
+ chunk_text = chunk.delta.text if chunk.delta.text else ''
|
|
|
+ full_assistant_content += chunk_text
|
|
|
+ assistant_prompt_message = AssistantPromptMessage(
|
|
|
+ content=chunk_text if chunk_text else '',
|
|
|
+ )
|
|
|
+ index = chunk.index
|
|
|
+ yield LLMResultChunk(
|
|
|
+ model=model,
|
|
|
+ prompt_messages=prompt_messages,
|
|
|
+ delta=LLMResultChunkDelta(
|
|
|
+ index=index,
|
|
|
+ message=assistant_prompt_message,
|
|
|
+ )
|
|
|
+ )
|
|
|
+ except Exception as ex:
|
|
|
+ raise InvokeError(str(ex))
|
|
|
+
|
|
|
+ def _calc_claude3_response_usage(self, model: str, credentials: dict, prompt_tokens: int, completion_tokens: int) -> LLMUsage:
|
|
|
+ """
|
|
|
+ Calculate response usage
|
|
|
+
|
|
|
+ :param model: model name
|
|
|
+ :param credentials: model credentials
|
|
|
+ :param prompt_tokens: prompt tokens
|
|
|
+ :param completion_tokens: completion tokens
|
|
|
+ :return: usage
|
|
|
+ """
|
|
|
+ # get prompt price info
|
|
|
+ prompt_price_info = self.get_price(
|
|
|
+ model=model,
|
|
|
+ credentials=credentials,
|
|
|
+ price_type=PriceType.INPUT,
|
|
|
+ tokens=prompt_tokens,
|
|
|
+ )
|
|
|
+
|
|
|
+ # get completion price info
|
|
|
+ completion_price_info = self.get_price(
|
|
|
+ model=model,
|
|
|
+ credentials=credentials,
|
|
|
+ price_type=PriceType.OUTPUT,
|
|
|
+ tokens=completion_tokens
|
|
|
+ )
|
|
|
+
|
|
|
+ # transform usage
|
|
|
+ usage = LLMUsage(
|
|
|
+ prompt_tokens=prompt_tokens,
|
|
|
+ prompt_unit_price=prompt_price_info.unit_price,
|
|
|
+ prompt_price_unit=prompt_price_info.unit,
|
|
|
+ prompt_price=prompt_price_info.total_amount,
|
|
|
+ completion_tokens=completion_tokens,
|
|
|
+ completion_unit_price=completion_price_info.unit_price,
|
|
|
+ completion_price_unit=completion_price_info.unit,
|
|
|
+ completion_price=completion_price_info.total_amount,
|
|
|
+ total_tokens=prompt_tokens + completion_tokens,
|
|
|
+ total_price=prompt_price_info.total_amount + completion_price_info.total_amount,
|
|
|
+ currency=prompt_price_info.currency,
|
|
|
+ latency=time.perf_counter() - self.started_at
|
|
|
+ )
|
|
|
+
|
|
|
+ return usage
|
|
|
+
|
|
|
+ def _convert_claude3_prompt_messages(self, prompt_messages: list[PromptMessage]) -> tuple[str, list[dict]]:
|
|
|
+ """
|
|
|
+ Convert prompt messages to dict list and system
|
|
|
+ """
|
|
|
+ system = ""
|
|
|
+ prompt_message_dicts = []
|
|
|
+
|
|
|
+ for message in prompt_messages:
|
|
|
+ if isinstance(message, SystemPromptMessage):
|
|
|
+ system += message.content + ("\n" if not system else "")
|
|
|
+ else:
|
|
|
+ prompt_message_dicts.append(self._convert_claude3_prompt_message_to_dict(message))
|
|
|
+
|
|
|
+ return system, prompt_message_dicts
|
|
|
+
|
|
|
+ def _convert_claude3_prompt_message_to_dict(self, message: PromptMessage) -> dict:
|
|
|
+ """
|
|
|
+ Convert PromptMessage to dict
|
|
|
+ """
|
|
|
+ if isinstance(message, UserPromptMessage):
|
|
|
+ message = cast(UserPromptMessage, message)
|
|
|
+ if isinstance(message.content, str):
|
|
|
+ message_dict = {"role": "user", "content": message.content}
|
|
|
+ else:
|
|
|
+ sub_messages = []
|
|
|
+ for message_content in message.content:
|
|
|
+ if message_content.type == PromptMessageContentType.TEXT:
|
|
|
+ message_content = cast(TextPromptMessageContent, message_content)
|
|
|
+ sub_message_dict = {
|
|
|
+ "type": "text",
|
|
|
+ "text": message_content.data
|
|
|
+ }
|
|
|
+ sub_messages.append(sub_message_dict)
|
|
|
+ elif message_content.type == PromptMessageContentType.IMAGE:
|
|
|
+ message_content = cast(ImagePromptMessageContent, message_content)
|
|
|
+ if not message_content.data.startswith("data:"):
|
|
|
+ # fetch image data from url
|
|
|
+ try:
|
|
|
+ image_content = requests.get(message_content.data).content
|
|
|
+ mime_type, _ = mimetypes.guess_type(message_content.data)
|
|
|
+ base64_data = base64.b64encode(image_content).decode('utf-8')
|
|
|
+ except Exception as ex:
|
|
|
+ raise ValueError(f"Failed to fetch image data from url {message_content.data}, {ex}")
|
|
|
+ else:
|
|
|
+ data_split = message_content.data.split(";base64,")
|
|
|
+ mime_type = data_split[0].replace("data:", "")
|
|
|
+ base64_data = data_split[1]
|
|
|
+
|
|
|
+ if mime_type not in ["image/jpeg", "image/png", "image/gif", "image/webp"]:
|
|
|
+ raise ValueError(f"Unsupported image type {mime_type}, "
|
|
|
+ f"only support image/jpeg, image/png, image/gif, and image/webp")
|
|
|
+
|
|
|
+ sub_message_dict = {
|
|
|
+ "type": "image",
|
|
|
+ "source": {
|
|
|
+ "type": "base64",
|
|
|
+ "media_type": mime_type,
|
|
|
+ "data": base64_data
|
|
|
+ }
|
|
|
+ }
|
|
|
+ sub_messages.append(sub_message_dict)
|
|
|
+
|
|
|
+ message_dict = {"role": "user", "content": sub_messages}
|
|
|
+ elif isinstance(message, AssistantPromptMessage):
|
|
|
+ message = cast(AssistantPromptMessage, message)
|
|
|
+ message_dict = {"role": "assistant", "content": message.content}
|
|
|
+ elif isinstance(message, SystemPromptMessage):
|
|
|
+ message = cast(SystemPromptMessage, message)
|
|
|
+ message_dict = {"role": "system", "content": message.content}
|
|
|
+ else:
|
|
|
+ raise ValueError(f"Got unknown type {message}")
|
|
|
+
|
|
|
+ return message_dict
|
|
|
+
|
|
|
def get_num_tokens(self, model: str, credentials: dict, messages: list[PromptMessage] | str,
|
|
|
tools: Optional[list[PromptMessageTool]] = None) -> int:
|
|
|
"""
|
|
@@ -101,7 +377,19 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
|
|
|
:param credentials: model credentials
|
|
|
:return:
|
|
|
"""
|
|
|
-
|
|
|
+
|
|
|
+ if "anthropic.claude-3" in model:
|
|
|
+ try:
|
|
|
+ self._invoke_claude3(model=model,
|
|
|
+ credentials=credentials,
|
|
|
+ prompt_messages=[{"role": "user", "content": "ping"}],
|
|
|
+ model_parameters={},
|
|
|
+ stop=None,
|
|
|
+ stream=False)
|
|
|
+
|
|
|
+ except Exception as ex:
|
|
|
+ raise CredentialsValidateFailedError(str(ex))
|
|
|
+
|
|
|
try:
|
|
|
ping_message = UserPromptMessage(content="ping")
|
|
|
self._generate(model=model,
|