Browse Source

feat: add volcengine maas model provider (#4142)

sino 1 year ago
parent
commit
4aa21242b6
25 changed files with 1834 additions and 1 deletions
  1. 1 0
      api/core/model_runtime/model_providers/_position.yaml
  2. 0 0
      api/core/model_runtime/model_providers/volcengine_maas/__init__.py
  3. 23 0
      api/core/model_runtime/model_providers/volcengine_maas/_assets/icon_l_en.svg
  4. 39 0
      api/core/model_runtime/model_providers/volcengine_maas/_assets/icon_l_zh.svg
  5. 8 0
      api/core/model_runtime/model_providers/volcengine_maas/_assets/icon_s_en.svg
  6. 108 0
      api/core/model_runtime/model_providers/volcengine_maas/client.py
  7. 156 0
      api/core/model_runtime/model_providers/volcengine_maas/errors.py
  8. 0 0
      api/core/model_runtime/model_providers/volcengine_maas/llm/__init__.py
  9. 284 0
      api/core/model_runtime/model_providers/volcengine_maas/llm/llm.py
  10. 12 0
      api/core/model_runtime/model_providers/volcengine_maas/llm/models.py
  11. 0 0
      api/core/model_runtime/model_providers/volcengine_maas/text_embedding/__init__.py
  12. 132 0
      api/core/model_runtime/model_providers/volcengine_maas/text_embedding/text_embedding.py
  13. 4 0
      api/core/model_runtime/model_providers/volcengine_maas/volc_sdk/__init__.py
  14. 1 0
      api/core/model_runtime/model_providers/volcengine_maas/volc_sdk/base/__init__.py
  15. 144 0
      api/core/model_runtime/model_providers/volcengine_maas/volc_sdk/base/auth.py
  16. 207 0
      api/core/model_runtime/model_providers/volcengine_maas/volc_sdk/base/service.py
  17. 43 0
      api/core/model_runtime/model_providers/volcengine_maas/volc_sdk/base/util.py
  18. 79 0
      api/core/model_runtime/model_providers/volcengine_maas/volc_sdk/common.py
  19. 213 0
      api/core/model_runtime/model_providers/volcengine_maas/volc_sdk/maas.py
  20. 10 0
      api/core/model_runtime/model_providers/volcengine_maas/volcengine_maas.py
  21. 151 0
      api/core/model_runtime/model_providers/volcengine_maas/volcengine_maas.yaml
  22. 7 1
      api/tests/integration_tests/.env.example
  23. 0 0
      api/tests/integration_tests/model_runtime/volcengine_maas/__init__.py
  24. 81 0
      api/tests/integration_tests/model_runtime/volcengine_maas/test_embedding.py
  25. 131 0
      api/tests/integration_tests/model_runtime/volcengine_maas/test_llm.py

+ 1 - 0
api/core/model_runtime/model_providers/_position.yaml

@@ -26,5 +26,6 @@
 - yi
 - openllm
 - localai
+- volcengine_maas
 - openai_api_compatible
 - deepseek

+ 0 - 0
api/core/model_runtime/model_providers/volcengine_maas/__init__.py


File diff suppressed because it is too large
+ 23 - 0
api/core/model_runtime/model_providers/volcengine_maas/_assets/icon_l_en.svg


File diff suppressed because it is too large
+ 39 - 0
api/core/model_runtime/model_providers/volcengine_maas/_assets/icon_l_zh.svg


File diff suppressed because it is too large
+ 8 - 0
api/core/model_runtime/model_providers/volcengine_maas/_assets/icon_s_en.svg


+ 108 - 0
api/core/model_runtime/model_providers/volcengine_maas/client.py

@@ -0,0 +1,108 @@
+import re
+from collections.abc import Callable, Generator
+from typing import cast
+
+from core.model_runtime.entities.message_entities import (
+    AssistantPromptMessage,
+    ImagePromptMessageContent,
+    PromptMessage,
+    PromptMessageContentType,
+    SystemPromptMessage,
+    UserPromptMessage,
+)
+from core.model_runtime.model_providers.volcengine_maas.errors import wrap_error
+from core.model_runtime.model_providers.volcengine_maas.volc_sdk import ChatRole, MaasException, MaasService
+
+
+class MaaSClient(MaasService):
+    def __init__(self, host: str, region: str):
+        self.endpoint_id = None
+        super().__init__(host, region)
+
+    def set_endpoint_id(self, endpoint_id: str):
+        self.endpoint_id = endpoint_id
+
+    @classmethod
+    def from_credential(cls, credentials: dict) -> 'MaaSClient':
+        host = credentials['api_endpoint_host']
+        region = credentials['volc_region']
+        ak = credentials['volc_access_key_id']
+        sk = credentials['volc_secret_access_key']
+        endpoint_id = credentials['endpoint_id']
+
+        client = cls(host, region)
+        client.set_endpoint_id(endpoint_id)
+        client.set_ak(ak)
+        client.set_sk(sk)
+        return client
+
+    def chat(self, params: dict, messages: list[PromptMessage], stream=False) -> Generator | dict:
+        req = {
+            'parameters': params,
+            'messages': [self.convert_prompt_message_to_maas_message(prompt) for prompt in messages]
+        }
+        if not stream:
+            return super().chat(
+                self.endpoint_id,
+                req,
+            )
+        return super().stream_chat(
+            self.endpoint_id,
+            req,
+        )
+
+    def embeddings(self, texts: list[str]) -> dict:
+        req = {
+            'input': texts
+        }
+        return super().embeddings(self.endpoint_id, req)
+
+    @staticmethod
+    def convert_prompt_message_to_maas_message(message: PromptMessage) -> dict:
+        if isinstance(message, UserPromptMessage):
+            message = cast(UserPromptMessage, message)
+            if isinstance(message.content, str):
+                message_dict = {"role": ChatRole.USER,
+                                "content": message.content}
+            else:
+                content = []
+                for message_content in message.content:
+                    if message_content.type == PromptMessageContentType.TEXT:
+                        raise ValueError(
+                            'Content object type only support image_url')
+                    elif message_content.type == PromptMessageContentType.IMAGE:
+                        message_content = cast(
+                            ImagePromptMessageContent, message_content)
+                        image_data = re.sub(
+                            r'^data:image\/[a-zA-Z]+;base64,', '', message_content.data)
+                        content.append({
+                            'type': 'image_url',
+                            'image_url': {
+                                'url': '',
+                                'image_bytes': image_data,
+                                'detail': message_content.detail,
+                            }
+                        })
+
+                message_dict = {'role': ChatRole.USER, 'content': content}
+        elif isinstance(message, AssistantPromptMessage):
+            message = cast(AssistantPromptMessage, message)
+            message_dict = {'role': ChatRole.ASSISTANT,
+                            'content': message.content}
+        elif isinstance(message, SystemPromptMessage):
+            message = cast(SystemPromptMessage, message)
+            message_dict = {'role': ChatRole.SYSTEM,
+                            'content': message.content}
+        else:
+            raise ValueError(f"Got unknown PromptMessage type {message}")
+
+        return message_dict
+
+    @staticmethod
+    def wrap_exception(fn: Callable[[], dict | Generator]) -> dict | Generator:
+        try:
+            resp = fn()
+        except MaasException as e:
+            raise wrap_error(e)
+
+        return resp

+ 156 - 0
api/core/model_runtime/model_providers/volcengine_maas/errors.py

@@ -0,0 +1,156 @@
+from core.model_runtime.model_providers.volcengine_maas.volc_sdk import MaasException
+
+
+class ClientSDKRequestError(MaasException):
+    pass
+
+
+class SignatureDoesNotMatch(MaasException):
+    pass
+
+
+class RequestTimeout(MaasException):
+    pass
+
+
+class ServiceConnectionTimeout(MaasException):
+    pass
+
+
+class MissingAuthenticationHeader(MaasException):
+    pass
+
+
+class AuthenticationHeaderIsInvalid(MaasException):
+    pass
+
+
+class InternalServiceError(MaasException):
+    pass
+
+
+class MissingParameter(MaasException):
+    pass
+
+
+class InvalidParameter(MaasException):
+    pass
+
+
+class AuthenticationExpire(MaasException):
+    pass
+
+
+class EndpointIsInvalid(MaasException):
+    pass
+
+
+class EndpointIsNotEnable(MaasException):
+    pass
+
+
+class ModelNotSupportStreamMode(MaasException):
+    pass
+
+
+class ReqTextExistRisk(MaasException):
+    pass
+
+
+class RespTextExistRisk(MaasException):
+    pass
+
+
+class EndpointRateLimitExceeded(MaasException):
+    pass
+
+
+class ServiceConnectionRefused(MaasException):
+    pass
+
+
+class ServiceConnectionClosed(MaasException):
+    pass
+
+
+class UnauthorizedUserForEndpoint(MaasException):
+    pass
+
+
+class InvalidEndpointWithNoURL(MaasException):
+    pass
+
+
+class EndpointAccountRpmRateLimitExceeded(MaasException):
+    pass
+
+
+class EndpointAccountTpmRateLimitExceeded(MaasException):
+    pass
+
+
+class ServiceResourceWaitQueueFull(MaasException):
+    pass
+
+
+class EndpointIsPending(MaasException):
+    pass
+
+
+class ServiceNotOpen(MaasException):
+    pass
+
+
+AuthErrors = {
+    'SignatureDoesNotMatch': SignatureDoesNotMatch,
+    'MissingAuthenticationHeader': MissingAuthenticationHeader,
+    'AuthenticationHeaderIsInvalid': AuthenticationHeaderIsInvalid,
+    'AuthenticationExpire': AuthenticationExpire,
+    'UnauthorizedUserForEndpoint': UnauthorizedUserForEndpoint,
+}
+
+BadRequestErrors = {
+    'MissingParameter': MissingParameter,
+    'InvalidParameter': InvalidParameter,
+    'EndpointIsInvalid': EndpointIsInvalid,
+    'EndpointIsNotEnable': EndpointIsNotEnable,
+    'ModelNotSupportStreamMode': ModelNotSupportStreamMode,
+    'ReqTextExistRisk': ReqTextExistRisk,
+    'RespTextExistRisk': RespTextExistRisk,
+    'InvalidEndpointWithNoURL': InvalidEndpointWithNoURL,
+    'ServiceNotOpen': ServiceNotOpen,
+}
+
+RateLimitErrors = {
+    'EndpointRateLimitExceeded': EndpointRateLimitExceeded,
+    'EndpointAccountRpmRateLimitExceeded': EndpointAccountRpmRateLimitExceeded,
+    'EndpointAccountTpmRateLimitExceeded': EndpointAccountTpmRateLimitExceeded,
+}
+
+ServerUnavailableErrors = {
+    'InternalServiceError': InternalServiceError,
+    'EndpointIsPending': EndpointIsPending,
+    'ServiceResourceWaitQueueFull': ServiceResourceWaitQueueFull,
+}
+
+ConnectionErrors = {
+    'ClientSDKRequestError': ClientSDKRequestError,
+    'RequestTimeout': RequestTimeout,
+    'ServiceConnectionTimeout': ServiceConnectionTimeout,
+    'ServiceConnectionRefused': ServiceConnectionRefused,
+    'ServiceConnectionClosed': ServiceConnectionClosed,
+}
+
+ErrorCodeMap = {
+    **AuthErrors,
+    **BadRequestErrors,
+    **RateLimitErrors,
+    **ServerUnavailableErrors,
+    **ConnectionErrors,
+}
+
+
+def wrap_error(e: MaasException) -> Exception:
+    if ErrorCodeMap.get(e.code):
+        return ErrorCodeMap.get(e.code)(e.code_n, e.code, e.message, e.req_id)
+    return e

+ 0 - 0
api/core/model_runtime/model_providers/volcengine_maas/llm/__init__.py


+ 284 - 0
api/core/model_runtime/model_providers/volcengine_maas/llm/llm.py

@@ -0,0 +1,284 @@
+import logging
+from collections.abc import Generator
+
+from core.model_runtime.entities.common_entities import I18nObject
+from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
+from core.model_runtime.entities.message_entities import (
+    AssistantPromptMessage,
+    PromptMessage,
+    PromptMessageTool,
+    UserPromptMessage,
+)
+from core.model_runtime.entities.model_entities import (
+    AIModelEntity,
+    FetchFrom,
+    ModelPropertyKey,
+    ModelType,
+    ParameterRule,
+    ParameterType,
+)
+from core.model_runtime.errors.invoke import (
+    InvokeAuthorizationError,
+    InvokeBadRequestError,
+    InvokeConnectionError,
+    InvokeError,
+    InvokeRateLimitError,
+    InvokeServerUnavailableError,
+)
+from core.model_runtime.errors.validate import CredentialsValidateFailedError
+from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
+from core.model_runtime.model_providers.volcengine_maas.client import MaaSClient
+from core.model_runtime.model_providers.volcengine_maas.errors import (
+    AuthErrors,
+    BadRequestErrors,
+    ConnectionErrors,
+    RateLimitErrors,
+    ServerUnavailableErrors,
+)
+from core.model_runtime.model_providers.volcengine_maas.llm.models import ModelConfigs
+from core.model_runtime.model_providers.volcengine_maas.volc_sdk import MaasException
+
+logger = logging.getLogger(__name__)
+
+
+class VolcengineMaaSLargeLanguageModel(LargeLanguageModel):
+    def _invoke(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
+                model_parameters: dict, tools: list[PromptMessageTool] | None = None,
+                stop: list[str] | None = None, stream: bool = True, user: str | None = None) \
+            -> LLMResult | Generator:
+        return self._generate(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user)
+
+    def validate_credentials(self, model: str, credentials: dict) -> None:
+        """
+        Validate credentials
+        """
+        # ping
+        client = MaaSClient.from_credential(credentials)
+        try:
+            client.chat(
+                {
+                    'max_new_tokens': 16,
+                    'temperature': 0.7,
+                    'top_p': 0.9,
+                    'top_k': 15,
+                },
+                [UserPromptMessage(content='ping\nAnswer: ')],
+            )
+        except MaasException as e:
+            raise CredentialsValidateFailedError(e.message)
+
+    def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
+                       tools: list[PromptMessageTool] | None = None) -> int:
+        if len(prompt_messages) == 0:
+            return 0
+        return self._num_tokens_from_messages(prompt_messages)
+
+    def _num_tokens_from_messages(self, messages: list[PromptMessage]) -> int:
+        """
+        Calculate num tokens.
+
+        :param messages: messages
+        """
+        num_tokens = 0
+        messages_dict = [
+            MaaSClient.convert_prompt_message_to_maas_message(m) for m in messages]
+        for message in messages_dict:
+            for key, value in message.items():
+                num_tokens += self._get_num_tokens_by_gpt2(str(key))
+                num_tokens += self._get_num_tokens_by_gpt2(str(value))
+
+        return num_tokens
+
+    def _generate(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
+                  model_parameters: dict, tools: list[PromptMessageTool] | None = None,
+                  stop: list[str] | None = None, stream: bool = True, user: str | None = None) \
+            -> LLMResult | Generator:
+
+        client = MaaSClient.from_credential(credentials)
+
+        req_params = ModelConfigs.get(
+            credentials['base_model_name'], {}).get('req_params', {}).copy()
+        if credentials.get('context_size'):
+            req_params['max_prompt_tokens'] = credentials.get('context_size')
+        if credentials.get('max_tokens'):
+            req_params['max_new_tokens'] = credentials.get('max_tokens')
+        if model_parameters.get('max_tokens'):
+            req_params['max_new_tokens'] = model_parameters.get('max_tokens')
+        if model_parameters.get('temperature'):
+            req_params['temperature'] = model_parameters.get('temperature')
+        if model_parameters.get('top_p'):
+            req_params['top_p'] = model_parameters.get('top_p')
+        if model_parameters.get('top_k'):
+            req_params['top_k'] = model_parameters.get('top_k')
+        if model_parameters.get('presence_penalty'):
+            req_params['presence_penalty'] = model_parameters.get(
+                'presence_penalty')
+        if model_parameters.get('frequency_penalty'):
+            req_params['frequency_penalty'] = model_parameters.get(
+                'frequency_penalty')
+        if stop:
+            req_params['stop'] = stop
+
+        resp = MaaSClient.wrap_exception(
+            lambda: client.chat(req_params, prompt_messages, stream))
+        if not stream:
+            return self._handle_chat_response(model, credentials, prompt_messages, resp)
+        return self._handle_stream_chat_response(model, credentials, prompt_messages, resp)
+
+    def _handle_stream_chat_response(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], resp: Generator) -> Generator:
+        for index, r in enumerate(resp):
+            choices = r['choices']
+            if not choices:
+                continue
+            choice = choices[0]
+            message = choice['message']
+            usage = None
+            if r.get('usage'):
+                usage = self._calc_usage(model, credentials, r['usage'])
+            yield LLMResultChunk(
+                model=model,
+                prompt_messages=prompt_messages,
+                delta=LLMResultChunkDelta(
+                    index=index,
+                    message=AssistantPromptMessage(
+                        content=message['content'] if message['content'] else '',
+                        tool_calls=[]
+                    ),
+                    usage=usage,
+                    finish_reason=choice.get('finish_reason'),
+                ),
+            )
+
+    def _handle_chat_response(self,  model: str, credentials: dict, prompt_messages: list[PromptMessage], resp: dict) -> LLMResult:
+        choices = resp['choices']
+        if not choices:
+            return
+        choice = choices[0]
+        message = choice['message']
+
+        return LLMResult(
+            model=model,
+            prompt_messages=prompt_messages,
+            message=AssistantPromptMessage(
+                content=message['content'] if message['content'] else '',
+                tool_calls=[],
+            ),
+            usage=self._calc_usage(model, credentials, resp['usage']),
+        )
+
+    def _calc_usage(self,  model: str, credentials: dict, usage: dict) -> LLMUsage:
+        return self._calc_response_usage(model=model, credentials=credentials,
+                                         prompt_tokens=usage['prompt_tokens'],
+                                         completion_tokens=usage['completion_tokens']
+                                         )
+
+    def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None:
+        """
+            used to define customizable model schema
+        """
+        max_tokens = ModelConfigs.get(
+            credentials['base_model_name'], {}).get('req_params', {}).get('max_new_tokens')
+        if credentials.get('max_tokens'):
+            max_tokens = int(credentials.get('max_tokens'))
+        rules = [
+            ParameterRule(
+                name='temperature',
+                type=ParameterType.FLOAT,
+                use_template='temperature',
+                label=I18nObject(
+                    zh_Hans='温度',
+                    en_US='Temperature'
+                )
+            ),
+            ParameterRule(
+                name='top_p',
+                type=ParameterType.FLOAT,
+                use_template='top_p',
+                label=I18nObject(
+                    zh_Hans='Top P',
+                    en_US='Top P'
+                )
+            ),
+            ParameterRule(
+                name='top_k',
+                type=ParameterType.INT,
+                min=1,
+                default=1,
+                label=I18nObject(
+                    zh_Hans='Top K',
+                    en_US='Top K'
+                )
+            ),
+            ParameterRule(
+                name='presence_penalty',
+                type=ParameterType.FLOAT,
+                use_template='presence_penalty',
+                label={
+                    'en_US': 'Presence Penalty',
+                    'zh_Hans': '存在惩罚',
+                },
+                min=-2.0,
+                max=2.0,
+            ),
+            ParameterRule(
+                name='frequency_penalty',
+                type=ParameterType.FLOAT,
+                use_template='frequency_penalty',
+                label={
+                    'en_US': 'Frequency Penalty',
+                    'zh_Hans': '频率惩罚',
+                },
+                min=-2.0,
+                max=2.0,
+            ),
+            ParameterRule(
+                name='max_tokens',
+                type=ParameterType.INT,
+                use_template='max_tokens',
+                min=1,
+                max=max_tokens,
+                default=512,
+                label=I18nObject(
+                    zh_Hans='最大生成长度',
+                    en_US='Max Tokens'
+                )
+            ),
+        ]
+
+        model_properties = ModelConfigs.get(
+            credentials['base_model_name'], {}).get('model_properties', {}).copy()
+        if credentials.get('mode'):
+            model_properties[ModelPropertyKey.MODE] = credentials.get('mode')
+        if credentials.get('context_size'):
+            model_properties[ModelPropertyKey.CONTEXT_SIZE] = int(
+                credentials.get('context_size', 4096))
+        entity = AIModelEntity(
+            model=model,
+            label=I18nObject(
+                en_US=model
+            ),
+            fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
+            model_type=ModelType.LLM,
+            model_properties=model_properties,
+            parameter_rules=rules
+        )
+
+        return entity
+
+    @property
+    def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
+        """
+        Map model invoke error to unified error
+        The key is the error type thrown to the caller
+        The value is the error type thrown by the model,
+        which needs to be converted into a unified error type for the caller.
+
+        :return: Invoke error mapping
+        """
+        return {
+            InvokeConnectionError: ConnectionErrors.values(),
+            InvokeServerUnavailableError: ServerUnavailableErrors.values(),
+            InvokeRateLimitError: RateLimitErrors.values(),
+            InvokeAuthorizationError: AuthErrors.values(),
+            InvokeBadRequestError: BadRequestErrors.values(),
+        }

