|
@@ -1,22 +1,14 @@
|
|
|
+# standard import
|
|
|
import base64
|
|
|
import json
|
|
|
import logging
|
|
|
import mimetypes
|
|
|
-import time
|
|
|
from collections.abc import Generator
|
|
|
from typing import Optional, Union, cast
|
|
|
|
|
|
+# 3rd import
|
|
|
import boto3
|
|
|
import requests
|
|
|
-from anthropic import AnthropicBedrock, Stream
|
|
|
-from anthropic.types import (
|
|
|
- ContentBlockDeltaEvent,
|
|
|
- Message,
|
|
|
- MessageDeltaEvent,
|
|
|
- MessageStartEvent,
|
|
|
- MessageStopEvent,
|
|
|
- MessageStreamEvent,
|
|
|
-)
|
|
|
from botocore.config import Config
|
|
|
from botocore.exceptions import (
|
|
|
ClientError,
|
|
@@ -27,7 +19,8 @@ from botocore.exceptions import (
|
|
|
)
|
|
|
from cohere import ChatMessage
|
|
|
|
|
|
-from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
|
|
|
+# local import
|
|
|
+from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
|
|
|
from core.model_runtime.entities.message_entities import (
|
|
|
AssistantPromptMessage,
|
|
|
ImagePromptMessageContent,
|
|
@@ -38,7 +31,6 @@ from core.model_runtime.entities.message_entities import (
|
|
|
TextPromptMessageContent,
|
|
|
UserPromptMessage,
|
|
|
)
|
|
|
-from core.model_runtime.entities.model_entities import PriceType
|
|
|
from core.model_runtime.errors.invoke import (
|
|
|
InvokeAuthorizationError,
|
|
|
InvokeBadRequestError,
|
|
@@ -73,8 +65,8 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
|
|
|
:param user: unique user id
|
|
|
:return: full response or stream response chunk generator result
|
|
|
"""
|
|
|
-
|
|
|
- # invoke anthropic models via anthropic official SDK
|
|
|
+ # TODO: consolidate different invocation methods for models based on base model capabilities
|
|
|
+ # invoke anthropic models via boto3 client
|
|
|
if "anthropic" in model:
|
|
|
return self._generate_anthropic(model, credentials, prompt_messages, model_parameters, stop, stream, user)
|
|
|
# invoke Cohere models via boto3 client
|
|
@@ -171,48 +163,34 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
|
|
|
:param stream: is stream response
|
|
|
:return: full response or stream response chunk generator result
|
|
|
"""
|
|
|
- # use Anthropic official SDK references
|
|
|
- # - https://docs.anthropic.com/claude/reference/claude-on-amazon-bedrock
|
|
|
- # - https://github.com/anthropics/anthropic-sdk-python
|
|
|
- client = AnthropicBedrock(
|
|
|
- aws_access_key=credentials.get("aws_access_key_id"),
|
|
|
- aws_secret_key=credentials.get("aws_secret_access_key"),
|
|
|
- aws_region=credentials["aws_region"],
|
|
|
- )
|
|
|
+ bedrock_client = boto3.client(service_name='bedrock-runtime',
|
|
|
+ aws_access_key_id=credentials.get("aws_access_key_id"),
|
|
|
+ aws_secret_access_key=credentials.get("aws_secret_access_key"),
|
|
|
+ region_name=credentials["aws_region"])
|
|
|
|
|
|
- extra_model_kwargs = {}
|
|
|
- if stop:
|
|
|
- extra_model_kwargs['stop_sequences'] = stop
|
|
|
-
|
|
|
- # Notice: If you request the current version of the SDK to the bedrock server,
|
|
|
- # you will get the following error message and you need to wait for the service or SDK to be updated.
|
|
|
- # Response: Error code: 400
|
|
|
- # {'message': 'Malformed input request: #: subject must not be valid against schema
|
|
|
- # {"required":["messages"]}#: extraneous key [metadata] is not permitted, please reformat your input and try again.'}
|
|
|
- # TODO: Open in the future when the interface is properly supported
|
|
|
- # if user:
|
|
|
- # ref: https://github.com/anthropics/anthropic-sdk-python/blob/e84645b07ca5267066700a104b4d8d6a8da1383d/src/anthropic/resources/messages.py#L465
|
|
|
- # extra_model_kwargs['metadata'] = message_create_params.Metadata(user_id=user)
|
|
|
-
|
|
|
- system, prompt_message_dicts = self._convert_claude_prompt_messages(prompt_messages)
|
|
|
-
|
|
|
- if system:
|
|
|
- extra_model_kwargs['system'] = system
|
|
|
-
|
|
|
- response = client.messages.create(
|
|
|
- model=model,
|
|
|
- messages=prompt_message_dicts,
|
|
|
- stream=stream,
|
|
|
- **model_parameters,
|
|
|
- **extra_model_kwargs
|
|
|
- )
|
|
|
+ system, prompt_message_dicts = self._convert_converse_prompt_messages(prompt_messages)
|
|
|
+ inference_config, additional_model_fields = self._convert_converse_api_model_parameters(model_parameters, stop)
|
|
|
|
|
|
if stream:
|
|
|
- return self._handle_claude_stream_response(model, credentials, response, prompt_messages)
|
|
|
-
|
|
|
- return self._handle_claude_response(model, credentials, response, prompt_messages)
|
|
|
+ response = bedrock_client.converse_stream(
|
|
|
+ modelId=model,
|
|
|
+ messages=prompt_message_dicts,
|
|
|
+ system=system,
|
|
|
+ inferenceConfig=inference_config,
|
|
|
+ additionalModelRequestFields=additional_model_fields
|
|
|
+ )
|
|
|
+ return self._handle_converse_stream_response(model, credentials, response, prompt_messages)
|
|
|
+ else:
|
|
|
+ response = bedrock_client.converse(
|
|
|
+ modelId=model,
|
|
|
+ messages=prompt_message_dicts,
|
|
|
+ system=system,
|
|
|
+ inferenceConfig=inference_config,
|
|
|
+ additionalModelRequestFields=additional_model_fields
|
|
|
+ )
|
|
|
+ return self._handle_converse_response(model, credentials, response, prompt_messages)
|
|
|
|
|
|
- def _handle_claude_response(self, model: str, credentials: dict, response: Message,
|
|
|
+ def _handle_converse_response(self, model: str, credentials: dict, response: dict,
|
|
|
prompt_messages: list[PromptMessage]) -> LLMResult:
|
|
|
"""
|
|
|
Handle llm chat response
|
|
@@ -223,17 +201,16 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
|
|
|
:param prompt_messages: prompt messages
|
|
|
:return: full response chunk generator result
|
|
|
"""
|
|
|
-
|
|
|
# transform assistant message to prompt message
|
|
|
assistant_prompt_message = AssistantPromptMessage(
|
|
|
- content=response.content[0].text
|
|
|
+ content=response['output']['message']['content'][0]['text']
|
|
|
)
|
|
|
|
|
|
# calculate num tokens
|
|
|
- if response.usage:
|
|
|
+ if response['usage']:
|
|
|
# transform usage
|
|
|
- prompt_tokens = response.usage.input_tokens
|
|
|
- completion_tokens = response.usage.output_tokens
|
|
|
+ prompt_tokens = response['usage']['inputTokens']
|
|
|
+ completion_tokens = response['usage']['outputTokens']
|
|
|
else:
|
|
|
# calculate num tokens
|
|
|
prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages)
|
|
@@ -242,17 +219,15 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
|
|
|
# transform usage
|
|
|
usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
|
|
|
|
|
|
- # transform response
|
|
|
- response = LLMResult(
|
|
|
- model=response.model,
|
|
|
+ result = LLMResult(
|
|
|
+ model=model,
|
|
|
prompt_messages=prompt_messages,
|
|
|
message=assistant_prompt_message,
|
|
|
- usage=usage
|
|
|
+ usage=usage,
|
|
|
)
|
|
|
+ return result
|
|
|
|
|
|
- return response
|
|
|
-
|
|
|
- def _handle_claude_stream_response(self, model: str, credentials: dict, response: Stream[MessageStreamEvent],
|
|
|
+ def _handle_converse_stream_response(self, model: str, credentials: dict, response: dict,
|
|
|
prompt_messages: list[PromptMessage], ) -> Generator:
|
|
|
"""
|
|
|
Handle llm chat stream response
|
|
@@ -272,14 +247,14 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
|
|
|
finish_reason = None
|
|
|
index = 0
|
|
|
|
|
|
- for chunk in response:
|
|
|
- if isinstance(chunk, MessageStartEvent):
|
|
|
- return_model = chunk.message.model
|
|
|
- input_tokens = chunk.message.usage.input_tokens
|
|
|
- elif isinstance(chunk, MessageDeltaEvent):
|
|
|
- output_tokens = chunk.usage.output_tokens
|
|
|
- finish_reason = chunk.delta.stop_reason
|
|
|
- elif isinstance(chunk, MessageStopEvent):
|
|
|
+ for chunk in response['stream']:
|
|
|
+ if 'messageStart' in chunk:
|
|
|
+ return_model = model
|
|
|
+ elif 'messageStop' in chunk:
|
|
|
+ finish_reason = chunk['messageStop']['stopReason']
|
|
|
+ elif 'metadata' in chunk:
|
|
|
+ input_tokens = chunk['metadata']['usage']['inputTokens']
|
|
|
+ output_tokens = chunk['metadata']['usage']['outputTokens']
|
|
|
usage = self._calc_response_usage(model, credentials, input_tokens, output_tokens)
|
|
|
yield LLMResultChunk(
|
|
|
model=return_model,
|
|
@@ -293,13 +268,13 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
|
|
|
usage=usage
|
|
|
)
|
|
|
)
|
|
|
- elif isinstance(chunk, ContentBlockDeltaEvent):
|
|
|
- chunk_text = chunk.delta.text if chunk.delta.text else ''
|
|
|
+ elif 'contentBlockDelta' in chunk:
|
|
|
+ chunk_text = chunk['contentBlockDelta']['delta']['text'] if chunk['contentBlockDelta']['delta']['text'] else ''
|
|
|
full_assistant_content += chunk_text
|
|
|
assistant_prompt_message = AssistantPromptMessage(
|
|
|
content=chunk_text if chunk_text else '',
|
|
|
)
|
|
|
- index = chunk.index
|
|
|
+ index = chunk['contentBlockDelta']['contentBlockIndex']
|
|
|
yield LLMResultChunk(
|
|
|
model=model,
|
|
|
prompt_messages=prompt_messages,
|
|
@@ -310,57 +285,33 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
|
|
|
)
|
|
|
except Exception as ex:
|
|
|
raise InvokeError(str(ex))
|
|
|
+
|
|
|
+ def _convert_converse_api_model_parameters(self, model_parameters: dict, stop: Optional[list[str]] = None) -> tuple[dict, dict]:
|
|
|
+ inference_config = {}
|
|
|
+ additional_model_fields = {}
|
|
|
+ if 'max_tokens' in model_parameters:
|
|
|
+ inference_config['maxTokens'] = model_parameters['max_tokens']
|
|
|
+
|
|
|
+ if 'temperature' in model_parameters:
|
|
|
+ inference_config['temperature'] = model_parameters['temperature']
|
|
|
+
|
|
|
+ if 'top_p' in model_parameters:
|
|
|
+ inference_config['topP'] = model_parameters['temperature']
|
|
|
|
|
|
- def _calc_claude_response_usage(self, model: str, credentials: dict, prompt_tokens: int, completion_tokens: int) -> LLMUsage:
|
|
|
- """
|
|
|
- Calculate response usage
|
|
|
-
|
|
|
- :param model: model name
|
|
|
- :param credentials: model credentials
|
|
|
- :param prompt_tokens: prompt tokens
|
|
|
- :param completion_tokens: completion tokens
|
|
|
- :return: usage
|
|
|
- """
|
|
|
- # get prompt price info
|
|
|
- prompt_price_info = self.get_price(
|
|
|
- model=model,
|
|
|
- credentials=credentials,
|
|
|
- price_type=PriceType.INPUT,
|
|
|
- tokens=prompt_tokens,
|
|
|
- )
|
|
|
-
|
|
|
- # get completion price info
|
|
|
- completion_price_info = self.get_price(
|
|
|
- model=model,
|
|
|
- credentials=credentials,
|
|
|
- price_type=PriceType.OUTPUT,
|
|
|
- tokens=completion_tokens
|
|
|
- )
|
|
|
-
|
|
|
- # transform usage
|
|
|
- usage = LLMUsage(
|
|
|
- prompt_tokens=prompt_tokens,
|
|
|
- prompt_unit_price=prompt_price_info.unit_price,
|
|
|
- prompt_price_unit=prompt_price_info.unit,
|
|
|
- prompt_price=prompt_price_info.total_amount,
|
|
|
- completion_tokens=completion_tokens,
|
|
|
- completion_unit_price=completion_price_info.unit_price,
|
|
|
- completion_price_unit=completion_price_info.unit,
|
|
|
- completion_price=completion_price_info.total_amount,
|
|
|
- total_tokens=prompt_tokens + completion_tokens,
|
|
|
- total_price=prompt_price_info.total_amount + completion_price_info.total_amount,
|
|
|
- currency=prompt_price_info.currency,
|
|
|
- latency=time.perf_counter() - self.started_at
|
|
|
- )
|
|
|
-
|
|
|
- return usage
|
|
|
+ if stop:
|
|
|
+ inference_config['stopSequences'] = stop
|
|
|
+
|
|
|
+ if 'top_k' in model_parameters:
|
|
|
+ additional_model_fields['top_k'] = model_parameters['top_k']
|
|
|
+
|
|
|
+ return inference_config, additional_model_fields
|
|
|
|
|
|
- def _convert_claude_prompt_messages(self, prompt_messages: list[PromptMessage]) -> tuple[str, list[dict]]:
|
|
|
+ def _convert_converse_prompt_messages(self, prompt_messages: list[PromptMessage]) -> tuple[str, list[dict]]:
|
|
|
"""
|
|
|
Convert prompt messages to dict list and system
|
|
|
"""
|
|
|
|
|
|
- system = ""
|
|
|
+ system = []
|
|
|
first_loop = True
|
|
|
for message in prompt_messages:
|
|
|
if isinstance(message, SystemPromptMessage):
|
|
@@ -375,25 +326,24 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
|
|
|
prompt_message_dicts = []
|
|
|
for message in prompt_messages:
|
|
|
if not isinstance(message, SystemPromptMessage):
|
|
|
- prompt_message_dicts.append(self._convert_claude_prompt_message_to_dict(message))
|
|
|
+ prompt_message_dicts.append(self._convert_prompt_message_to_dict(message))
|
|
|
|
|
|
return system, prompt_message_dicts
|
|
|
|
|
|
- def _convert_claude_prompt_message_to_dict(self, message: PromptMessage) -> dict:
|
|
|
+ def _convert_prompt_message_to_dict(self, message: PromptMessage) -> dict:
|
|
|
"""
|
|
|
Convert PromptMessage to dict
|
|
|
"""
|
|
|
if isinstance(message, UserPromptMessage):
|
|
|
message = cast(UserPromptMessage, message)
|
|
|
if isinstance(message.content, str):
|
|
|
- message_dict = {"role": "user", "content": message.content}
|
|
|
+ message_dict = {"role": "user", "content": [{'text': message.content}]}
|
|
|
else:
|
|
|
sub_messages = []
|
|
|
for message_content in message.content:
|
|
|
if message_content.type == PromptMessageContentType.TEXT:
|
|
|
message_content = cast(TextPromptMessageContent, message_content)
|
|
|
sub_message_dict = {
|
|
|
- "type": "text",
|
|
|
"text": message_content.data
|
|
|
}
|
|
|
sub_messages.append(sub_message_dict)
|
|
@@ -404,24 +354,24 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
|
|
|
try:
|
|
|
image_content = requests.get(message_content.data).content
|
|
|
mime_type, _ = mimetypes.guess_type(message_content.data)
|
|
|
- base64_data = base64.b64encode(image_content).decode('utf-8')
|
|
|
except Exception as ex:
|
|
|
raise ValueError(f"Failed to fetch image data from url {message_content.data}, {ex}")
|
|
|
else:
|
|
|
data_split = message_content.data.split(";base64,")
|
|
|
mime_type = data_split[0].replace("data:", "")
|
|
|
base64_data = data_split[1]
|
|
|
+ image_content = base64.b64decode(base64_data)
|
|
|
|
|
|
if mime_type not in ["image/jpeg", "image/png", "image/gif", "image/webp"]:
|
|
|
raise ValueError(f"Unsupported image type {mime_type}, "
|
|
|
f"only support image/jpeg, image/png, image/gif, and image/webp")
|
|
|
|
|
|
sub_message_dict = {
|
|
|
- "type": "image",
|
|
|
- "source": {
|
|
|
- "type": "base64",
|
|
|
- "media_type": mime_type,
|
|
|
- "data": base64_data
|
|
|
+ "image": {
|
|
|
+ "format": mime_type.replace('image/', ''),
|
|
|
+ "source": {
|
|
|
+ "bytes": image_content
|
|
|
+ }
|
|
|
}
|
|
|
}
|
|
|
sub_messages.append(sub_message_dict)
|
|
@@ -429,10 +379,10 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
|
|
|
message_dict = {"role": "user", "content": sub_messages}
|
|
|
elif isinstance(message, AssistantPromptMessage):
|
|
|
message = cast(AssistantPromptMessage, message)
|
|
|
- message_dict = {"role": "assistant", "content": message.content}
|
|
|
+ message_dict = {"role": "assistant", "content": [{'text': message.content}]}
|
|
|
elif isinstance(message, SystemPromptMessage):
|
|
|
message = cast(SystemPromptMessage, message)
|
|
|
- message_dict = {"role": "system", "content": message.content}
|
|
|
+ message_dict = [{'text': message.content}]
|
|
|
else:
|
|
|
raise ValueError(f"Got unknown type {message}")
|
|
|
|