|
@@ -1,17 +1,36 @@
|
|
|
import json
|
|
|
import logging
|
|
|
-from collections.abc import Generator
|
|
|
-from typing import Any, Optional, Union
|
|
|
+import re
|
|
|
+from collections.abc import Generator, Iterator
|
|
|
+from typing import Any, Optional, Union, cast
|
|
|
|
|
|
+# from openai.types.chat import ChatCompletion, ChatCompletionChunk
|
|
|
import boto3
|
|
|
+from sagemaker import Predictor, serializers
|
|
|
+from sagemaker.session import Session
|
|
|
|
|
|
-from core.model_runtime.entities.llm_entities import LLMMode, LLMResult
|
|
|
+from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta
|
|
|
from core.model_runtime.entities.message_entities import (
|
|
|
AssistantPromptMessage,
|
|
|
+ ImagePromptMessageContent,
|
|
|
PromptMessage,
|
|
|
+ PromptMessageContent,
|
|
|
+ PromptMessageContentType,
|
|
|
PromptMessageTool,
|
|
|
+ SystemPromptMessage,
|
|
|
+ ToolPromptMessage,
|
|
|
+ UserPromptMessage,
|
|
|
+)
|
|
|
+from core.model_runtime.entities.model_entities import (
|
|
|
+ AIModelEntity,
|
|
|
+ FetchFrom,
|
|
|
+ I18nObject,
|
|
|
+ ModelFeature,
|
|
|
+ ModelPropertyKey,
|
|
|
+ ModelType,
|
|
|
+ ParameterRule,
|
|
|
+ ParameterType,
|
|
|
)
|
|
|
-from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, I18nObject, ModelType
|
|
|
from core.model_runtime.errors.invoke import (
|
|
|
InvokeAuthorizationError,
|
|
|
InvokeBadRequestError,
|
|
@@ -25,12 +44,140 @@ from core.model_runtime.model_providers.__base.large_language_model import Large
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
+def inference(predictor, messages:list[dict[str,Any]], params:dict[str,Any], stop:list, stream=False):
|
|
|
+ """
|
|
|
+ params:
|
|
|
+ predictor : Sagemaker Predictor
|
|
|
+ messages (List[Dict[str,Any]]): message list。
|
|
|
+ messages = [
|
|
|
+ {"role": "system", "content":"please answer in Chinese"},
|
|
|
+ {"role": "user", "content": "who are you? what are you doing?"},
|
|
|
+ ]
|
|
|
+ params (Dict[str,Any]): model parameters for LLM。
|
|
|
+ stream (bool): False by default。
|
|
|
+
|
|
|
+ response:
|
|
|
+ result of inference if stream is False
|
|
|
+ Iterator of Chunks if stream is True
|
|
|
+ """
|
|
|
+ payload = {
|
|
|
+ "model" : params.get('model_name'),
|
|
|
+ "stop" : stop,
|
|
|
+ "messages": messages,
|
|
|
+ "stream" : stream,
|
|
|
+ "max_tokens" : params.get('max_new_tokens', params.get('max_tokens', 2048)),
|
|
|
+ "temperature" : params.get('temperature', 0.1),
|
|
|
+ "top_p" : params.get('top_p', 0.9),
|
|
|
+ }
|
|
|
+
|
|
|
+ if not stream:
|
|
|
+ response = predictor.predict(payload)
|
|
|
+ return response
|
|
|
+ else:
|
|
|
+ response_stream = predictor.predict_stream(payload)
|
|
|
+ return response_stream
|
|
|
|
|
|
class SageMakerLargeLanguageModel(LargeLanguageModel):
|
|
|
"""
|
|
|
Model class for Cohere large language model.
|
|
|
"""
|
|
|
sagemaker_client: Any = None
|
|
|
+ sagemaker_sess : Any = None
|
|
|
+ predictor : Any = None
|
|
|
+
|
|
|
+ def _handle_chat_generate_response(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
|
|
|
+ tools: list[PromptMessageTool],
|
|
|
+ resp: bytes) -> LLMResult:
|
|
|
+ """
|
|
|
+ handle normal chat generate response
|
|
|
+ """
|
|
|
+ resp_obj = json.loads(resp.decode('utf-8'))
|
|
|
+ resp_str = resp_obj.get('choices')[0].get('message').get('content')
|
|
|
+
|
|
|
+ if len(resp_str) == 0:
|
|
|
+ raise InvokeServerUnavailableError("Empty response")
|
|
|
+
|
|
|
+ assistant_prompt_message = AssistantPromptMessage(
|
|
|
+ content=resp_str,
|
|
|
+ tool_calls=[]
|
|
|
+ )
|
|
|
+
|
|
|
+ prompt_tokens = self._num_tokens_from_messages(messages=prompt_messages, tools=tools)
|
|
|
+ completion_tokens = self._num_tokens_from_messages(messages=[assistant_prompt_message], tools=tools)
|
|
|
+
|
|
|
+ usage = self._calc_response_usage(model=model, credentials=credentials, prompt_tokens=prompt_tokens,
|
|
|
+ completion_tokens=completion_tokens)
|
|
|
+
|
|
|
+ response = LLMResult(
|
|
|
+ model=model,
|
|
|
+ prompt_messages=prompt_messages,
|
|
|
+ system_fingerprint=None,
|
|
|
+ usage=usage,
|
|
|
+ message=assistant_prompt_message,
|
|
|
+ )
|
|
|
+
|
|
|
+ return response
|
|
|
+
|
|
|
+ def _handle_chat_stream_response(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
|
|
|
+ tools: list[PromptMessageTool],
|
|
|
+ resp: Iterator[bytes]) -> Generator:
|
|
|
+ """
|
|
|
+ handle stream chat generate response
|
|
|
+ """
|
|
|
+ full_response = ''
|
|
|
+ buffer = ""
|
|
|
+ for chunk_bytes in resp:
|
|
|
+ buffer += chunk_bytes.decode('utf-8')
|
|
|
+ last_idx = 0
|
|
|
+ for match in re.finditer(r'^data:\s*(.+?)(\n\n)', buffer):
|
|
|
+ try:
|
|
|
+ data = json.loads(match.group(1).strip())
|
|
|
+ last_idx = match.span()[1]
|
|
|
+
|
|
|
+ if "content" in data["choices"][0]["delta"]:
|
|
|
+ chunk_content = data["choices"][0]["delta"]["content"]
|
|
|
+ assistant_prompt_message = AssistantPromptMessage(
|
|
|
+ content=chunk_content,
|
|
|
+ tool_calls=[]
|
|
|
+ )
|
|
|
+
|
|
|
+ if data["choices"][0]['finish_reason'] is not None:
|
|
|
+ temp_assistant_prompt_message = AssistantPromptMessage(
|
|
|
+ content=full_response,
|
|
|
+ tool_calls=[]
|
|
|
+ )
|
|
|
+ prompt_tokens = self._num_tokens_from_messages(messages=prompt_messages, tools=tools)
|
|
|
+ completion_tokens = self._num_tokens_from_messages(messages=[temp_assistant_prompt_message], tools=[])
|
|
|
+ usage = self._calc_response_usage(model=model, credentials=credentials, prompt_tokens=prompt_tokens, completion_tokens=completion_tokens)
|
|
|
+
|
|
|
+ yield LLMResultChunk(
|
|
|
+ model=model,
|
|
|
+ prompt_messages=prompt_messages,
|
|
|
+ system_fingerprint=None,
|
|
|
+ delta=LLMResultChunkDelta(
|
|
|
+ index=0,
|
|
|
+ message=assistant_prompt_message,
|
|
|
+ finish_reason=data["choices"][0]['finish_reason'],
|
|
|
+ usage=usage
|
|
|
+ ),
|
|
|
+ )
|
|
|
+ else:
|
|
|
+ yield LLMResultChunk(
|
|
|
+ model=model,
|
|
|
+ prompt_messages=prompt_messages,
|
|
|
+ system_fingerprint=None,
|
|
|
+ delta=LLMResultChunkDelta(
|
|
|
+ index=0,
|
|
|
+ message=assistant_prompt_message
|
|
|
+ ),
|
|
|
+ )
|
|
|
+
|
|
|
+ full_response += chunk_content
|
|
|
+ except (json.JSONDecodeError, KeyError, IndexError) as e:
|
|
|
+ logger.info("json parse exception, content: {}".format(match.group(1).strip()))
|
|
|
+ pass
|
|
|
+
|
|
|
+ buffer = buffer[last_idx:]
|
|
|
|
|
|
def _invoke(self, model: str, credentials: dict,
|
|
|
prompt_messages: list[PromptMessage], model_parameters: dict,
|
|
@@ -50,9 +197,6 @@ class SageMakerLargeLanguageModel(LargeLanguageModel):
|
|
|
:param user: unique user id
|
|
|
:return: full response or stream response chunk generator result
|
|
|
"""
|
|
|
- # get model mode
|
|
|
- model_mode = self.get_model_mode(model, credentials)
|
|
|
-
|
|
|
if not self.sagemaker_client:
|
|
|
access_key = credentials.get('access_key')
|
|
|
secret_key = credentials.get('secret_key')
|
|
@@ -68,37 +212,132 @@ class SageMakerLargeLanguageModel(LargeLanguageModel):
|
|
|
else:
|
|
|
self.sagemaker_client = boto3.client("sagemaker-runtime")
|
|
|
|
|
|
+ sagemaker_session = Session(sagemaker_runtime_client=self.sagemaker_client)
|
|
|
+ self.predictor = Predictor(
|
|
|
+ endpoint_name=credentials.get('sagemaker_endpoint'),
|
|
|
+ sagemaker_session=sagemaker_session,
|
|
|
+ serializer=serializers.JSONSerializer(),
|
|
|
+ )
|
|
|
|
|
|
- sagemaker_endpoint = credentials.get('sagemaker_endpoint')
|
|
|
- response_model = self.sagemaker_client.invoke_endpoint(
|
|
|
- EndpointName=sagemaker_endpoint,
|
|
|
- Body=json.dumps(
|
|
|
- {
|
|
|
- "inputs": prompt_messages[0].content,
|
|
|
- "parameters": { "stop" : stop},
|
|
|
- "history" : []
|
|
|
- }
|
|
|
- ),
|
|
|
- ContentType="application/json",
|
|
|
- )
|
|
|
|
|
|
- assistant_text = response_model['Body'].read().decode('utf8')
|
|
|
+ messages:list[dict[str,Any]] = [ {"role": p.role.value, "content": p.content} for p in prompt_messages ]
|
|
|
+ response = inference(predictor=self.predictor, messages=messages, params=model_parameters, stop=stop, stream=stream)
|
|
|
|
|
|
- # transform assistant message to prompt message
|
|
|
- assistant_prompt_message = AssistantPromptMessage(
|
|
|
- content=assistant_text
|
|
|
- )
|
|
|
+ if stream:
|
|
|
+ if tools and len(tools) > 0:
|
|
|
+ raise InvokeBadRequestError(f"{model}'s tool calls does not support stream mode")
|
|
|
|
|
|
- usage = self._calc_response_usage(model, credentials, 0, 0)
|
|
|
+ return self._handle_chat_stream_response(model=model, credentials=credentials,
|
|
|
+ prompt_messages=prompt_messages,
|
|
|
+ tools=tools, resp=response)
|
|
|
+ return self._handle_chat_generate_response(model=model, credentials=credentials,
|
|
|
+ prompt_messages=prompt_messages,
|
|
|
+ tools=tools, resp=response)
|
|
|
|
|
|
- response = LLMResult(
|
|
|
- model=model,
|
|
|
- prompt_messages=prompt_messages,
|
|
|
- message=assistant_prompt_message,
|
|
|
- usage=usage
|
|
|
- )
|
|
|
+ def _convert_prompt_message_to_dict(self, message: PromptMessage) -> dict:
|
|
|
+ """
|
|
|
+ Convert PromptMessage to dict for OpenAI Compatibility API
|
|
|
+ """
|
|
|
+ if isinstance(message, UserPromptMessage):
|
|
|
+ message = cast(UserPromptMessage, message)
|
|
|
+ if isinstance(message.content, str):
|
|
|
+ message_dict = {"role": "user", "content": message.content}
|
|
|
+ else:
|
|
|
+ sub_messages = []
|
|
|
+ for message_content in message.content:
|
|
|
+ if message_content.type == PromptMessageContentType.TEXT:
|
|
|
+ message_content = cast(PromptMessageContent, message_content)
|
|
|
+ sub_message_dict = {
|
|
|
+ "type": "text",
|
|
|
+ "text": message_content.data
|
|
|
+ }
|
|
|
+ sub_messages.append(sub_message_dict)
|
|
|
+ elif message_content.type == PromptMessageContentType.IMAGE:
|
|
|
+ message_content = cast(ImagePromptMessageContent, message_content)
|
|
|
+ sub_message_dict = {
|
|
|
+ "type": "image_url",
|
|
|
+ "image_url": {
|
|
|
+ "url": message_content.data,
|
|
|
+ "detail": message_content.detail.value
|
|
|
+ }
|
|
|
+ }
|
|
|
+ sub_messages.append(sub_message_dict)
|
|
|
+ message_dict = {"role": "user", "content": sub_messages}
|
|
|
+ elif isinstance(message, AssistantPromptMessage):
|
|
|
+ message = cast(AssistantPromptMessage, message)
|
|
|
+ message_dict = {"role": "assistant", "content": message.content}
|
|
|
+ if message.tool_calls and len(message.tool_calls) > 0:
|
|
|
+ message_dict["function_call"] = {
|
|
|
+ "name": message.tool_calls[0].function.name,
|
|
|
+ "arguments": message.tool_calls[0].function.arguments
|
|
|
+ }
|
|
|
+ elif isinstance(message, SystemPromptMessage):
|
|
|
+ message = cast(SystemPromptMessage, message)
|
|
|
+ message_dict = {"role": "system", "content": message.content}
|
|
|
+ elif isinstance(message, ToolPromptMessage):
|
|
|
+ message = cast(ToolPromptMessage, message)
|
|
|
+ message_dict = {"tool_call_id": message.tool_call_id, "role": "tool", "content": message.content}
|
|
|
+ else:
|
|
|
+ raise ValueError(f"Unknown message type {type(message)}")
|
|
|
+
|
|
|
+ return message_dict
|
|
|
+
|
|
|
+ def _num_tokens_from_messages(self, messages: list[PromptMessage], tools: list[PromptMessageTool],
|
|
|
+ is_completion_model: bool = False) -> int:
|
|
|
+ def tokens(text: str):
|
|
|
+ return self._get_num_tokens_by_gpt2(text)
|
|
|
+
|
|
|
+ if is_completion_model:
|
|
|
+ return sum(tokens(str(message.content)) for message in messages)
|
|
|
+
|
|
|
+ tokens_per_message = 3
|
|
|
+ tokens_per_name = 1
|
|
|
+
|
|
|
+ num_tokens = 0
|
|
|
+ messages_dict = [self._convert_prompt_message_to_dict(m) for m in messages]
|
|
|
+ for message in messages_dict:
|
|
|
+ num_tokens += tokens_per_message
|
|
|
+ for key, value in message.items():
|
|
|
+ if isinstance(value, list):
|
|
|
+ text = ''
|
|
|
+ for item in value:
|
|
|
+ if isinstance(item, dict) and item['type'] == 'text':
|
|
|
+ text += item['text']
|
|
|
+
|
|
|
+ value = text
|
|
|
+
|
|
|
+ if key == "tool_calls":
|
|
|
+ for tool_call in value:
|
|
|
+ for t_key, t_value in tool_call.items():
|
|
|
+ num_tokens += tokens(t_key)
|
|
|
+ if t_key == "function":
|
|
|
+ for f_key, f_value in t_value.items():
|
|
|
+ num_tokens += tokens(f_key)
|
|
|
+ num_tokens += tokens(f_value)
|
|
|
+ else:
|
|
|
+ num_tokens += tokens(t_key)
|
|
|
+ num_tokens += tokens(t_value)
|
|
|
+ if key == "function_call":
|
|
|
+ for t_key, t_value in value.items():
|
|
|
+ num_tokens += tokens(t_key)
|
|
|
+ if t_key == "function":
|
|
|
+ for f_key, f_value in t_value.items():
|
|
|
+ num_tokens += tokens(f_key)
|
|
|
+ num_tokens += tokens(f_value)
|
|
|
+ else:
|
|
|
+ num_tokens += tokens(t_key)
|
|
|
+ num_tokens += tokens(t_value)
|
|
|
+ else:
|
|
|
+ num_tokens += tokens(str(value))
|
|
|
|
|
|
- return response
|
|
|
+ if key == "name":
|
|
|
+ num_tokens += tokens_per_name
|
|
|
+ num_tokens += 3
|
|
|
+
|
|
|
+ if tools:
|
|
|
+ num_tokens += self._num_tokens_for_tools(tools)
|
|
|
+
|
|
|
+ return num_tokens
|
|
|
|
|
|
def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
|
|
|
tools: Optional[list[PromptMessageTool]] = None) -> int:
|
|
@@ -112,10 +351,8 @@ class SageMakerLargeLanguageModel(LargeLanguageModel):
|
|
|
:return:
|
|
|
"""
|
|
|
# get model mode
|
|
|
- model_mode = self.get_model_mode(model)
|
|
|
-
|
|
|
try:
|
|
|
- return 0
|
|
|
+ return self._num_tokens_from_messages(prompt_messages, tools)
|
|
|
except Exception as e:
|
|
|
raise self._transform_invoke_error(e)
|
|
|
|
|
@@ -129,7 +366,7 @@ class SageMakerLargeLanguageModel(LargeLanguageModel):
|
|
|
"""
|
|
|
try:
|
|
|
# get model mode
|
|
|
- model_mode = self.get_model_mode(model)
|
|
|
+ pass
|
|
|
except Exception as ex:
|
|
|
raise CredentialsValidateFailedError(str(ex))
|
|
|
|
|
@@ -200,13 +437,7 @@ class SageMakerLargeLanguageModel(LargeLanguageModel):
|
|
|
)
|
|
|
]
|
|
|
|
|
|
- completion_type = LLMMode.value_of(credentials["mode"])
|
|
|
-
|
|
|
- if completion_type == LLMMode.CHAT:
|
|
|
- print(f"completion_type : {LLMMode.CHAT.value}")
|
|
|
-
|
|
|
- if completion_type == LLMMode.COMPLETION:
|
|
|
- print(f"completion_type : {LLMMode.COMPLETION.value}")
|
|
|
+ completion_type = LLMMode.value_of(credentials["mode"]).value
|
|
|
|
|
|
features = []
|
|
|
|