+ 12 - 0
api/core/model_runtime/model_providers/volcengine_maas/llm/models.py

@@ -0,0 +1,12 @@
+ModelConfigs = {
+    'Skylark2-pro-4k': {
+        'req_params': {
+            'max_prompt_tokens': 4096,
+            'max_new_tokens': 4000,
+        },
+        'model_properties': {
+            'context_size': 4096,
+            'mode': 'chat',
+        }
+    }
+}

+ 0 - 0
api/core/model_runtime/model_providers/volcengine_maas/text_embedding/__init__.py


+ 132 - 0
api/core/model_runtime/model_providers/volcengine_maas/text_embedding/text_embedding.py

@@ -0,0 +1,132 @@
+import time
+from typing import Optional
+
+from core.model_runtime.entities.model_entities import PriceType
+from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
+from core.model_runtime.errors.invoke import (
+    InvokeAuthorizationError,
+    InvokeBadRequestError,
+    InvokeConnectionError,
+    InvokeError,
+    InvokeRateLimitError,
+    InvokeServerUnavailableError,
+)
+from core.model_runtime.errors.validate import CredentialsValidateFailedError
+from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
+from core.model_runtime.model_providers.volcengine_maas.client import MaaSClient
+from core.model_runtime.model_providers.volcengine_maas.errors import (
+    AuthErrors,
+    BadRequestErrors,
+    ConnectionErrors,
+    RateLimitErrors,
+    ServerUnavailableErrors,
+)
+from core.model_runtime.model_providers.volcengine_maas.volc_sdk import MaasException
+
+
+class VolcengineMaaSTextEmbeddingModel(TextEmbeddingModel):
+    """
+    Model class for VolcengineMaaS text embedding model.
+    """
+
+    def _invoke(self, model: str, credentials: dict,
+                texts: list[str], user: Optional[str] = None) \
+            -> TextEmbeddingResult:
+        """
+        Invoke text embedding model
+
+        :param model: model name
+        :param credentials: model credentials
+        :param texts: texts to embed
+        :param user: unique user id
+        :return: embeddings result
+        """
+        client = MaaSClient.from_credential(credentials)
+        resp = MaaSClient.wrap_exception(lambda: client.embeddings(texts))
+
+        usage = self._calc_response_usage(
+            model=model, credentials=credentials, tokens=resp['total_tokens'])
+
+        result = TextEmbeddingResult(
+            model=model,
+            embeddings=[v['embedding'] for v in resp['data']],
+            usage=usage
+        )
+
+        return result
+
+    def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int:
+        """
+        Get number of tokens for given prompt messages
+
+        :param model: model name
+        :param credentials: model credentials
+        :param texts: texts to embed
+        :return:
+        """
+        num_tokens = 0
+        for text in texts:
+            # use GPT2Tokenizer to get num tokens
+            num_tokens += self._get_num_tokens_by_gpt2(text)
+        return num_tokens
+
+    def validate_credentials(self, model: str, credentials: dict) -> None:
+        """
+        Validate model credentials
+
+        :param model: model name
+        :param credentials: model credentials
+        :return:
+        """
+        try:
+            self._invoke(model=model, credentials=credentials, texts=['ping'])
+        except MaasException as e:
+            raise CredentialsValidateFailedError(e.message)
+
+    @property
+    def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
+        """
+        Map model invoke error to unified error
+        The key is the error type thrown to the caller
+        The value is the error type thrown by the model,
+        which needs to be converted into a unified error type for the caller.
+
+        :return: Invoke error mapping
+        """
+        return {
+            InvokeConnectionError: ConnectionErrors.values(),
+            InvokeServerUnavailableError: ServerUnavailableErrors.values(),
+            InvokeRateLimitError: RateLimitErrors.values(),
+            InvokeAuthorizationError: AuthErrors.values(),
+            InvokeBadRequestError: BadRequestErrors.values(),
+        }
+
+    def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage:
+        """
+        Calculate response usage
+
+        :param model: model name
+        :param credentials: model credentials
+        :param tokens: input tokens
+        :return: usage
+        """
+        # get input price info
+        input_price_info = self.get_price(
+            model=model,
+            credentials=credentials,
+            price_type=PriceType.INPUT,
+            tokens=tokens
+        )
+
+        # transform usage
+        usage = EmbeddingUsage(
+            tokens=tokens,
+            total_tokens=tokens,
+            unit_price=input_price_info.unit_price,
+            price_unit=input_price_info.unit,
+            total_price=input_price_info.total_amount,
+            currency=input_price_info.currency,
+            latency=time.perf_counter() - self.started_at
+        )
+
+        return usage

