|
@@ -57,7 +57,7 @@ class BaiduAccessToken:
|
|
raise RateLimitReachedError(f'Rate limit reached: {resp["error_description"]}')
|
|
raise RateLimitReachedError(f'Rate limit reached: {resp["error_description"]}')
|
|
else:
|
|
else:
|
|
raise Exception(f'Unknown error: {resp["error_description"]}')
|
|
raise Exception(f'Unknown error: {resp["error_description"]}')
|
|
-
|
|
|
|
|
|
+
|
|
return resp['access_token']
|
|
return resp['access_token']
|
|
|
|
|
|
@staticmethod
|
|
@staticmethod
|
|
@@ -114,7 +114,7 @@ class ErnieMessage:
|
|
'role': self.role,
|
|
'role': self.role,
|
|
'content': self.content,
|
|
'content': self.content,
|
|
}
|
|
}
|
|
-
|
|
|
|
|
|
+
|
|
def __init__(self, content: str, role: str = 'user') -> None:
|
|
def __init__(self, content: str, role: str = 'user') -> None:
|
|
self.content = content
|
|
self.content = content
|
|
self.role = role
|
|
self.role = role
|
|
@@ -131,6 +131,7 @@ class ErnieBotModel:
|
|
'ernie-3.5-4k-0205': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-3.5-4k-0205',
|
|
'ernie-3.5-4k-0205': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-3.5-4k-0205',
|
|
'ernie-3.5-128k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-3.5-128k',
|
|
'ernie-3.5-128k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-3.5-128k',
|
|
'ernie-4.0-8k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions_pro',
|
|
'ernie-4.0-8k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions_pro',
|
|
|
|
+ 'ernie-4.0-8k-latest': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions_pro',
|
|
'ernie-speed-8k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie_speed',
|
|
'ernie-speed-8k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie_speed',
|
|
'ernie-speed-128k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-speed-128k',
|
|
'ernie-speed-128k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-speed-128k',
|
|
'ernie-speed-appbuilder': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ai_apaas',
|
|
'ernie-speed-appbuilder': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ai_apaas',
|
|
@@ -157,7 +158,7 @@ class ErnieBotModel:
|
|
self.api_key = api_key
|
|
self.api_key = api_key
|
|
self.secret_key = secret_key
|
|
self.secret_key = secret_key
|
|
|
|
|
|
- def generate(self, model: str, stream: bool, messages: list[ErnieMessage],
|
|
|
|
|
|
+ def generate(self, model: str, stream: bool, messages: list[ErnieMessage],
|
|
parameters: dict[str, Any], timeout: int, tools: list[PromptMessageTool], \
|
|
parameters: dict[str, Any], timeout: int, tools: list[PromptMessageTool], \
|
|
stop: list[str], user: str) \
|
|
stop: list[str], user: str) \
|
|
-> Union[Generator[ErnieMessage, None, None], ErnieMessage]:
|
|
-> Union[Generator[ErnieMessage, None, None], ErnieMessage]:
|
|
@@ -189,7 +190,7 @@ class ErnieBotModel:
|
|
if stream:
|
|
if stream:
|
|
return self._handle_chat_stream_generate_response(resp)
|
|
return self._handle_chat_stream_generate_response(resp)
|
|
return self._handle_chat_generate_response(resp)
|
|
return self._handle_chat_generate_response(resp)
|
|
-
|
|
|
|
|
|
+
|
|
def _handle_error(self, code: int, msg: str):
|
|
def _handle_error(self, code: int, msg: str):
|
|
error_map = {
|
|
error_map = {
|
|
1: InternalServerError,
|
|
1: InternalServerError,
|
|
@@ -234,15 +235,15 @@ class ErnieBotModel:
|
|
def _get_access_token(self) -> str:
|
|
def _get_access_token(self) -> str:
|
|
token = BaiduAccessToken.get_access_token(self.api_key, self.secret_key)
|
|
token = BaiduAccessToken.get_access_token(self.api_key, self.secret_key)
|
|
return token.access_token
|
|
return token.access_token
|
|
-
|
|
|
|
|
|
+
|
|
def _copy_messages(self, messages: list[ErnieMessage]) -> list[ErnieMessage]:
|
|
def _copy_messages(self, messages: list[ErnieMessage]) -> list[ErnieMessage]:
|
|
return [ErnieMessage(message.content, message.role) for message in messages]
|
|
return [ErnieMessage(message.content, message.role) for message in messages]
|
|
|
|
|
|
- def _check_parameters(self, model: str, parameters: dict[str, Any],
|
|
|
|
|
|
+ def _check_parameters(self, model: str, parameters: dict[str, Any],
|
|
tools: list[PromptMessageTool], stop: list[str]) -> None:
|
|
tools: list[PromptMessageTool], stop: list[str]) -> None:
|
|
if model not in self.api_bases:
|
|
if model not in self.api_bases:
|
|
raise BadRequestError(f'Invalid model: {model}')
|
|
raise BadRequestError(f'Invalid model: {model}')
|
|
-
|
|
|
|
|
|
+
|
|
# if model not in self.function_calling_supports and tools is not None and len(tools) > 0:
|
|
# if model not in self.function_calling_supports and tools is not None and len(tools) > 0:
|
|
# raise BadRequestError(f'Model {model} does not support calling function.')
|
|
# raise BadRequestError(f'Model {model} does not support calling function.')
|
|
# ErnieBot supports function calling, however, there is lots of limitations.
|
|
# ErnieBot supports function calling, however, there is lots of limitations.
|
|
@@ -259,32 +260,32 @@ class ErnieBotModel:
|
|
for s in stop:
|
|
for s in stop:
|
|
if len(s) > 20:
|
|
if len(s) > 20:
|
|
raise BadRequestError('stop item should not exceed 20 characters.')
|
|
raise BadRequestError('stop item should not exceed 20 characters.')
|
|
-
|
|
|
|
|
|
+
|
|
def _build_request_body(self, model: str, messages: list[ErnieMessage], stream: bool, parameters: dict[str, Any],
|
|
def _build_request_body(self, model: str, messages: list[ErnieMessage], stream: bool, parameters: dict[str, Any],
|
|
tools: list[PromptMessageTool], stop: list[str], user: str) -> dict[str, Any]:
|
|
tools: list[PromptMessageTool], stop: list[str], user: str) -> dict[str, Any]:
|
|
# if model in self.function_calling_supports:
|
|
# if model in self.function_calling_supports:
|
|
# return self._build_function_calling_request_body(model, messages, parameters, tools, stop, user)
|
|
# return self._build_function_calling_request_body(model, messages, parameters, tools, stop, user)
|
|
return self._build_chat_request_body(model, messages, stream, parameters, stop, user)
|
|
return self._build_chat_request_body(model, messages, stream, parameters, stop, user)
|
|
-
|
|
|
|
|
|
+
|
|
def _build_function_calling_request_body(self, model: str, messages: list[ErnieMessage], stream: bool,
|
|
def _build_function_calling_request_body(self, model: str, messages: list[ErnieMessage], stream: bool,
|
|
- parameters: dict[str, Any], tools: list[PromptMessageTool],
|
|
|
|
|
|
+ parameters: dict[str, Any], tools: list[PromptMessageTool],
|
|
stop: list[str], user: str) \
|
|
stop: list[str], user: str) \
|
|
-> dict[str, Any]:
|
|
-> dict[str, Any]:
|
|
if len(messages) % 2 == 0:
|
|
if len(messages) % 2 == 0:
|
|
raise BadRequestError('The number of messages should be odd.')
|
|
raise BadRequestError('The number of messages should be odd.')
|
|
if messages[0].role == 'function':
|
|
if messages[0].role == 'function':
|
|
raise BadRequestError('The first message should be user message.')
|
|
raise BadRequestError('The first message should be user message.')
|
|
-
|
|
|
|
|
|
+
|
|
"""
|
|
"""
|
|
TODO: implement function calling
|
|
TODO: implement function calling
|
|
"""
|
|
"""
|
|
|
|
|
|
- def _build_chat_request_body(self, model: str, messages: list[ErnieMessage], stream: bool,
|
|
|
|
|
|
+ def _build_chat_request_body(self, model: str, messages: list[ErnieMessage], stream: bool,
|
|
parameters: dict[str, Any], stop: list[str], user: str) \
|
|
parameters: dict[str, Any], stop: list[str], user: str) \
|
|
-> dict[str, Any]:
|
|
-> dict[str, Any]:
|
|
if len(messages) == 0:
|
|
if len(messages) == 0:
|
|
raise BadRequestError('The number of messages should not be zero.')
|
|
raise BadRequestError('The number of messages should not be zero.')
|
|
-
|
|
|
|
|
|
+
|
|
# check if the first element is system, shift it
|
|
# check if the first element is system, shift it
|
|
system_message = ''
|
|
system_message = ''
|
|
if messages[0].role == 'system':
|
|
if messages[0].role == 'system':
|
|
@@ -313,7 +314,7 @@ class ErnieBotModel:
|
|
body['system'] = system_message
|
|
body['system'] = system_message
|
|
|
|
|
|
return body
|
|
return body
|
|
-
|
|
|
|
|
|
+
|
|
def _handle_chat_generate_response(self, response: Response) -> ErnieMessage:
|
|
def _handle_chat_generate_response(self, response: Response) -> ErnieMessage:
|
|
data = response.json()
|
|
data = response.json()
|
|
if 'error_code' in data:
|
|
if 'error_code' in data:
|
|
@@ -349,7 +350,7 @@ class ErnieBotModel:
|
|
self._handle_error(code, msg)
|
|
self._handle_error(code, msg)
|
|
except Exception as e:
|
|
except Exception as e:
|
|
raise InternalServerError(f'Failed to parse response: {e}')
|
|
raise InternalServerError(f'Failed to parse response: {e}')
|
|
-
|
|
|
|
|
|
+
|
|
if line.startswith('data:'):
|
|
if line.startswith('data:'):
|
|
line = line[5:].strip()
|
|
line = line[5:].strip()
|
|
else:
|
|
else:
|
|
@@ -361,7 +362,7 @@ class ErnieBotModel:
|
|
data = loads(line)
|
|
data = loads(line)
|
|
except Exception as e:
|
|
except Exception as e:
|
|
raise InternalServerError(f'Failed to parse response: {e}')
|
|
raise InternalServerError(f'Failed to parse response: {e}')
|
|
-
|
|
|
|
|
|
+
|
|
result = data['result']
|
|
result = data['result']
|
|
is_end = data['is_end']
|
|
is_end = data['is_end']
|
|
|
|
|
|
@@ -379,4 +380,4 @@ class ErnieBotModel:
|
|
yield message
|
|
yield message
|
|
else:
|
|
else:
|
|
message = ErnieMessage(content=result, role='assistant')
|
|
message = ErnieMessage(content=result, role='assistant')
|
|
- yield message
|
|
|
|
|
|
+ yield message
|