|
@@ -1,8 +1,10 @@
|
|
|
import logging
|
|
|
from collections.abc import Generator
|
|
|
|
|
|
+from volcenginesdkarkruntime.types.chat import ChatCompletion, ChatCompletionChunk
|
|
|
+
|
|
|
from core.model_runtime.entities.common_entities import I18nObject
|
|
|
-from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
|
|
|
+from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
|
|
|
from core.model_runtime.entities.message_entities import (
|
|
|
AssistantPromptMessage,
|
|
|
PromptMessage,
|
|
@@ -27,19 +29,21 @@ from core.model_runtime.errors.invoke import (
|
|
|
)
|
|
|
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
|
|
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
|
|
-from core.model_runtime.model_providers.volcengine_maas.client import MaaSClient
|
|
|
-from core.model_runtime.model_providers.volcengine_maas.errors import (
|
|
|
+from core.model_runtime.model_providers.volcengine_maas.client import ArkClientV3
|
|
|
+from core.model_runtime.model_providers.volcengine_maas.legacy.client import MaaSClient
|
|
|
+from core.model_runtime.model_providers.volcengine_maas.legacy.errors import (
|
|
|
AuthErrors,
|
|
|
BadRequestErrors,
|
|
|
ConnectionErrors,
|
|
|
+ MaasException,
|
|
|
RateLimitErrors,
|
|
|
ServerUnavailableErrors,
|
|
|
)
|
|
|
from core.model_runtime.model_providers.volcengine_maas.llm.models import (
|
|
|
get_model_config,
|
|
|
get_v2_req_params,
|
|
|
+ get_v3_req_params,
|
|
|
)
|
|
|
-from core.model_runtime.model_providers.volcengine_maas.volc_sdk import MaasException
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
@@ -49,13 +53,20 @@ class VolcengineMaaSLargeLanguageModel(LargeLanguageModel):
|
|
|
model_parameters: dict, tools: list[PromptMessageTool] | None = None,
|
|
|
stop: list[str] | None = None, stream: bool = True, user: str | None = None) \
|
|
|
-> LLMResult | Generator:
|
|
|
- return self._generate(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user)
|
|
|
+ if ArkClientV3.is_legacy(credentials):
|
|
|
+ return self._generate_v2(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user)
|
|
|
+ return self._generate_v3(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user)
|
|
|
|
|
|
def validate_credentials(self, model: str, credentials: dict) -> None:
|
|
|
"""
|
|
|
Validate credentials
|
|
|
"""
|
|
|
- # ping
|
|
|
+ if ArkClientV3.is_legacy(credentials):
|
|
|
+ return self._validate_credentials_v2(credentials)
|
|
|
+ return self._validate_credentials_v3(credentials)
|
|
|
+
|
|
|
+ @staticmethod
|
|
|
+ def _validate_credentials_v2(credentials: dict) -> None:
|
|
|
client = MaaSClient.from_credential(credentials)
|
|
|
try:
|
|
|
client.chat(
|
|
@@ -70,21 +81,40 @@ class VolcengineMaaSLargeLanguageModel(LargeLanguageModel):
|
|
|
except MaasException as e:
|
|
|
raise CredentialsValidateFailedError(e.message)
|
|
|
|
|
|
+ @staticmethod
|
|
|
+ def _validate_credentials_v3(credentials: dict) -> None:
|
|
|
+ client = ArkClientV3.from_credentials(credentials)
|
|
|
+ try:
|
|
|
+ client.chat(max_tokens=16, temperature=0.7, top_p=0.9,
|
|
|
+ messages=[UserPromptMessage(content='ping\nAnswer: ')], )
|
|
|
+ except Exception as e:
|
|
|
+ raise CredentialsValidateFailedError(e)
|
|
|
+
|
|
|
def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
|
|
|
tools: list[PromptMessageTool] | None = None) -> int:
|
|
|
- if len(prompt_messages) == 0:
|
|
|
+ if ArkClientV3.is_legacy(credentials):
|
|
|
+ return self._get_num_tokens_v2(prompt_messages)
|
|
|
+ return self._get_num_tokens_v3(prompt_messages)
|
|
|
+
|
|
|
+ def _get_num_tokens_v2(self, messages: list[PromptMessage]) -> int:
|
|
|
+ if len(messages) == 0:
|
|
|
return 0
|
|
|
- return self._num_tokens_from_messages(prompt_messages)
|
|
|
+ num_tokens = 0
|
|
|
+ messages_dict = [
|
|
|
+ MaaSClient.convert_prompt_message_to_maas_message(m) for m in messages]
|
|
|
+ for message in messages_dict:
|
|
|
+ for key, value in message.items():
|
|
|
+ num_tokens += self._get_num_tokens_by_gpt2(str(key))
|
|
|
+ num_tokens += self._get_num_tokens_by_gpt2(str(value))
|
|
|
|
|
|
- def _num_tokens_from_messages(self, messages: list[PromptMessage]) -> int:
|
|
|
- """
|
|
|
- Calculate num tokens.
|
|
|
+ return num_tokens
|
|
|
|
|
|
- :param messages: messages
|
|
|
- """
|
|
|
+ def _get_num_tokens_v3(self, messages: list[PromptMessage]) -> int:
|
|
|
+ if len(messages) == 0:
|
|
|
+ return 0
|
|
|
num_tokens = 0
|
|
|
messages_dict = [
|
|
|
- MaaSClient.convert_prompt_message_to_maas_message(m) for m in messages]
|
|
|
+ ArkClientV3.convert_prompt_message(m) for m in messages]
|
|
|
for message in messages_dict:
|
|
|
for key, value in message.items():
|
|
|
num_tokens += self._get_num_tokens_by_gpt2(str(key))
|
|
@@ -92,9 +122,9 @@ class VolcengineMaaSLargeLanguageModel(LargeLanguageModel):
|
|
|
|
|
|
return num_tokens
|
|
|
|
|
|
- def _generate(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
|
|
|
- model_parameters: dict, tools: list[PromptMessageTool] | None = None,
|
|
|
- stop: list[str] | None = None, stream: bool = True, user: str | None = None) \
|
|
|
+ def _generate_v2(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
|
|
|
+ model_parameters: dict, tools: list[PromptMessageTool] | None = None,
|
|
|
+ stop: list[str] | None = None, stream: bool = True, user: str | None = None) \
|
|
|
-> LLMResult | Generator:
|
|
|
|
|
|
client = MaaSClient.from_credential(credentials)
|
|
@@ -106,77 +136,151 @@ class VolcengineMaaSLargeLanguageModel(LargeLanguageModel):
|
|
|
]
|
|
|
resp = MaaSClient.wrap_exception(
|
|
|
lambda: client.chat(req_params, prompt_messages, stream, **extra_model_kwargs))
|
|
|
- if not stream:
|
|
|
- return self._handle_chat_response(model, credentials, prompt_messages, resp)
|
|
|
- return self._handle_stream_chat_response(model, credentials, prompt_messages, resp)
|
|
|
|
|
|
- def _handle_stream_chat_response(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], resp: Generator) -> Generator:
|
|
|
- for index, r in enumerate(resp):
|
|
|
- choices = r['choices']
|
|
|
+ def _handle_stream_chat_response() -> Generator:
|
|
|
+ for index, r in enumerate(resp):
|
|
|
+ choices = r['choices']
|
|
|
+ if not choices:
|
|
|
+ continue
|
|
|
+ choice = choices[0]
|
|
|
+ message = choice['message']
|
|
|
+ usage = None
|
|
|
+ if r.get('usage'):
|
|
|
+ usage = self._calc_response_usage(model=model, credentials=credentials,
|
|
|
+ prompt_tokens=r['usage']['prompt_tokens'],
|
|
|
+ completion_tokens=r['usage']['completion_tokens']
|
|
|
+ )
|
|
|
+ yield LLMResultChunk(
|
|
|
+ model=model,
|
|
|
+ prompt_messages=prompt_messages,
|
|
|
+ delta=LLMResultChunkDelta(
|
|
|
+ index=index,
|
|
|
+ message=AssistantPromptMessage(
|
|
|
+ content=message['content'] if message['content'] else '',
|
|
|
+ tool_calls=[]
|
|
|
+ ),
|
|
|
+ usage=usage,
|
|
|
+ finish_reason=choice.get('finish_reason'),
|
|
|
+ ),
|
|
|
+ )
|
|
|
+
|
|
|
+ def _handle_chat_response() -> LLMResult:
|
|
|
+ choices = resp['choices']
|
|
|
if not choices:
|
|
|
- continue
|
|
|
+ raise ValueError("No choices found")
|
|
|
+
|
|
|
choice = choices[0]
|
|
|
message = choice['message']
|
|
|
- usage = None
|
|
|
- if r.get('usage'):
|
|
|
- usage = self._calc_usage(model, credentials, r['usage'])
|
|
|
- yield LLMResultChunk(
|
|
|
+
|
|
|
+ # parse tool calls
|
|
|
+ tool_calls = []
|
|
|
+ if message['tool_calls']:
|
|
|
+ for call in message['tool_calls']:
|
|
|
+ tool_call = AssistantPromptMessage.ToolCall(
|
|
|
+ id=call['function']['name'],
|
|
|
+ type=call['type'],
|
|
|
+ function=AssistantPromptMessage.ToolCall.ToolCallFunction(
|
|
|
+ name=call['function']['name'],
|
|
|
+ arguments=call['function']['arguments']
|
|
|
+ )
|
|
|
+ )
|
|
|
+ tool_calls.append(tool_call)
|
|
|
+
|
|
|
+ usage = resp['usage']
|
|
|
+ return LLMResult(
|
|
|
model=model,
|
|
|
prompt_messages=prompt_messages,
|
|
|
- delta=LLMResultChunkDelta(
|
|
|
- index=index,
|
|
|
- message=AssistantPromptMessage(
|
|
|
- content=message['content'] if message['content'] else '',
|
|
|
- tool_calls=[]
|
|
|
- ),
|
|
|
- usage=usage,
|
|
|
- finish_reason=choice.get('finish_reason'),
|
|
|
+ message=AssistantPromptMessage(
|
|
|
+ content=message['content'] if message['content'] else '',
|
|
|
+ tool_calls=tool_calls,
|
|
|
),
|
|
|
+ usage=self._calc_response_usage(model=model, credentials=credentials,
|
|
|
+ prompt_tokens=usage['prompt_tokens'],
|
|
|
+ completion_tokens=usage['completion_tokens']
|
|
|
+ ),
|
|
|
)
|
|
|
|
|
|
- def _handle_chat_response(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], resp: dict) -> LLMResult:
|
|
|
- choices = resp['choices']
|
|
|
- if not choices:
|
|
|
- return
|
|
|
- choice = choices[0]
|
|
|
- message = choice['message']
|
|
|
-
|
|
|
- # parse tool calls
|
|
|
- tool_calls = []
|
|
|
- if message['tool_calls']:
|
|
|
- for call in message['tool_calls']:
|
|
|
- tool_call = AssistantPromptMessage.ToolCall(
|
|
|
- id=call['function']['name'],
|
|
|
- type=call['type'],
|
|
|
- function=AssistantPromptMessage.ToolCall.ToolCallFunction(
|
|
|
- name=call['function']['name'],
|
|
|
- arguments=call['function']['arguments']
|
|
|
- )
|
|
|
+ if not stream:
|
|
|
+ return _handle_chat_response()
|
|
|
+ return _handle_stream_chat_response()
|
|
|
+
|
|
|
+ def _generate_v3(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
|
|
|
+ model_parameters: dict, tools: list[PromptMessageTool] | None = None,
|
|
|
+ stop: list[str] | None = None, stream: bool = True, user: str | None = None) \
|
|
|
+ -> LLMResult | Generator:
|
|
|
+
|
|
|
+ client = ArkClientV3.from_credentials(credentials)
|
|
|
+ req_params = get_v3_req_params(credentials, model_parameters, stop)
|
|
|
+ if tools:
|
|
|
+ req_params['tools'] = tools
|
|
|
+
|
|
|
+ def _handle_stream_chat_response(chunks: Generator[ChatCompletionChunk]) -> Generator:
|
|
|
+ for chunk in chunks:
|
|
|
+ if not chunk.choices:
|
|
|
+ continue
|
|
|
+ choice = chunk.choices[0]
|
|
|
+
|
|
|
+ yield LLMResultChunk(
|
|
|
+ model=model,
|
|
|
+ prompt_messages=prompt_messages,
|
|
|
+ delta=LLMResultChunkDelta(
|
|
|
+ index=choice.index,
|
|
|
+ message=AssistantPromptMessage(
|
|
|
+ content=choice.delta.content,
|
|
|
+ tool_calls=[]
|
|
|
+ ),
|
|
|
+ usage=self._calc_response_usage(model=model, credentials=credentials,
|
|
|
+ prompt_tokens=chunk.usage.prompt_tokens,
|
|
|
+ completion_tokens=chunk.usage.completion_tokens
|
|
|
+ ) if chunk.usage else None,
|
|
|
+ finish_reason=choice.finish_reason,
|
|
|
+ ),
|
|
|
)
|
|
|
- tool_calls.append(tool_call)
|
|
|
|
|
|
- return LLMResult(
|
|
|
- model=model,
|
|
|
- prompt_messages=prompt_messages,
|
|
|
- message=AssistantPromptMessage(
|
|
|
- content=message['content'] if message['content'] else '',
|
|
|
- tool_calls=tool_calls,
|
|
|
- ),
|
|
|
- usage=self._calc_usage(model, credentials, resp['usage']),
|
|
|
- )
|
|
|
+ def _handle_chat_response(resp: ChatCompletion) -> LLMResult:
|
|
|
+ choice = resp.choices[0]
|
|
|
+ message = choice.message
|
|
|
+ # parse tool calls
|
|
|
+ tool_calls = []
|
|
|
+ if message.tool_calls:
|
|
|
+ for call in message.tool_calls:
|
|
|
+ tool_call = AssistantPromptMessage.ToolCall(
|
|
|
+ id=call.id,
|
|
|
+ type=call.type,
|
|
|
+ function=AssistantPromptMessage.ToolCall.ToolCallFunction(
|
|
|
+ name=call.function.name,
|
|
|
+ arguments=call.function.arguments
|
|
|
+ )
|
|
|
+ )
|
|
|
+ tool_calls.append(tool_call)
|
|
|
+
|
|
|
+ usage = resp.usage
|
|
|
+ return LLMResult(
|
|
|
+ model=model,
|
|
|
+ prompt_messages=prompt_messages,
|
|
|
+ message=AssistantPromptMessage(
|
|
|
+ content=message.content if message.content else "",
|
|
|
+ tool_calls=tool_calls,
|
|
|
+ ),
|
|
|
+ usage=self._calc_response_usage(model=model, credentials=credentials,
|
|
|
+ prompt_tokens=usage.prompt_tokens,
|
|
|
+ completion_tokens=usage.completion_tokens
|
|
|
+ ),
|
|
|
+ )
|
|
|
+
|
|
|
+ if not stream:
|
|
|
+ resp = client.chat(prompt_messages, **req_params)
|
|
|
+ return _handle_chat_response(resp)
|
|
|
|
|
|
- def _calc_usage(self, model: str, credentials: dict, usage: dict) -> LLMUsage:
|
|
|
- return self._calc_response_usage(model=model, credentials=credentials,
|
|
|
- prompt_tokens=usage['prompt_tokens'],
|
|
|
- completion_tokens=usage['completion_tokens']
|
|
|
- )
|
|
|
+ chunks = client.stream_chat(prompt_messages, **req_params)
|
|
|
+ return _handle_stream_chat_response(chunks)
|
|
|
|
|
|
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None:
|
|
|
"""
|
|
|
used to define customizable model schema
|
|
|
"""
|
|
|
model_config = get_model_config(credentials)
|
|
|
-
|
|
|
+
|
|
|
rules = [
|
|
|
ParameterRule(
|
|
|
name='temperature',
|
|
@@ -212,7 +316,7 @@ class VolcengineMaaSLargeLanguageModel(LargeLanguageModel):
|
|
|
use_template='presence_penalty',
|
|
|
label=I18nObject(
|
|
|
en_US='Presence Penalty',
|
|
|
- zh_Hans= '存在惩罚',
|
|
|
+ zh_Hans='存在惩罚',
|
|
|
),
|
|
|
min=-2.0,
|
|
|
max=2.0,
|
|
@@ -222,8 +326,8 @@ class VolcengineMaaSLargeLanguageModel(LargeLanguageModel):
|
|
|
type=ParameterType.FLOAT,
|
|
|
use_template='frequency_penalty',
|
|
|
label=I18nObject(
|
|
|
- en_US= 'Frequency Penalty',
|
|
|
- zh_Hans= '频率惩罚',
|
|
|
+ en_US='Frequency Penalty',
|
|
|
+ zh_Hans='频率惩罚',
|
|
|
),
|
|
|
min=-2.0,
|
|
|
max=2.0,
|
|
@@ -245,7 +349,7 @@ class VolcengineMaaSLargeLanguageModel(LargeLanguageModel):
|
|
|
model_properties = {}
|
|
|
model_properties[ModelPropertyKey.CONTEXT_SIZE] = model_config.properties.context_size
|
|
|
model_properties[ModelPropertyKey.MODE] = model_config.properties.mode.value
|
|
|
-
|
|
|
+
|
|
|
entity = AIModelEntity(
|
|
|
model=model,
|
|
|
label=I18nObject(
|