+ 4 - 0
api/core/model_runtime/model_providers/volcengine_maas/volc_sdk/__init__.py

@@ -0,0 +1,4 @@
+from .common import ChatRole
+from .maas import MaasException, MaasService
+
+__all__ = ['MaasService', 'ChatRole', 'MaasException']

+ 1 - 0
api/core/model_runtime/model_providers/volcengine_maas/volc_sdk/base/__init__.py

@@ -0,0 +1 @@
+

+ 144 - 0
api/core/model_runtime/model_providers/volcengine_maas/volc_sdk/base/auth.py

@@ -0,0 +1,144 @@
+# coding : utf-8
+import datetime
+
+import pytz
+
+from .util import Util
+
+
+class MetaData:
+    def __init__(self):
+        self.algorithm = ''
+        self.credential_scope = ''
+        self.signed_headers = ''
+        self.date = ''
+        self.region = ''
+        self.service = ''
+
+    def set_date(self, date):
+        self.date = date
+
+    def set_service(self, service):
+        self.service = service
+
+    def set_region(self, region):
+        self.region = region
+
+    def set_algorithm(self, algorithm):
+        self.algorithm = algorithm
+
+    def set_credential_scope(self, credential_scope):
+        self.credential_scope = credential_scope
+
+    def set_signed_headers(self, signed_headers):
+        self.signed_headers = signed_headers
+
+
+class SignResult:
+    def __init__(self):
+        self.xdate = ''
+        self.xCredential = ''
+        self.xAlgorithm = ''
+        self.xSignedHeaders = ''
+        self.xSignedQueries = ''
+        self.xSignature = ''
+        self.xContextSha256 = ''
+        self.xSecurityToken = ''
+
+        self.authorization = ''
+
+    def __str__(self):
+        return '\n'.join(['{}:{}'.format(*item) for item in self.__dict__.items()])
+
+
+class Credentials:
+    def __init__(self, ak, sk, service, region, session_token=''):
+        self.ak = ak
+        self.sk = sk
+        self.service = service
+        self.region = region
+        self.session_token = session_token
+
+    def set_ak(self, ak):
+        self.ak = ak
+
+    def set_sk(self, sk):
+        self.sk = sk
+
+    def set_session_token(self, session_token):
+        self.session_token = session_token
+
+
+class Signer:
+    @staticmethod
+    def sign(request, credentials):
+        if request.path == '':
+            request.path = '/'
+        if request.method != 'GET' and not ('Content-Type' in request.headers):
+            request.headers['Content-Type'] = 'application/x-www-form-urlencoded; charset=utf-8'
+
+        format_date = Signer.get_current_format_date()
+        request.headers['X-Date'] = format_date
+        if credentials.session_token != '':
+            request.headers['X-Security-Token'] = credentials.session_token
+
+        md = MetaData()
+        md.set_algorithm('HMAC-SHA256')
+        md.set_service(credentials.service)
+        md.set_region(credentials.region)
+        md.set_date(format_date[:8])
+
+        hashed_canon_req = Signer.hashed_canonical_request_v4(request, md)
+        md.set_credential_scope('/'.join([md.date, md.region, md.service, 'request']))
+
+        signing_str = '\n'.join([md.algorithm, format_date, md.credential_scope, hashed_canon_req])
+        signing_key = Signer.get_signing_secret_key_v4(credentials.sk, md.date, md.region, md.service)
+        sign = Util.to_hex(Util.hmac_sha256(signing_key, signing_str))
+        request.headers['Authorization'] = Signer.build_auth_header_v4(sign, md, credentials)
+        return
+
+    @staticmethod
+    def hashed_canonical_request_v4(request, meta):
+        body_hash = Util.sha256(request.body)
+        request.headers['X-Content-Sha256'] = body_hash
+
+        signed_headers = dict()
+        for key in request.headers:
+            if key in ['Content-Type', 'Content-Md5', 'Host'] or key.startswith('X-'):
+                signed_headers[key.lower()] = request.headers[key]
+
+        if 'host' in signed_headers:
+            v = signed_headers['host']
+            if v.find(':') != -1:
+                split = v.split(':')
+                port = split[1]
+                if str(port) == '80' or str(port) == '443':
+                    signed_headers['host'] = split[0]
+
+        signed_str = ''
+        for key in sorted(signed_headers.keys()):
+            signed_str += key + ':' + signed_headers[key] + '\n'
+
+        meta.set_signed_headers(';'.join(sorted(signed_headers.keys())))
+
+        canonical_request = '\n'.join(
+            [request.method, Util.norm_uri(request.path), Util.norm_query(request.query), signed_str,
+             meta.signed_headers, body_hash])
+
+        return Util.sha256(canonical_request)
+
+    @staticmethod
+    def get_signing_secret_key_v4(sk, date, region, service):
+        date = Util.hmac_sha256(bytes(sk, encoding='utf-8'), date)
+        region = Util.hmac_sha256(date, region)
+        service = Util.hmac_sha256(region, service)
+        return Util.hmac_sha256(service, 'request')
+
+    @staticmethod
+    def build_auth_header_v4(signature, meta, credentials):
+        credential = credentials.ak + '/' + meta.credential_scope
+        return meta.algorithm + ' Credential=' + credential + ', SignedHeaders=' + meta.signed_headers + ', Signature=' + signature
+
+    @staticmethod
+    def get_current_format_date():
+        return datetime.datetime.now(tz=pytz.timezone('UTC')).strftime("%Y%m%dT%H%M%SZ")

