Browse Source

feat: Introduce Ark SDK v3 and ensure compatibility with models of SDK v2 (#7579)

Co-authored-by: crazywoola <427733928@qq.com>
sino 8 months ago
parent
commit
efc136cce5

+ 152 - 85
api/core/model_runtime/model_providers/volcengine_maas/client.py

@@ -1,6 +1,25 @@
 import re
-from collections.abc import Callable, Generator
-from typing import cast
+from collections.abc import Generator
+from typing import Optional, cast
+
+from volcenginesdkarkruntime import Ark
+from volcenginesdkarkruntime.types.chat import (
+    ChatCompletion,
+    ChatCompletionAssistantMessageParam,
+    ChatCompletionChunk,
+    ChatCompletionContentPartImageParam,
+    ChatCompletionContentPartTextParam,
+    ChatCompletionMessageParam,
+    ChatCompletionMessageToolCallParam,
+    ChatCompletionSystemMessageParam,
+    ChatCompletionToolMessageParam,
+    ChatCompletionToolParam,
+    ChatCompletionUserMessageParam,
+)
+from volcenginesdkarkruntime.types.chat.chat_completion_content_part_image_param import ImageURL
+from volcenginesdkarkruntime.types.chat.chat_completion_message_tool_call_param import Function
+from volcenginesdkarkruntime.types.create_embedding_response import CreateEmbeddingResponse
+from volcenginesdkarkruntime.types.shared_params import FunctionDefinition
 
 from core.model_runtime.entities.message_entities import (
     AssistantPromptMessage,
@@ -12,123 +31,171 @@ from core.model_runtime.entities.message_entities import (
     ToolPromptMessage,
     UserPromptMessage,
 )
-from core.model_runtime.model_providers.volcengine_maas.errors import wrap_error
-from core.model_runtime.model_providers.volcengine_maas.volc_sdk import ChatRole, MaasException, MaasService
 
 
-class MaaSClient(MaasService):
-    def __init__(self, host: str, region: str):
-        self.endpoint_id = None
-        super().__init__(host, region)
+class ArkClientV3:
+    endpoint_id: Optional[str] = None
+    ark: Optional[Ark] = None
 
-    def set_endpoint_id(self, endpoint_id: str):
-        self.endpoint_id = endpoint_id
+    def __init__(self, *args, **kwargs):
+        self.ark = Ark(*args, **kwargs)
+        self.endpoint_id = None
 
-    @classmethod
-    def from_credential(cls, credentials: dict) -> 'MaaSClient':
-        host = credentials['api_endpoint_host']
-        region = credentials['volc_region']
-        ak = credentials['volc_access_key_id']
-        sk = credentials['volc_secret_access_key']
-        endpoint_id = credentials['endpoint_id']
+    @staticmethod
+    def is_legacy(credentials: dict) -> bool:
+        if ArkClientV3.is_compatible_with_legacy(credentials):
+            return False
+        sdk_version = credentials.get("sdk_version", "v2")
+        return sdk_version != "v3"
 
-        client = cls(host, region)
-        client.set_endpoint_id(endpoint_id)
-        client.set_ak(ak)
-        client.set_sk(sk)
-        return client
+    @staticmethod
+    def is_compatible_with_legacy(credentials: dict) -> bool:
+        sdk_version = credentials.get("sdk_version")
+        endpoint = credentials.get("api_endpoint_host")
+        return sdk_version is None and endpoint == "maas-api.ml-platform-cn-beijing.volces.com"
 
-    def chat(self, params: dict, messages: list[PromptMessage], stream=False, **extra_model_kwargs) -> Generator | dict:
-        req = {
-            'parameters': params,
-            'messages': [self.convert_prompt_message_to_maas_message(prompt) for prompt in messages],
-            **extra_model_kwargs,
+    @classmethod
+    def from_credentials(cls, credentials):
+        """Initialize the client using the credentials provided."""
+        args = {
+            "base_url": credentials['api_endpoint_host'],
+            "region": credentials['volc_region'],
+            "ak": credentials['volc_access_key_id'],
+            "sk": credentials['volc_secret_access_key'],
         }
-        if not stream:
-            return super().chat(
-                self.endpoint_id,
-                req,
-            )
-        return super().stream_chat(
-            self.endpoint_id,
-            req,
-        )
+        if cls.is_compatible_with_legacy(credentials):
+            args["base_url"] = "https://ark.cn-beijing.volces.com/api/v3"
 
-    def embeddings(self, texts: list[str]) -> dict:
-        req = {
-            'input': texts
-        }
-        return super().embeddings(self.endpoint_id, req)
+        client = ArkClientV3(
+            **args
+        )
+        client.endpoint_id = credentials['endpoint_id']
+        return client
 
     @staticmethod
-    def convert_prompt_message_to_maas_message(message: PromptMessage) -> dict:
+    def convert_prompt_message(message: PromptMessage) -> ChatCompletionMessageParam:
+        """Converts a PromptMessage to a ChatCompletionMessageParam"""
         if isinstance(message, UserPromptMessage):
             message = cast(UserPromptMessage, message)
             if isinstance(message.content, str):
-                message_dict = {"role": ChatRole.USER,
-                                "content": message.content}
+                content = message.content
             else:
                 content = []
                 for message_content in message.content:
                     if message_content.type == PromptMessageContentType.TEXT:
-                        raise ValueError(
-                            'Content object type only support image_url')
+                        content.append(ChatCompletionContentPartTextParam(
+                            text=message_content.text,
+                            type='text',
+                        ))
                     elif message_content.type == PromptMessageContentType.IMAGE:
                         message_content = cast(
                             ImagePromptMessageContent, message_content)
                         image_data = re.sub(
                             r'^data:image\/[a-zA-Z]+;base64,', '', message_content.data)
-                        content.append({
-                            'type': 'image_url',
-                            'image_url': {
-                                'url': '',
-                                'image_bytes': image_data,
-                                'detail': message_content.detail,
-                            }
-                        })
-
-                message_dict = {'role': ChatRole.USER, 'content': content}
+                        content.append(ChatCompletionContentPartImageParam(
+                            image_url=ImageURL(
+                                url=image_data,
+                                detail=message_content.detail.value,
+                            ),
+                            type='image_url',
+                        ))
+            message_dict = ChatCompletionUserMessageParam(
+                role='user',
+                content=content
+            )
         elif isinstance(message, AssistantPromptMessage):
             message = cast(AssistantPromptMessage, message)
-            message_dict = {'role': ChatRole.ASSISTANT,
-                            'content': message.content}
-            if message.tool_calls:
-                message_dict['tool_calls'] = [
-                    {
-                        'name': call.function.name,
-                        'arguments': call.function.arguments
-                    } for call in message.tool_calls
+            message_dict = ChatCompletionAssistantMessageParam(
+                content=message.content,
+                role='assistant',
+                tool_calls=None if not message.tool_calls else [
+                    ChatCompletionMessageToolCallParam(
+                        id=call.id,
+                        function=Function(
+                            name=call.function.name,
+                            arguments=call.function.arguments
+                        ),
+                        type='function'
+                    ) for call in message.tool_calls
                 ]
+            )
         elif isinstance(message, SystemPromptMessage):
             message = cast(SystemPromptMessage, message)
-            message_dict = {'role': ChatRole.SYSTEM,
-                            'content': message.content}
+            message_dict = ChatCompletionSystemMessageParam(
+                content=message.content,
+                role='system'
+            )
         elif isinstance(message, ToolPromptMessage):
             message = cast(ToolPromptMessage, message)
-            message_dict = {'role': ChatRole.FUNCTION,
-                            'content': message.content,
-                            'name': message.tool_call_id}
+            message_dict = ChatCompletionToolMessageParam(
+                content=message.content,
+                role='tool',
+                tool_call_id=message.tool_call_id
+            )
         else:
             raise ValueError(f"Got unknown PromptMessage type {message}")
 
         return message_dict
 
     @staticmethod
-    def wrap_exception(fn: Callable[[], dict | Generator]) -> dict | Generator:
-        try:
-            resp = fn()
-        except MaasException as e:
-            raise wrap_error(e)
+    def _convert_tool_prompt(message: PromptMessageTool) -> ChatCompletionToolParam:
+        return ChatCompletionToolParam(
+            type='function',
+            function=FunctionDefinition(
+                name=message.name,
+                description=message.description,
+                parameters=message.parameters,
+            )
+        )
 
-        return resp
+    def chat(self, messages: list[PromptMessage],
+             tools: Optional[list[PromptMessageTool]] = None,
+             stop: Optional[list[str]] = None,
+             frequency_penalty: Optional[float] = None,
+             max_tokens: Optional[int] = None,
+             presence_penalty: Optional[float] = None,
+             top_p: Optional[float] = None,
+             temperature: Optional[float] = None,
+             ) -> ChatCompletion:
+        """Block chat"""
+        return self.ark.chat.completions.create(
+            model=self.endpoint_id,
+            messages=[self.convert_prompt_message(message) for message in messages],
+            tools=[self._convert_tool_prompt(tool) for tool in tools] if tools else None,
+            stop=stop,
+            frequency_penalty=frequency_penalty,
+            max_tokens=max_tokens,
+            presence_penalty=presence_penalty,
+            top_p=top_p,
+            temperature=temperature,
+        )
 
-    @staticmethod
-    def transform_tool_prompt_to_maas_config(tool: PromptMessageTool):
-        return {
-            "type": "function",
-            "function": {
-                "name": tool.name,
-                "description": tool.description,
-                "parameters": tool.parameters,
-            }
-        }
+    def stream_chat(self, messages: list[PromptMessage],
+                    tools: Optional[list[PromptMessageTool]] = None,
+                    stop: Optional[list[str]] = None,
+                    frequency_penalty: Optional[float] = None,
+                    max_tokens: Optional[int] = None,
+                    presence_penalty: Optional[float] = None,
+                    top_p: Optional[float] = None,
+                    temperature: Optional[float] = None,
+                    ) -> Generator[ChatCompletionChunk]:
+        """Stream chat"""
+        chunks = self.ark.chat.completions.create(
+            stream=True,
+            model=self.endpoint_id,
+            messages=[self.convert_prompt_message(message) for message in messages],
+            tools=[self._convert_tool_prompt(tool) for tool in tools] if tools else None,
+            stop=stop,
+            frequency_penalty=frequency_penalty,
+            max_tokens=max_tokens,
+            presence_penalty=presence_penalty,
+            top_p=top_p,
+            temperature=temperature,
+        )
+        for chunk in chunks:
+            if not chunk.choices:
+                continue
+            yield chunk
+
+    def embeddings(self, texts: list[str]) -> CreateEmbeddingResponse:
+        return self.ark.embeddings.create(model=self.endpoint_id, input=texts)

+ 0 - 0
api/core/model_runtime/model_providers/volcengine_maas/legacy/__init__.py


+ 134 - 0
api/core/model_runtime/model_providers/volcengine_maas/legacy/client.py

@@ -0,0 +1,134 @@
+import re
+from collections.abc import Callable, Generator
+from typing import cast
+
+from core.model_runtime.entities.message_entities import (
+    AssistantPromptMessage,
+    ImagePromptMessageContent,
+    PromptMessage,
+    PromptMessageContentType,
+    PromptMessageTool,
+    SystemPromptMessage,
+    ToolPromptMessage,
+    UserPromptMessage,
+)
+from core.model_runtime.model_providers.volcengine_maas.legacy.errors import wrap_error
+from core.model_runtime.model_providers.volcengine_maas.legacy.volc_sdk import ChatRole, MaasException, MaasService
+
+
+class MaaSClient(MaasService):
+    def __init__(self, host: str, region: str):
+        self.endpoint_id = None
+        super().__init__(host, region)
+
+    def set_endpoint_id(self, endpoint_id: str):
+        self.endpoint_id = endpoint_id
+
+    @classmethod
+    def from_credential(cls, credentials: dict) -> 'MaaSClient':
+        host = credentials['api_endpoint_host']
+        region = credentials['volc_region']
+        ak = credentials['volc_access_key_id']
+        sk = credentials['volc_secret_access_key']
+        endpoint_id = credentials['endpoint_id']
+
+        client = cls(host, region)
+        client.set_endpoint_id(endpoint_id)
+        client.set_ak(ak)
+        client.set_sk(sk)
+        return client
+
+    def chat(self, params: dict, messages: list[PromptMessage], stream=False, **extra_model_kwargs) -> Generator | dict:
+        req = {
+            'parameters': params,
+            'messages': [self.convert_prompt_message_to_maas_message(prompt) for prompt in messages],
+            **extra_model_kwargs,
+        }
+        if not stream:
+            return super().chat(
+                self.endpoint_id,
+                req,
+            )
+        return super().stream_chat(
+            self.endpoint_id,
+            req,
+        )
+
+    def embeddings(self, texts: list[str]) -> dict:
+        req = {
+            'input': texts
+        }
+        return super().embeddings(self.endpoint_id, req)
+
+    @staticmethod
+    def convert_prompt_message_to_maas_message(message: PromptMessage) -> dict:
+        if isinstance(message, UserPromptMessage):
+            message = cast(UserPromptMessage, message)
+            if isinstance(message.content, str):
+                message_dict = {"role": ChatRole.USER,
+                                "content": message.content}
+            else:
+                content = []
+                for message_content in message.content:
+                    if message_content.type == PromptMessageContentType.TEXT:
+                        raise ValueError(
+                            'Content object type only support image_url')
+                    elif message_content.type == PromptMessageContentType.IMAGE:
+                        message_content = cast(
+                            ImagePromptMessageContent, message_content)
+                        image_data = re.sub(
+                            r'^data:image\/[a-zA-Z]+;base64,', '', message_content.data)
+                        content.append({
+                            'type': 'image_url',
+                            'image_url': {
+                                'url': '',
+                                'image_bytes': image_data,
+                                'detail': message_content.detail,
+                            }
+                        })
+
+                message_dict = {'role': ChatRole.USER, 'content': content}
+        elif isinstance(message, AssistantPromptMessage):
+            message = cast(AssistantPromptMessage, message)
+            message_dict = {'role': ChatRole.ASSISTANT,
+                            'content': message.content}
+            if message.tool_calls:
+                message_dict['tool_calls'] = [
+                    {
+                        'name': call.function.name,
+                        'arguments': call.function.arguments
+                    } for call in message.tool_calls
+                ]
+        elif isinstance(message, SystemPromptMessage):
+            message = cast(SystemPromptMessage, message)
+            message_dict = {'role': ChatRole.SYSTEM,
+                            'content': message.content}
+        elif isinstance(message, ToolPromptMessage):
+            message = cast(ToolPromptMessage, message)
+            message_dict = {'role': ChatRole.FUNCTION,
+                            'content': message.content,
+                            'name': message.tool_call_id}
+        else:
+            raise ValueError(f"Got unknown PromptMessage type {message}")
+
+        return message_dict
+
+    @staticmethod
+    def wrap_exception(fn: Callable[[], dict | Generator]) -> dict | Generator:
+        try:
+            resp = fn()
+        except MaasException as e:
+            raise wrap_error(e)
+
+        return resp
+
+    @staticmethod
+    def transform_tool_prompt_to_maas_config(tool: PromptMessageTool):
+        return {
+            "type": "function",
+            "function": {
+                "name": tool.name,
+                "description": tool.description,
+                "parameters": tool.parameters,
+            }
+        }

+ 1 - 1
api/core/model_runtime/model_providers/volcengine_maas/errors.py

@@ -1,4 +1,4 @@
-from core.model_runtime.model_providers.volcengine_maas.volc_sdk import MaasException
+from core.model_runtime.model_providers.volcengine_maas.legacy.volc_sdk import MaasException
 
 
 class ClientSDKRequestError(MaasException):

api/core/model_runtime/model_providers/volcengine_maas/volc_sdk/__init__.py → api/core/model_runtime/model_providers/volcengine_maas/legacy/volc_sdk/__init__.py


api/core/model_runtime/model_providers/volcengine_maas/volc_sdk/base/__init__.py → api/core/model_runtime/model_providers/volcengine_maas/legacy/volc_sdk/base/__init__.py


api/core/model_runtime/model_providers/volcengine_maas/volc_sdk/base/auth.py → api/core/model_runtime/model_providers/volcengine_maas/legacy/volc_sdk/base/auth.py


api/core/model_runtime/model_providers/volcengine_maas/volc_sdk/base/service.py → api/core/model_runtime/model_providers/volcengine_maas/legacy/volc_sdk/base/service.py


api/core/model_runtime/model_providers/volcengine_maas/volc_sdk/base/util.py → api/core/model_runtime/model_providers/volcengine_maas/legacy/volc_sdk/base/util.py


api/core/model_runtime/model_providers/volcengine_maas/volc_sdk/common.py → api/core/model_runtime/model_providers/volcengine_maas/legacy/volc_sdk/common.py


api/core/model_runtime/model_providers/volcengine_maas/volc_sdk/maas.py → api/core/model_runtime/model_providers/volcengine_maas/legacy/volc_sdk/maas.py


+ 178 - 74
api/core/model_runtime/model_providers/volcengine_maas/llm/llm.py

@@ -1,8 +1,10 @@
 import logging
 from collections.abc import Generator
 
+from volcenginesdkarkruntime.types.chat import ChatCompletion, ChatCompletionChunk
+
 from core.model_runtime.entities.common_entities import I18nObject
-from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
+from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
 from core.model_runtime.entities.message_entities import (
     AssistantPromptMessage,
     PromptMessage,
@@ -27,19 +29,21 @@ from core.model_runtime.errors.invoke import (
 )
 from core.model_runtime.errors.validate import CredentialsValidateFailedError
 from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
-from core.model_runtime.model_providers.volcengine_maas.client import MaaSClient
-from core.model_runtime.model_providers.volcengine_maas.errors import (
+from core.model_runtime.model_providers.volcengine_maas.client import ArkClientV3
+from core.model_runtime.model_providers.volcengine_maas.legacy.client import MaaSClient
+from core.model_runtime.model_providers.volcengine_maas.legacy.errors import (
     AuthErrors,
     BadRequestErrors,
     ConnectionErrors,
+    MaasException,
     RateLimitErrors,
     ServerUnavailableErrors,
 )
 from core.model_runtime.model_providers.volcengine_maas.llm.models import (
     get_model_config,
     get_v2_req_params,
+    get_v3_req_params,
 )
-from core.model_runtime.model_providers.volcengine_maas.volc_sdk import MaasException
 
 logger = logging.getLogger(__name__)
 
@@ -49,13 +53,20 @@ class VolcengineMaaSLargeLanguageModel(LargeLanguageModel):
                 model_parameters: dict, tools: list[PromptMessageTool] | None = None,
                 stop: list[str] | None = None, stream: bool = True, user: str | None = None) \
             -> LLMResult | Generator:
-        return self._generate(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user)
+        if ArkClientV3.is_legacy(credentials):
+            return self._generate_v2(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user)
+        return self._generate_v3(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user)
 
     def validate_credentials(self, model: str, credentials: dict) -> None:
         """
         Validate credentials
         """
-        # ping
+        if ArkClientV3.is_legacy(credentials):
+            return self._validate_credentials_v2(credentials)
+        return self._validate_credentials_v3(credentials)
+
+    @staticmethod
+    def _validate_credentials_v2(credentials: dict) -> None:
         client = MaaSClient.from_credential(credentials)
         try:
             client.chat(
@@ -70,21 +81,40 @@ class VolcengineMaaSLargeLanguageModel(LargeLanguageModel):
         except MaasException as e:
             raise CredentialsValidateFailedError(e.message)
 
+    @staticmethod
+    def _validate_credentials_v3(credentials: dict) -> None:
+        client = ArkClientV3.from_credentials(credentials)
+        try:
+            client.chat(max_tokens=16, temperature=0.7, top_p=0.9,
+                        messages=[UserPromptMessage(content='ping\nAnswer: ')], )
+        except Exception as e:
+            raise CredentialsValidateFailedError(e)
+
     def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
                        tools: list[PromptMessageTool] | None = None) -> int:
-        if len(prompt_messages) == 0:
+        if ArkClientV3.is_legacy(credentials):
+            return self._get_num_tokens_v2(prompt_messages)
+        return self._get_num_tokens_v3(prompt_messages)
+
+    def _get_num_tokens_v2(self, messages: list[PromptMessage]) -> int:
+        if len(messages) == 0:
             return 0
-        return self._num_tokens_from_messages(prompt_messages)
+        num_tokens = 0
+        messages_dict = [
+            MaaSClient.convert_prompt_message_to_maas_message(m) for m in messages]
+        for message in messages_dict:
+            for key, value in message.items():
+                num_tokens += self._get_num_tokens_by_gpt2(str(key))
+                num_tokens += self._get_num_tokens_by_gpt2(str(value))
 
-    def _num_tokens_from_messages(self, messages: list[PromptMessage]) -> int:
-        """
-        Calculate num tokens.
+        return num_tokens
 
-        :param messages: messages
-        """
+    def _get_num_tokens_v3(self, messages: list[PromptMessage]) -> int:
+        if len(messages) == 0:
+            return 0
         num_tokens = 0
         messages_dict = [
-            MaaSClient.convert_prompt_message_to_maas_message(m) for m in messages]
+            ArkClientV3.convert_prompt_message(m) for m in messages]
         for message in messages_dict:
             for key, value in message.items():
                 num_tokens += self._get_num_tokens_by_gpt2(str(key))
@@ -92,9 +122,9 @@ class VolcengineMaaSLargeLanguageModel(LargeLanguageModel):
 
         return num_tokens
 
-    def _generate(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
-                  model_parameters: dict, tools: list[PromptMessageTool] | None = None,
-                  stop: list[str] | None = None, stream: bool = True, user: str | None = None) \
+    def _generate_v2(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
+                     model_parameters: dict, tools: list[PromptMessageTool] | None = None,
+                     stop: list[str] | None = None, stream: bool = True, user: str | None = None) \
             -> LLMResult | Generator:
 
         client = MaaSClient.from_credential(credentials)
@@ -106,77 +136,151 @@ class VolcengineMaaSLargeLanguageModel(LargeLanguageModel):
             ]
         resp = MaaSClient.wrap_exception(
             lambda: client.chat(req_params, prompt_messages, stream, **extra_model_kwargs))
-        if not stream:
-            return self._handle_chat_response(model, credentials, prompt_messages, resp)
-        return self._handle_stream_chat_response(model, credentials, prompt_messages, resp)
 
-    def _handle_stream_chat_response(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], resp: Generator) -> Generator:
-        for index, r in enumerate(resp):
-            choices = r['choices']
+        def _handle_stream_chat_response() -> Generator:
+            for index, r in enumerate(resp):
+                choices = r['choices']
+                if not choices:
+                    continue
+                choice = choices[0]
+                message = choice['message']
+                usage = None
+                if r.get('usage'):
+                    usage = self._calc_response_usage(model=model, credentials=credentials,
+                                                      prompt_tokens=r['usage']['prompt_tokens'],
+                                                      completion_tokens=r['usage']['completion_tokens']
+                                                      )
+                yield LLMResultChunk(
+                    model=model,
+                    prompt_messages=prompt_messages,
+                    delta=LLMResultChunkDelta(
+                        index=index,
+                        message=AssistantPromptMessage(
+                            content=message['content'] if message['content'] else '',
+                            tool_calls=[]
+                        ),
+                        usage=usage,
+                        finish_reason=choice.get('finish_reason'),
+                    ),
+                )
+
+        def _handle_chat_response() -> LLMResult:
+            choices = resp['choices']
             if not choices:
-                continue
+                raise ValueError("No choices found")
+
             choice = choices[0]
             message = choice['message']
-            usage = None
-            if r.get('usage'):
-                usage = self._calc_usage(model, credentials, r['usage'])
-            yield LLMResultChunk(
+
+            # parse tool calls
+            tool_calls = []
+            if message['tool_calls']:
+                for call in message['tool_calls']:
+                    tool_call = AssistantPromptMessage.ToolCall(
+                        id=call['function']['name'],
+                        type=call['type'],
+                        function=AssistantPromptMessage.ToolCall.ToolCallFunction(
+                            name=call['function']['name'],
+                            arguments=call['function']['arguments']
+                        )
+                    )
+                    tool_calls.append(tool_call)
+
+            usage = resp['usage']
+            return LLMResult(
                 model=model,
                 prompt_messages=prompt_messages,
-                delta=LLMResultChunkDelta(
-                    index=index,
-                    message=AssistantPromptMessage(
-                        content=message['content'] if message['content'] else '',
-                        tool_calls=[]
-                    ),
-                    usage=usage,
-                    finish_reason=choice.get('finish_reason'),
+                message=AssistantPromptMessage(
+                    content=message['content'] if message['content'] else '',
+                    tool_calls=tool_calls,
                 ),
+                usage=self._calc_response_usage(model=model, credentials=credentials,
+                                                prompt_tokens=usage['prompt_tokens'],
+                                                completion_tokens=usage['completion_tokens']
+                                                ),
             )
 
-    def _handle_chat_response(self,  model: str, credentials: dict, prompt_messages: list[PromptMessage], resp: dict) -> LLMResult:
-        choices = resp['choices']
-        if not choices:
-            return
-        choice = choices[0]
-        message = choice['message']
-
-        # parse tool calls
-        tool_calls = []
-        if message['tool_calls']:
-            for call in message['tool_calls']:
-                tool_call = AssistantPromptMessage.ToolCall(
-                    id=call['function']['name'],
-                    type=call['type'],
-                    function=AssistantPromptMessage.ToolCall.ToolCallFunction(
-                        name=call['function']['name'],
-                        arguments=call['function']['arguments']
-                    )
+        if not stream:
+            return _handle_chat_response()
+        return _handle_stream_chat_response()
+
+    def _generate_v3(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
+                     model_parameters: dict, tools: list[PromptMessageTool] | None = None,
+                     stop: list[str] | None = None, stream: bool = True, user: str | None = None) \
+            -> LLMResult | Generator:
+
+        client = ArkClientV3.from_credentials(credentials)
+        req_params = get_v3_req_params(credentials, model_parameters, stop)
+        if tools:
+            req_params['tools'] = tools
+
+        def _handle_stream_chat_response(chunks: Generator[ChatCompletionChunk]) -> Generator:
+            for chunk in chunks:
+                if not chunk.choices:
+                    continue
+                choice = chunk.choices[0]
+
+                yield LLMResultChunk(
+                    model=model,
+                    prompt_messages=prompt_messages,
+                    delta=LLMResultChunkDelta(
+                        index=choice.index,
+                        message=AssistantPromptMessage(
+                            content=choice.delta.content,
+                            tool_calls=[]
+                        ),
+                        usage=self._calc_response_usage(model=model, credentials=credentials,
+                                                        prompt_tokens=chunk.usage.prompt_tokens,
+                                                        completion_tokens=chunk.usage.completion_tokens
+                                                        ) if chunk.usage else None,
+                        finish_reason=choice.finish_reason,
+                    ),
                 )
-                tool_calls.append(tool_call)
 
-        return LLMResult(
-            model=model,
-            prompt_messages=prompt_messages,
-            message=AssistantPromptMessage(
-                content=message['content'] if message['content'] else '',
-                tool_calls=tool_calls,
-            ),
-            usage=self._calc_usage(model, credentials, resp['usage']),
-        )
+        def _handle_chat_response(resp: ChatCompletion) -> LLMResult:
+            choice = resp.choices[0]
+            message = choice.message
+            # parse tool calls
+            tool_calls = []
+            if message.tool_calls:
+                for call in message.tool_calls:
+                    tool_call = AssistantPromptMessage.ToolCall(
+                        id=call.id,
+                        type=call.type,
+                        function=AssistantPromptMessage.ToolCall.ToolCallFunction(
+                            name=call.function.name,
+                            arguments=call.function.arguments
+                        )
+                    )
+                    tool_calls.append(tool_call)
+
+            usage = resp.usage
+            return LLMResult(
+                model=model,
+                prompt_messages=prompt_messages,
+                message=AssistantPromptMessage(
+                    content=message.content if message.content else "",
+                    tool_calls=tool_calls,
+                ),
+                usage=self._calc_response_usage(model=model, credentials=credentials,
+                                                prompt_tokens=usage.prompt_tokens,
+                                                completion_tokens=usage.completion_tokens
+                                                ),
+            )
+
+        if not stream:
+            resp = client.chat(prompt_messages, **req_params)
+            return _handle_chat_response(resp)
 
-    def _calc_usage(self,  model: str, credentials: dict, usage: dict) -> LLMUsage:
-        return self._calc_response_usage(model=model, credentials=credentials,
-                                         prompt_tokens=usage['prompt_tokens'],
-                                         completion_tokens=usage['completion_tokens']
-                                         )
+        chunks = client.stream_chat(prompt_messages, **req_params)
+        return _handle_stream_chat_response(chunks)
 
     def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None:
         """
             used to define customizable model schema
         """
         model_config = get_model_config(credentials)
-    
+
         rules = [
             ParameterRule(
                 name='temperature',
@@ -212,7 +316,7 @@ class VolcengineMaaSLargeLanguageModel(LargeLanguageModel):
                 use_template='presence_penalty',
                 label=I18nObject(
                     en_US='Presence Penalty',
-                    zh_Hans= '存在惩罚',
+                    zh_Hans='存在惩罚',
                 ),
                 min=-2.0,
                 max=2.0,
@@ -222,8 +326,8 @@ class VolcengineMaaSLargeLanguageModel(LargeLanguageModel):
                 type=ParameterType.FLOAT,
                 use_template='frequency_penalty',
                 label=I18nObject(
-                    en_US= 'Frequency Penalty',
-                    zh_Hans= '频率惩罚',
+                    en_US='Frequency Penalty',
+                    zh_Hans='频率惩罚',
                 ),
                 min=-2.0,
                 max=2.0,
@@ -245,7 +349,7 @@ class VolcengineMaaSLargeLanguageModel(LargeLanguageModel):
         model_properties = {}
         model_properties[ModelPropertyKey.CONTEXT_SIZE] = model_config.properties.context_size
         model_properties[ModelPropertyKey.MODE] = model_config.properties.mode.value
-       
+
         entity = AIModelEntity(
             model=model,
             label=I18nObject(

+ 44 - 14
api/core/model_runtime/model_providers/volcengine_maas/llm/models.py

@@ -5,10 +5,11 @@ from core.model_runtime.entities.model_entities import ModelFeature
 
 
 class ModelProperties(BaseModel):
-    context_size: int 
-    max_tokens: int 
+    context_size: int
+    max_tokens: int
     mode: LLMMode
 
+
 class ModelConfig(BaseModel):
     properties: ModelProperties
     features: list[ModelFeature]
@@ -24,23 +25,23 @@ configs: dict[str, ModelConfig] = {
         features=[ModelFeature.TOOL_CALL]
     ),
     'Doubao-pro-32k': ModelConfig(
-        properties=ModelProperties(context_size=32768, max_tokens=32768, mode=LLMMode.CHAT),
+        properties=ModelProperties(context_size=32768, max_tokens=4096, mode=LLMMode.CHAT),
         features=[ModelFeature.TOOL_CALL]
     ),
     'Doubao-lite-32k': ModelConfig(
-        properties=ModelProperties(context_size=32768, max_tokens=32768, mode=LLMMode.CHAT),
+        properties=ModelProperties(context_size=32768, max_tokens=4096, mode=LLMMode.CHAT),
         features=[ModelFeature.TOOL_CALL]
     ),
     'Doubao-pro-128k': ModelConfig(
-        properties=ModelProperties(context_size=131072, max_tokens=131072, mode=LLMMode.CHAT),
+        properties=ModelProperties(context_size=131072, max_tokens=4096, mode=LLMMode.CHAT),
         features=[ModelFeature.TOOL_CALL]
     ),
     'Doubao-lite-128k': ModelConfig(
-        properties=ModelProperties(context_size=131072, max_tokens=131072, mode=LLMMode.CHAT),
+        properties=ModelProperties(context_size=131072, max_tokens=4096, mode=LLMMode.CHAT),
         features=[ModelFeature.TOOL_CALL]
     ),
     'Skylark2-pro-4k': ModelConfig(
-        properties=ModelProperties(context_size=4096, max_tokens=4000, mode=LLMMode.CHAT),
+        properties=ModelProperties(context_size=4096, max_tokens=4096, mode=LLMMode.CHAT),
         features=[]
     ),
     'Llama3-8B': ModelConfig(
@@ -77,23 +78,24 @@ configs: dict[str, ModelConfig] = {
     )
 }
 
-def get_model_config(credentials: dict)->ModelConfig:
+
+def get_model_config(credentials: dict) -> ModelConfig:
     base_model = credentials.get('base_model_name', '')
     model_configs = configs.get(base_model)
     if not model_configs:
         return ModelConfig(
-                properties=ModelProperties(
+            properties=ModelProperties(
                 context_size=int(credentials.get('context_size', 0)),
                 max_tokens=int(credentials.get('max_tokens', 0)),
-                mode= LLMMode.value_of(credentials.get('mode', 'chat')),
+                mode=LLMMode.value_of(credentials.get('mode', 'chat')),
             ),
             features=[]
         )
     return model_configs
 
 
-def get_v2_req_params(credentials: dict, model_parameters: dict, 
-                      stop: list[str] | None=None):
+def get_v2_req_params(credentials: dict, model_parameters: dict,
+                      stop: list[str] | None = None):
     req_params = {}
     # predefined properties
     model_configs = get_model_config(credentials)
@@ -116,8 +118,36 @@ def get_v2_req_params(credentials: dict, model_parameters: dict,
     if model_parameters.get('frequency_penalty'):
         req_params['frequency_penalty'] = model_parameters.get(
             'frequency_penalty')
-            
+
+    if stop:
+        req_params['stop'] = stop
+
+    return req_params
+
+
+def get_v3_req_params(credentials: dict, model_parameters: dict,
+                      stop: list[str] | None = None):
+    req_params = {}
+    # predefined properties
+    model_configs = get_model_config(credentials)
+    if model_configs:
+        req_params['max_tokens'] = model_configs.properties.max_tokens
+
+    # model parameters
+    if model_parameters.get('max_tokens'):
+        req_params['max_tokens'] = model_parameters.get('max_tokens')
+    if model_parameters.get('temperature'):
+        req_params['temperature'] = model_parameters.get('temperature')
+    if model_parameters.get('top_p'):
+        req_params['top_p'] = model_parameters.get('top_p')
+    if model_parameters.get('presence_penalty'):
+        req_params['presence_penalty'] = model_parameters.get(
+            'presence_penalty')
+    if model_parameters.get('frequency_penalty'):
+        req_params['frequency_penalty'] = model_parameters.get(
+            'frequency_penalty')
+
     if stop:
         req_params['stop'] = stop
 
-    return req_params
+    return req_params

+ 9 - 6
api/core/model_runtime/model_providers/volcengine_maas/text_embedding/models.py

@@ -2,26 +2,29 @@ from pydantic import BaseModel
 
 
 class ModelProperties(BaseModel):
-    context_size: int 
-    max_chunks: int 
+    context_size: int
+    max_chunks: int
+
 
 class ModelConfig(BaseModel):
     properties: ModelProperties
 
+
 ModelConfigs = {
     'Doubao-embedding': ModelConfig(
-        properties=ModelProperties(context_size=4096, max_chunks=1)
+        properties=ModelProperties(context_size=4096, max_chunks=32)
     ),
 }
 
-def get_model_config(credentials: dict)->ModelConfig:
+
+def get_model_config(credentials: dict) -> ModelConfig:
     base_model = credentials.get('base_model_name', '')
     model_configs = ModelConfigs.get(base_model)
     if not model_configs:
         return ModelConfig(
-                properties=ModelProperties(
+            properties=ModelProperties(
                 context_size=int(credentials.get('context_size', 0)),
                 max_chunks=int(credentials.get('max_chunks', 0)),
             )
         )
-    return model_configs
+    return model_configs

+ 44 - 6
api/core/model_runtime/model_providers/volcengine_maas/text_embedding/text_embedding.py

@@ -22,16 +22,17 @@ from core.model_runtime.errors.invoke import (
 )
 from core.model_runtime.errors.validate import CredentialsValidateFailedError
 from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
-from core.model_runtime.model_providers.volcengine_maas.client import MaaSClient
-from core.model_runtime.model_providers.volcengine_maas.errors import (
+from core.model_runtime.model_providers.volcengine_maas.client import ArkClientV3
+from core.model_runtime.model_providers.volcengine_maas.legacy.client import MaaSClient
+from core.model_runtime.model_providers.volcengine_maas.legacy.errors import (
     AuthErrors,
     BadRequestErrors,
     ConnectionErrors,
+    MaasException,
     RateLimitErrors,
     ServerUnavailableErrors,
 )
 from core.model_runtime.model_providers.volcengine_maas.text_embedding.models import get_model_config
-from core.model_runtime.model_providers.volcengine_maas.volc_sdk import MaasException
 
 
 class VolcengineMaaSTextEmbeddingModel(TextEmbeddingModel):
@@ -51,6 +52,14 @@ class VolcengineMaaSTextEmbeddingModel(TextEmbeddingModel):
         :param user: unique user id
         :return: embeddings result
         """
+        if ArkClientV3.is_legacy(credentials):
+            return self._generate_v2(model, credentials, texts, user)
+
+        return self._generate_v3(model, credentials, texts, user)
+
+    def _generate_v2(self, model: str, credentials: dict,
+                     texts: list[str], user: Optional[str] = None) \
+            -> TextEmbeddingResult:
         client = MaaSClient.from_credential(credentials)
         resp = MaaSClient.wrap_exception(lambda: client.embeddings(texts))
 
@@ -65,6 +74,23 @@ class VolcengineMaaSTextEmbeddingModel(TextEmbeddingModel):
 
         return result
 
+    def _generate_v3(self, model: str, credentials: dict,
+                     texts: list[str], user: Optional[str] = None) \
+            -> TextEmbeddingResult:
+        client = ArkClientV3.from_credentials(credentials)
+        resp = client.embeddings(texts)
+
+        usage = self._calc_response_usage(
+            model=model, credentials=credentials, tokens=resp.usage.total_tokens)
+
+        result = TextEmbeddingResult(
+            model=model,
+            embeddings=[v.embedding for v in resp.data],
+            usage=usage
+        )
+
+        return result
+
     def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int:
         """
         Get number of tokens for given prompt messages
@@ -88,11 +114,22 @@ class VolcengineMaaSTextEmbeddingModel(TextEmbeddingModel):
         :param credentials: model credentials
         :return:
         """
+        if ArkClientV3.is_legacy(credentials):
+            return self._validate_credentials_v2(model, credentials)
+        return self._validate_credentials_v3(model, credentials)
+
+    def _validate_credentials_v2(self, model: str, credentials: dict) -> None:
         try:
             self._invoke(model=model, credentials=credentials, texts=['ping'])
         except MaasException as e:
             raise CredentialsValidateFailedError(e.message)
 
+    def _validate_credentials_v3(self, model: str, credentials: dict) -> None:
+        try:
+            self._invoke(model=model, credentials=credentials, texts=['ping'])
+        except Exception as e:
+            raise CredentialsValidateFailedError(e)
+
     @property
     def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
         """
@@ -116,9 +153,10 @@ class VolcengineMaaSTextEmbeddingModel(TextEmbeddingModel):
             generate custom model entities from credentials
         """
         model_config = get_model_config(credentials)
-        model_properties = {}
-        model_properties[ModelPropertyKey.CONTEXT_SIZE] = model_config.properties.context_size
-        model_properties[ModelPropertyKey.MAX_CHUNKS] = model_config.properties.max_chunks
+        model_properties = {
+            ModelPropertyKey.CONTEXT_SIZE: model_config.properties.context_size,
+            ModelPropertyKey.MAX_CHUNKS: model_config.properties.max_chunks
+        }
         entity = AIModelEntity(
             model=model,
             label=I18nObject(en_US=model),

+ 36 - 1
api/poetry.lock

@@ -6143,6 +6143,19 @@ files = [
     {file = "pyarrow-17.0.0-cp312-cp312-win_amd64.whl", hash = "sha256:392bc9feabc647338e6c89267635e111d71edad5fcffba204425a7c8d13610d7"},
     {file = "pyarrow-17.0.0-cp38-cp38-macosx_10_15_x86_64.whl", hash = "sha256:af5ff82a04b2171415f1410cff7ebb79861afc5dae50be73ce06d6e870615204"},
     {file = "pyarrow-17.0.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:edca18eaca89cd6382dfbcff3dd2d87633433043650c07375d095cd3517561d8"},
+    {file = "pyarrow-17.0.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7c7916bff914ac5d4a8fe25b7a25e432ff921e72f6f2b7547d1e325c1ad9d155"},
+    {file = "pyarrow-17.0.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f553ca691b9e94b202ff741bdd40f6ccb70cdd5fbf65c187af132f1317de6145"},
+    {file = "pyarrow-17.0.0-cp38-cp38-manylinux_2_28_aarch64.whl", hash = "sha256:0cdb0e627c86c373205a2f94a510ac4376fdc523f8bb36beab2e7f204416163c"},
+    {file = "pyarrow-17.0.0-cp38-cp38-manylinux_2_28_x86_64.whl", hash = "sha256:d7d192305d9d8bc9082d10f361fc70a73590a4c65cf31c3e6926cd72b76bc35c"},
+    {file = "pyarrow-17.0.0-cp38-cp38-win_amd64.whl", hash = "sha256:02dae06ce212d8b3244dd3e7d12d9c4d3046945a5933d28026598e9dbbda1fca"},
+    {file = "pyarrow-17.0.0-cp39-cp39-macosx_10_15_x86_64.whl", hash = "sha256:13d7a460b412f31e4c0efa1148e1d29bdf18ad1411eb6757d38f8fbdcc8645fb"},
+    {file = "pyarrow-17.0.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:9b564a51fbccfab5a04a80453e5ac6c9954a9c5ef2890d1bcf63741909c3f8df"},
+    {file = "pyarrow-17.0.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:32503827abbc5aadedfa235f5ece8c4f8f8b0a3cf01066bc8d29de7539532687"},
+    {file = "pyarrow-17.0.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a155acc7f154b9ffcc85497509bcd0d43efb80d6f733b0dc3bb14e281f131c8b"},
+    {file = "pyarrow-17.0.0-cp39-cp39-manylinux_2_28_aarch64.whl", hash = "sha256:dec8d129254d0188a49f8a1fc99e0560dc1b85f60af729f47de4046015f9b0a5"},
+    {file = "pyarrow-17.0.0-cp39-cp39-manylinux_2_28_x86_64.whl", hash = "sha256:a48ddf5c3c6a6c505904545c25a4ae13646ae1f8ba703c4df4a1bfe4f4006bda"},
+    {file = "pyarrow-17.0.0-cp39-cp39-win_amd64.whl", hash = "sha256:42bf93249a083aca230ba7e2786c5f673507fa97bbd9725a1e2754715151a204"},
+    {file = "pyarrow-17.0.0.tar.gz", hash = "sha256:4beca9521ed2c0921c1023e68d097d0299b62c362639ea315572a58f3f50fd28"},
 ]
 
 [package.dependencies]
@@ -8855,6 +8868,28 @@ files = [
 ]
 
 [[package]]
+name = "volcengine-python-sdk"
+version = "1.0.98"
+description = "Volcengine SDK for Python"
+optional = false
+python-versions = "*"
+files = [
+    {file = "volcengine-python-sdk-1.0.98.tar.gz", hash = "sha256:1515e8d46cdcda387f9b45abbcaf0b04b982f7be68068de83f1e388281441784"},
+]
+
+[package.dependencies]
+anyio = {version = ">=3.5.0,<5", optional = true, markers = "extra == \"ark\""}
+certifi = ">=2017.4.17"
+httpx = {version = ">=0.23.0,<1", optional = true, markers = "extra == \"ark\""}
+pydantic = {version = ">=1.9.0,<3", optional = true, markers = "extra == \"ark\""}
+python-dateutil = ">=2.1"
+six = ">=1.10"
+urllib3 = ">=1.23"
+
+[package.extras]
+ark = ["anyio (>=3.5.0,<5)", "cached-property", "httpx (>=0.23.0,<1)", "pydantic (>=1.9.0,<3)"]
+
+[[package]]
 name = "watchfiles"
 version = "0.23.0"
 description = "Simple, modern and high performance file watching and code reload in python."
@@ -9634,4 +9669,4 @@ cffi = ["cffi (>=1.11)"]
 [metadata]
 lock-version = "2.0"
 python-versions = ">=3.10,<3.13"
-content-hash = "d7336115709114c2a4ff09b392f717e9c3547ae82b6a111d0c885c7a44269f02"
+content-hash = "04f970820de691f40fc9fb30f5ff0618b0f1a04d3315b14467fb88e475fa1243"

+ 1 - 0
api/pyproject.toml

@@ -191,6 +191,7 @@ zhipuai = "1.0.7"
 # Related transparent dependencies with pinned verion
 # required by main implementations
 ############################################################
+volcengine-python-sdk = {extras = ["ark"], version = "^1.0.98"}
 [tool.poetry.group.indriect.dependencies]
 kaleido = "0.2.1"
 rank-bm25 = "~0.2.2"