+ 207 - 0
api/core/model_runtime/model_providers/volcengine_maas/volc_sdk/base/service.py

@@ -0,0 +1,207 @@
+import json
+from collections import OrderedDict
+from urllib.parse import urlencode
+
+import requests
+
+from .auth import Signer
+
+VERSION = 'v1.0.137'
+
+
+class Service:
+    def __init__(self, service_info, api_info):
+        self.service_info = service_info
+        self.api_info = api_info
+        self.session = requests.session()
+
+    def set_ak(self, ak):
+        self.service_info.credentials.set_ak(ak)
+
+    def set_sk(self, sk):
+        self.service_info.credentials.set_sk(sk)
+
+    def set_session_token(self, session_token):
+        self.service_info.credentials.set_session_token(session_token)
+
+    def set_host(self, host):
+        self.service_info.host = host
+
+    def set_scheme(self, scheme):
+        self.service_info.scheme = scheme
+
+    def get(self, api, params, doseq=0):
+        if not (api in self.api_info):
+            raise Exception("no such api")
+        api_info = self.api_info[api]
+
+        r = self.prepare_request(api_info, params, doseq)
+
+        Signer.sign(r, self.service_info.credentials)
+
+        url = r.build(doseq)
+        resp = self.session.get(url, headers=r.headers,
+                                timeout=(self.service_info.connection_timeout, self.service_info.socket_timeout))
+        if resp.status_code == 200:
+            return resp.text
+        else:
+            raise Exception(resp.text)
+
+    def post(self, api, params, form):
+        if not (api in self.api_info):
+            raise Exception("no such api")
+        api_info = self.api_info[api]
+        r = self.prepare_request(api_info, params)
+        r.headers['Content-Type'] = 'application/x-www-form-urlencoded'
+        r.form = self.merge(api_info.form, form)
+        r.body = urlencode(r.form, True)
+        Signer.sign(r, self.service_info.credentials)
+
+        url = r.build()
+
+        resp = self.session.post(url, headers=r.headers, data=r.form,
+                                 timeout=(self.service_info.connection_timeout, self.service_info.socket_timeout))
+        if resp.status_code == 200:
+            return resp.text
+        else:
+            raise Exception(resp.text)
+
+    def json(self, api, params, body):
+        if not (api in self.api_info):
+            raise Exception("no such api")
+        api_info = self.api_info[api]
+        r = self.prepare_request(api_info, params)
+        r.headers['Content-Type'] = 'application/json'
+        r.body = body
+
+        Signer.sign(r, self.service_info.credentials)
+
+        url = r.build()
+        resp = self.session.post(url, headers=r.headers, data=r.body,
+                                 timeout=(self.service_info.connection_timeout, self.service_info.socket_timeout))
+        if resp.status_code == 200:
+            return json.dumps(resp.json())
+        else:
+            raise Exception(resp.text.encode("utf-8"))
+
+    def put(self, url, file_path, headers):
+        with open(file_path, 'rb') as f:
+            resp = self.session.put(url, headers=headers, data=f)
+            if resp.status_code == 200:
+                return True, resp.text.encode("utf-8")
+            else:
+                return False, resp.text.encode("utf-8")
+
+    def put_data(self, url, data, headers):
+        resp = self.session.put(url, headers=headers, data=data)
+        if resp.status_code == 200:
+            return True, resp.text.encode("utf-8")
+        else:
+            return False, resp.text.encode("utf-8")
+
+    def prepare_request(self, api_info, params, doseq=0):
+        for key in params:
+            if type(params[key]) == int or type(params[key]) == float or type(params[key]) == bool:
+                params[key] = str(params[key])
+            elif type(params[key]) == list:
+                if not doseq:
+                    params[key] = ','.join(params[key])
+
+        connection_timeout = self.service_info.connection_timeout
+        socket_timeout = self.service_info.socket_timeout
+
+        r = Request()
+        r.set_schema(self.service_info.scheme)
+        r.set_method(api_info.method)
+        r.set_connection_timeout(connection_timeout)
+        r.set_socket_timeout(socket_timeout)
+
+        headers = self.merge(api_info.header, self.service_info.header)
+        headers['Host'] = self.service_info.host
+        headers['User-Agent'] = 'volc-sdk-python/' + VERSION
+        r.set_headers(headers)
+
+        query = self.merge(api_info.query, params)
+        r.set_query(query)
+
+        r.set_host(self.service_info.host)
+        r.set_path(api_info.path)
+
+        return r
+
+    @staticmethod
+    def merge(param1, param2):
+        od = OrderedDict()
+        for key in param1:
+            od[key] = param1[key]
+
+        for key in param2:
+            od[key] = param2[key]
+
+        return od
+
+
+class Request:
+    def __init__(self):
+        self.schema = ''
+        self.method = ''
+        self.host = ''
+        self.path = ''
+        self.headers = OrderedDict()
+        self.query = OrderedDict()
+        self.body = ''
+        self.form = dict()
+        self.connection_timeout = 0
+        self.socket_timeout = 0
+
+    def set_schema(self, schema):
+        self.schema = schema
+
+    def set_method(self, method):
+        self.method = method
+
+    def set_host(self, host):
+        self.host = host
+
+    def set_path(self, path):
+        self.path = path
+
+    def set_headers(self, headers):
+        self.headers = headers
+
+    def set_query(self, query):
+        self.query = query
+
+    def set_body(self, body):
+        self.body = body
+
+    def set_connection_timeout(self, connection_timeout):
+        self.connection_timeout = connection_timeout
+
+    def set_socket_timeout(self, socket_timeout):
+        self.socket_timeout = socket_timeout
+
+    def build(self, doseq=0):
+        return self.schema + '://' + self.host + self.path + '?' + urlencode(self.query, doseq)
+
+
+class ServiceInfo:
+    def __init__(self, host, header, credentials, connection_timeout, socket_timeout, scheme='http'):
+        self.host = host
+        self.header = header
+        self.credentials = credentials
+        self.connection_timeout = connection_timeout
+        self.socket_timeout = socket_timeout
+        self.scheme = scheme
+
+
+class ApiInfo:
+    def __init__(self, method, path, query, form, header):
+        self.method = method
+        self.path = path
+        self.query = query
+        self.form = form
+        self.header = header
+
+    def __str__(self):
+        return 'method: ' + self.method + ', path: ' + self.path

+ 43 - 0
api/core/model_runtime/model_providers/volcengine_maas/volc_sdk/base/util.py

@@ -0,0 +1,43 @@
+import hashlib
+import hmac
+from functools import reduce
+from urllib.parse import quote
+
+
+class Util:
+    @staticmethod
+    def norm_uri(path):
+        return quote(path).replace('%2F', '/').replace('+', '%20')
+
+    @staticmethod
+    def norm_query(params):
+        query = ''
+        for key in sorted(params.keys()):
+            if type(params[key]) == list:
+                for k in params[key]:
+                    query = query + quote(key, safe='-_.~') + '=' + quote(k, safe='-_.~') + '&'
+            else:
+                query = query + quote(key, safe='-_.~') + '=' + quote(params[key], safe='-_.~') + '&'
+        query = query[:-1]
+        return query.replace('+', '%20')
+
+    @staticmethod
+    def hmac_sha256(key, content):
+        return hmac.new(key, bytes(content, encoding='utf-8'), hashlib.sha256).digest()
+
+    @staticmethod
+    def sha256(content):
+        if isinstance(content, str) is True:
+            return hashlib.sha256(content.encode('utf-8')).hexdigest()
+        else:
+            return hashlib.sha256(content).hexdigest()
+
+    @staticmethod
+    def to_hex(content):
+        lst = []
+        for ch in content:
+            hv = hex(ch).replace('0x', '')
+            if len(hv) == 1:
+                hv = '0' + hv
+            lst.append(hv)
+        return reduce(lambda x, y: x + y, lst)

+ 79 - 0
api/core/model_runtime/model_providers/volcengine_maas/volc_sdk/common.py

@@ -0,0 +1,79 @@
+import json
+import random
+from datetime import datetime
+
+
+class ChatRole:
+    USER = "user"
+    ASSISTANT = "assistant"
+    SYSTEM = "system"
+    FUNCTION = "function"
+
+
+class _Dict(dict):
+    __setattr__ = dict.__setitem__
+    __getattr__ = dict.__getitem__
+
+    def __missing__(self, key):
+        return None
+
+
+def dict_to_object(dict_obj):
+    # 支持嵌套类型
+    if isinstance(dict_obj, list):
+        insts = []
+        for i in dict_obj:
+            insts.append(dict_to_object(i))
+        return insts
+
+    if isinstance(dict_obj, dict):
+        inst = _Dict()
+        for k, v in dict_obj.items():
+            inst[k] = dict_to_object(v)
+        return inst
+
+    return dict_obj
+
+
+def json_to_object(json_str, req_id=None):
+    obj = dict_to_object(json.loads(json_str))
+    if obj and isinstance(obj, dict) and req_id:
+        obj["req_id"] = req_id
+    return obj
+
+
+def gen_req_id():
+    return datetime.now().strftime("%Y%m%d%H%M%S") + format(
+        random.randint(0, 2 ** 64 - 1), "020X"
+    )
+
+
+class SSEDecoder:
+    def __init__(self, source):
+        self.source = source
+
+    def _read(self):
+        data = b''
+        for chunk in self.source:
+            for line in chunk.splitlines(True):
+                data += line
+                if data.endswith((b'\r\r', b'\n\n', b'\r\n\r\n')):
+                    yield data
+                    data = b''
+        if data:
+            yield data
+
+    def next(self):
+        for chunk in self._read():
+            for line in chunk.splitlines():
+                # skip comment
+                if line.startswith(b':'):
+                    continue
+
+                if b':' in line:
+                    field, value = line.split(b':', 1)
+                else:
+                    field, value = line, b''
+
+                if field == b'data' and len(value) > 0:
+                    yield value

+ 213 - 0
api/core/model_runtime/model_providers/volcengine_maas/volc_sdk/maas.py

@@ -0,0 +1,213 @@
+import copy
+import json
+from collections.abc import Iterator
+
+from .base.auth import Credentials, Signer
+from .base.service import ApiInfo, Service, ServiceInfo
+from .common import SSEDecoder, dict_to_object, gen_req_id, json_to_object
+
+
+class MaasService(Service):
+    def __init__(self, host, region, connection_timeout=60, socket_timeout=60):
+        service_info = self.get_service_info(
+            host, region, connection_timeout, socket_timeout
+        )
+        self._apikey = None
+        api_info = self.get_api_info()
+        super().__init__(service_info, api_info)
+
+    def set_apikey(self, apikey):
+        self._apikey = apikey
+
+    @staticmethod
+    def get_service_info(host, region, connection_timeout, socket_timeout):
+        service_info = ServiceInfo(
+            host,
+            {"Accept": "application/json"},
+            Credentials("", "", "ml_maas", region),
+            connection_timeout,
+            socket_timeout,
+            "https",
+        )
+        return service_info
+
+    @staticmethod
+    def get_api_info():
+        api_info = {
+            "chat": ApiInfo("POST", "/api/v2/endpoint/{endpoint_id}/chat", {}, {}, {}),
+            "embeddings": ApiInfo(
+                "POST", "/api/v2/endpoint/{endpoint_id}/embeddings", {}, {}, {}
+            ),
+        }
+        return api_info
+
+    def chat(self, endpoint_id, req):
+        req["stream"] = False
+        return self._request(endpoint_id, "chat", req)
+
+    def stream_chat(self, endpoint_id, req):
+        req_id = gen_req_id()
+        self._validate("chat", req_id)
+        apikey = self._apikey
+
+        try:
+            req["stream"] = True
+            res = self._call(
+                endpoint_id, "chat", req_id, {}, json.dumps(req).encode("utf-8"), apikey, stream=True
+            )
+
+            decoder = SSEDecoder(res)
+
+            def iter_fn():
+                for data in decoder.next():
+                    if data == b"[DONE]":
+                        return
+
+                    try:
+                        res = json_to_object(
+                            str(data, encoding="utf-8"), req_id=req_id)
+                    except Exception:
+                        raise
+
+                    if res.error is not None and res.error.code_n != 0:
+                        raise MaasException(
+                            res.error.code_n,
+                            res.error.code,
+                            res.error.message,
+                            req_id,
+                        )
+                    yield res
+
+            return iter_fn()
+        except MaasException:
+            raise
+        except Exception as e:
+            raise new_client_sdk_request_error(str(e))
+
+    def embeddings(self, endpoint_id, req):
+        return self._request(endpoint_id, "embeddings", req)
+
+    def _request(self, endpoint_id, api, req, params={}):
+        req_id = gen_req_id()
+
+        self._validate(api, req_id)
+
+        apikey = self._apikey
+
+        try:
+            res = self._call(endpoint_id, api, req_id, params,
+                             json.dumps(req).encode("utf-8"), apikey)
+            resp = dict_to_object(res.json())
+            if resp and isinstance(resp, dict):
+                resp["req_id"] = req_id
+            return resp
+
+        except MaasException as e:
+            raise e
+        except Exception as e:
+            raise new_client_sdk_request_error(str(e), req_id)
+
+    def _validate(self, api, req_id):
+        credentials_exist = (
+            self.service_info.credentials is not None and
+            self.service_info.credentials.sk is not None and
+            self.service_info.credentials.ak is not None
+        )
+
+        if not self._apikey and not credentials_exist:
+            raise new_client_sdk_request_error("no valid credential", req_id)
+
+        if not (api in self.api_info):
+            raise new_client_sdk_request_error("no such api", req_id)
+
+    def _call(self, endpoint_id, api, req_id, params, body, apikey=None, stream=False):
+        api_info = copy.deepcopy(self.api_info[api])
+        api_info.path = api_info.path.format(endpoint_id=endpoint_id)
+
+        r = self.prepare_request(api_info, params)
+        r.headers["x-tt-logid"] = req_id
+        r.headers["Content-Type"] = "application/json"
+        r.body = body
+
+        if apikey is None:
+            Signer.sign(r, self.service_info.credentials)
+        elif apikey is not None:
+            r.headers["Authorization"] = "Bearer " + apikey
+
+        url = r.build()
+        res = self.session.post(
+            url,
+            headers=r.headers,
+            data=r.body,
+            timeout=(
+                self.service_info.connection_timeout,
+                self.service_info.socket_timeout,
+            ),
+            stream=stream,
+        )
+
+        if res.status_code != 200:
+            raw = res.text.encode()
+            res.close()
+            try:
+                resp = json_to_object(
+                    str(raw, encoding="utf-8"), req_id=req_id)
+            except Exception:
+                raise new_client_sdk_request_error(raw, req_id)
+
+            if resp.error:
+                raise MaasException(
+                    resp.error.code_n, resp.error.code, resp.error.message, req_id
+                )
+            else:
+                raise new_client_sdk_request_error(resp, req_id)
+
+        return res
+
+
+class MaasException(Exception):
+    def __init__(self, code_n, code, message, req_id):
+        self.code_n = code_n
+        self.code = code
+        self.message = message
+        self.req_id = req_id
+
+    def __str__(self):
+        return ("Detailed exception information is listed below.\n" +
+                "req_id: {}\n" +
+                "code_n: {}\n" +
+                "code: {}\n" +
+                "message: {}").format(self.req_id, self.code_n, self.code, self.message)
+
+
+def new_client_sdk_request_error(raw, req_id=""):
+    return MaasException(1709701, "ClientSDKRequestError", "MaaS SDK request error: {}".format(raw), req_id)
+
+
+class BinaryResponseContent:
+    def __init__(self, response, request_id) -> None:
+        self.response = response
+        self.request_id = request_id
+
+    def stream_to_file(
+            self,
+            file: str
+    ) -> None:
+        is_first = True
+        error_bytes = b''
+        with open(file, mode="wb") as f:
+            for data in self.response:
+                if len(error_bytes) > 0 or (is_first and "\"error\":" in str(data)):
+                    error_bytes += data
+                else:
+                    f.write(data)
+
+        if len(error_bytes) > 0:
+            resp = json_to_object(
+                str(error_bytes, encoding="utf-8"), req_id=self.request_id)
+            raise MaasException(
+                resp.error.code_n, resp.error.code, resp.error.message, self.request_id
+            )
+
+    def iter_bytes(self) -> Iterator[bytes]:
+        yield from self.response

+ 10 - 0
api/core/model_runtime/model_providers/volcengine_maas/volcengine_maas.py

@@ -0,0 +1,10 @@
+import logging
+
+from core.model_runtime.model_providers.__base.model_provider import ModelProvider
+
+logger = logging.getLogger(__name__)
+
+
+class VolcengineMaaSProvider(ModelProvider):
+    def validate_provider_credentials(self, credentials: dict) -> None:
+        pass

+ 151 - 0
api/core/model_runtime/model_providers/volcengine_maas/volcengine_maas.yaml

@@ -0,0 +1,151 @@
+provider: volcengine_maas
+label:
+  en_US: Volcengine
+description:
+  en_US: Volcengine MaaS models.
+icon_small:
+  en_US: icon_s_en.svg
+icon_large:
+  en_US: icon_l_en.svg
+  zh_Hans: icon_l_zh.svg
+background: "#F9FAFB"
+help:
+  title:
+    en_US: Get your Access Key and Secret Access Key from Volcengine Console
+  url:
+    en_US: https://console.volcengine.com/iam/keymanage/
+supported_model_types:
+  - llm
+  - text-embedding
+configurate_methods:
+  - customizable-model
+model_credential_schema:
+  model:
+    label:
+      en_US: Model Name
+      zh_Hans: 模型名称
+    placeholder:
+      en_US: Enter your Model Name
+      zh_Hans: 输入模型名称
+  credential_form_schemas:
+    - variable: volc_access_key_id
+      required: true
+      label:
+        en_US: Access Key
+        zh_Hans: Access Key
+      type: secret-input
+      placeholder:
+        en_US: Enter your Access Key
+        zh_Hans: 输入您的 Access Key
+    - variable: volc_secret_access_key
+      required: true
+      label:
+        en_US: Secret Access Key
+        zh_Hans: Secret Access Key
+      type: secret-input
+      placeholder:
+        en_US: Enter your Secret Access Key
+        zh_Hans: 输入您的 Secret Access Key
+    - variable: volc_region
+      required: true
+      label:
+        en_US: Volcengine Region
+        zh_Hans: 火山引擎地区
+      type: text-input
+      default: cn-beijing
+      placeholder:
+        en_US: Enter Volcengine Region
+        zh_Hans: 输入火山引擎地域
+    - variable: api_endpoint_host
+      required: true
+      label:
+        en_US: API Endpoint Host
+        zh_Hans: API Endpoint Host
+      type: text-input
+      default: maas-api.ml-platform-cn-beijing.volces.com
+      placeholder:
+        en_US: Enter your API Endpoint Host
+        zh_Hans: 输入 API Endpoint Host
+    - variable: endpoint_id
+      required: true
+      label:
+        en_US: Endpoint ID
+        zh_Hans: Endpoint ID
+      type: text-input
+      placeholder:
+        en_US: Enter your Endpoint ID
+        zh_Hans: 输入您的 Endpoint ID
+    - variable: base_model_name
+      show_on:
+        - variable: __model_type
+          value: llm
+      label:
+        en_US: Base Model
+        zh_Hans: 基础模型
+      type: select
+      required: true
+      options:
+        - label:
+            en_US: Skylark2-pro-4k
+          value: Skylark2-pro-4k
+          show_on:
+            - variable: __model_type
+              value: llm
+        - label:
+            en_US: Custom
+            zh_Hans: 自定义
+          value: Custom
+    - variable: mode
+      required: true
+      show_on:
+        - variable: __model_type
+          value: llm
+        - variable: base_model_name
+          value: Custom
+      label:
+        zh_Hans: 模型类型
+        en_US: Completion Mode
+      type: select
+      default: chat
+      placeholder:
+        zh_Hans: 选择对话类型
+        en_US: Select Completion Mode
+      options:
+        - value: completion
+          label:
+            en_US: Completion
+            zh_Hans: 补全
+        - value: chat
+          label:
+            en_US: Chat
+            zh_Hans: 对话
+    - variable: context_size
+      required: true
+      show_on:
+        - variable: __model_type
+          value: llm
+        - variable: base_model_name
+          value: Custom
+      label:
+        zh_Hans: 模型上下文长度
+        en_US: Model Context Size
+      type: text-input
+      default: '4096'
+      placeholder:
+        zh_Hans: 输入您的模型上下文长度
+        en_US: Enter your Model Context Size
+    - variable: max_tokens
+      required: true
+      show_on:
+        - variable: __model_type
+          value: llm
+        - variable: base_model_name
+          value: Custom
+      label:
+        zh_Hans: 最大 token 上限
+        en_US: Upper Bound for Max Tokens
+      default: '4096'
+      type: text-input
+      placeholder:
+        zh_Hans: 输入您的模型最大 token 上限
+        en_US: Enter your model Upper Bound for Max Tokens

+ 7 - 1
api/tests/integration_tests/.env.example

@@ -73,4 +73,10 @@ MOCK_SWITCH=false
 
 # CODE EXECUTION CONFIGURATION
 CODE_EXECUTION_ENDPOINT=
-CODE_EXECUTION_API_KEY=
+CODE_EXECUTION_API_KEY=
+
+# Volcengine MaaS Credentials
+VOLC_API_KEY=
+VOLC_SECRET_KEY=
+VOLC_MODEL_ENDPOINT_ID=
+VOLC_EMBEDDING_ENDPOINT_ID=

+ 0 - 0
api/tests/integration_tests/model_runtime/volcengine_maas/__init__.py


+ 81 - 0
api/tests/integration_tests/model_runtime/volcengine_maas/test_embedding.py

@@ -0,0 +1,81 @@
+import os
+
+import pytest
+
+from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult
+from core.model_runtime.errors.validate import CredentialsValidateFailedError
+from core.model_runtime.model_providers.volcengine_maas.text_embedding.text_embedding import (
+    VolcengineMaaSTextEmbeddingModel,
+)
+
+
+def test_validate_credentials():
+    model = VolcengineMaaSTextEmbeddingModel()
+
+    with pytest.raises(CredentialsValidateFailedError):
+        model.validate_credentials(
+            model='NOT IMPORTANT',
+            credentials={
+                'api_endpoint_host': 'maas-api.ml-platform-cn-beijing.volces.com',
+                'volc_region': 'cn-beijing',
+                'volc_access_key_id': 'INVALID',
+                'volc_secret_access_key': 'INVALID',
+                'endpoint_id': 'INVALID',
+            }
+        )
+
+    model.validate_credentials(
+        model='NOT IMPORTANT',
+        credentials={
+            'api_endpoint_host': 'maas-api.ml-platform-cn-beijing.volces.com',
+            'volc_region': 'cn-beijing',
+            'volc_access_key_id': os.environ.get('VOLC_API_KEY'),
+            'volc_secret_access_key': os.environ.get('VOLC_SECRET_KEY'),
+            'endpoint_id': os.environ.get('VOLC_EMBEDDING_ENDPOINT_ID'),
+        },
+    )
+
+
+def test_invoke_model():
+    model = VolcengineMaaSTextEmbeddingModel()
+
+    result = model.invoke(
+        model='NOT IMPORTANT',
+        credentials={
+            'api_endpoint_host': 'maas-api.ml-platform-cn-beijing.volces.com',
+            'volc_region': 'cn-beijing',
+            'volc_access_key_id': os.environ.get('VOLC_API_KEY'),
+            'volc_secret_access_key': os.environ.get('VOLC_SECRET_KEY'),
+            'endpoint_id': os.environ.get('VOLC_EMBEDDING_ENDPOINT_ID'),
+        },
+        texts=[
+            "hello",
+            "world"
+        ],
+        user="abc-123"
+    )
+
+    assert isinstance(result, TextEmbeddingResult)
+    assert len(result.embeddings) == 2
+    assert result.usage.total_tokens > 0
+
+
+def test_get_num_tokens():
+    model = VolcengineMaaSTextEmbeddingModel()
+
+    num_tokens = model.get_num_tokens(
+        model='NOT IMPORTANT',
+        credentials={
+            'api_endpoint_host': 'maas-api.ml-platform-cn-beijing.volces.com',
+            'volc_region': 'cn-beijing',
+            'volc_access_key_id': os.environ.get('VOLC_API_KEY'),
+            'volc_secret_access_key': os.environ.get('VOLC_SECRET_KEY'),
+            'endpoint_id': os.environ.get('VOLC_EMBEDDING_ENDPOINT_ID'),
+        },
+        texts=[
+            "hello",
+            "world"
+        ]
+    )
+
+    assert num_tokens == 2

+ 131 - 0
api/tests/integration_tests/model_runtime/volcengine_maas/test_llm.py

@@ -0,0 +1,131 @@
+import os
+from collections.abc import Generator
+
+import pytest
+
+from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
+from core.model_runtime.entities.message_entities import AssistantPromptMessage, UserPromptMessage
+from core.model_runtime.errors.validate import CredentialsValidateFailedError
+from core.model_runtime.model_providers.volcengine_maas.llm.llm import VolcengineMaaSLargeLanguageModel
+
+
+def test_validate_credentials_for_chat_model():
+    model = VolcengineMaaSLargeLanguageModel()
+
+    with pytest.raises(CredentialsValidateFailedError):
+        model.validate_credentials(
+            model='NOT IMPORTANT',
+            credentials={
+                'api_endpoint_host': 'maas-api.ml-platform-cn-beijing.volces.com',
+                'volc_region': 'cn-beijing',
+                'volc_access_key_id': 'INVALID',
+                'volc_secret_access_key': 'INVALID',
+                'endpoint_id': 'INVALID',
+            }
+        )
+
+    model.validate_credentials(
+        model='NOT IMPORTANT',
+        credentials={
+            'api_endpoint_host': 'maas-api.ml-platform-cn-beijing.volces.com',
+            'volc_region': 'cn-beijing',
+            'volc_access_key_id': os.environ.get('VOLC_API_KEY'),
+            'volc_secret_access_key': os.environ.get('VOLC_SECRET_KEY'),
+            'endpoint_id': os.environ.get('VOLC_MODEL_ENDPOINT_ID'),
+        }
+    )
+
+
+def test_invoke_model():
+    model = VolcengineMaaSLargeLanguageModel()
+
+    response = model.invoke(
+        model='NOT IMPORTANT',
+        credentials={
+            'api_endpoint_host': 'maas-api.ml-platform-cn-beijing.volces.com',
+            'volc_region': 'cn-beijing',
+            'volc_access_key_id': os.environ.get('VOLC_API_KEY'),
+            'volc_secret_access_key': os.environ.get('VOLC_SECRET_KEY'),
+            'endpoint_id': os.environ.get('VOLC_MODEL_ENDPOINT_ID'),
+            'base_model_name': 'Skylark2-pro-4k',
+        },
+        prompt_messages=[
+            UserPromptMessage(
+                content='Hello World!'
+            )
+        ],
+        model_parameters={
+            'temperature': 0.7,
+            'top_p': 1.0,
+            'top_k': 1,
+        },
+        stop=['you'],
+        user="abc-123",
+        stream=False
+    )
+
+    assert isinstance(response, LLMResult)
+    assert len(response.message.content) > 0
+    assert response.usage.total_tokens > 0
+
+
+def test_invoke_stream_model():
+    model = VolcengineMaaSLargeLanguageModel()
+
+    response = model.invoke(
+        model='NOT IMPORTANT',
+        credentials={
+            'api_endpoint_host': 'maas-api.ml-platform-cn-beijing.volces.com',
+            'volc_region': 'cn-beijing',
+            'volc_access_key_id': os.environ.get('VOLC_API_KEY'),
+            'volc_secret_access_key': os.environ.get('VOLC_SECRET_KEY'),
+            'endpoint_id': os.environ.get('VOLC_MODEL_ENDPOINT_ID'),
+            'base_model_name': 'Skylark2-pro-4k',
+        },
+        prompt_messages=[
+            UserPromptMessage(
+                content='Hello World!'
+            )
+        ],
+        model_parameters={
+            'temperature': 0.7,
+            'top_p': 1.0,
+            'top_k': 1,
+        },
+        stop=['you'],
+        stream=True,
+        user="abc-123"
+    )
+
+    assert isinstance(response, Generator)
+    for chunk in response:
+        assert isinstance(chunk, LLMResultChunk)
+        assert isinstance(chunk.delta, LLMResultChunkDelta)
+        assert isinstance(chunk.delta.message, AssistantPromptMessage)
+        assert len(
+            chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True
+
+
+def test_get_num_tokens():
+    model = VolcengineMaaSLargeLanguageModel()
+
+    response = model.get_num_tokens(
+        model='NOT IMPORTANT',
+        credentials={
+            'api_endpoint_host': 'maas-api.ml-platform-cn-beijing.volces.com',
+            'volc_region': 'cn-beijing',
+            'volc_access_key_id': os.environ.get('VOLC_API_KEY'),
+            'volc_secret_access_key': os.environ.get('VOLC_SECRET_KEY'),
+            'endpoint_id': os.environ.get('VOLC_MODEL_ENDPOINT_ID'),
+            'base_model_name': 'Skylark2-pro-4k',
+        },
+        prompt_messages=[
+            UserPromptMessage(
+                content='Hello World!'
+            )
+        ],
+        tools=[]
+    )
+
+    assert isinstance(response, int)
+    assert response == 6