Forráskód Böngészése

Feat/add triton inference server (#2928)

Yeuoly 1 éve%!(EXTRA string=óta)
szülő
commit
240a94182e

+ 2 - 1
api/core/model_runtime/model_providers/_position.yaml

@@ -11,6 +11,8 @@
 - groq
 - replicate
 - huggingface_hub
+- xinference
+- triton_inference_server
 - zhipuai
 - baichuan
 - spark
@@ -20,7 +22,6 @@
 - moonshot
 - jina
 - chatglm
-- xinference
 - yi
 - openllm
 - localai

+ 0 - 0
api/core/model_runtime/model_providers/triton_inference_server/__init__.py


BIN
api/core/model_runtime/model_providers/triton_inference_server/_assets/icon_l_en.png


A különbségek nem kerülnek megjelenítésre, a fájl túl nagy
+ 3 - 0
api/core/model_runtime/model_providers/triton_inference_server/_assets/icon_s_en.svg


+ 0 - 0
api/core/model_runtime/model_providers/triton_inference_server/llm/__init__.py


+ 267 - 0
api/core/model_runtime/model_providers/triton_inference_server/llm/llm.py

@@ -0,0 +1,267 @@
+from collections.abc import Generator
+
+from httpx import Response, post
+from yarl import URL
+
+from core.model_runtime.entities.common_entities import I18nObject
+from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
+from core.model_runtime.entities.message_entities import (
+    AssistantPromptMessage,
+    PromptMessage,
+    PromptMessageTool,
+    SystemPromptMessage,
+    UserPromptMessage,
+)
+from core.model_runtime.entities.model_entities import (
+    AIModelEntity,
+    FetchFrom,
+    ModelPropertyKey,
+    ModelType,
+    ParameterRule,
+    ParameterType,
+)
+from core.model_runtime.errors.invoke import (
+    InvokeAuthorizationError,
+    InvokeBadRequestError,
+    InvokeConnectionError,
+    InvokeError,
+    InvokeRateLimitError,
+    InvokeServerUnavailableError,
+)
+from core.model_runtime.errors.validate import CredentialsValidateFailedError
+from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
+
+
+class TritonInferenceAILargeLanguageModel(LargeLanguageModel):
+    def _invoke(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], 
+                model_parameters: dict, tools: list[PromptMessageTool] | None = None, 
+                stop: list[str] | None = None, stream: bool = True, user: str | None = None) \
+        -> LLMResult | Generator:
+        """
+            invoke LLM
+
+            see `core.model_runtime.model_providers.__base.large_language_model.LargeLanguageModel._invoke`
+        """
+        return self._generate(
+            model=model, credentials=credentials, prompt_messages=prompt_messages, model_parameters=model_parameters,
+            tools=tools, stop=stop, stream=stream, user=user,
+        )
+
+    def validate_credentials(self, model: str, credentials: dict) -> None:
+        """
+            validate credentials
+        """
+        if 'server_url' not in credentials:
+            raise CredentialsValidateFailedError('server_url is required in credentials')
+        
+        try:
+            self._invoke(model=model, credentials=credentials, prompt_messages=[
+                UserPromptMessage(content='ping')
+            ], model_parameters={}, stream=False)
+        except InvokeError as ex:
+            raise CredentialsValidateFailedError(f'An error occurred during connection: {str(ex)}')
+
+    def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
+                       tools: list[PromptMessageTool] | None = None) -> int:
+        """
+            get number of tokens
+
+            cause TritonInference LLM is a customized model, we could net detect which tokenizer to use
+            so we just take the GPT2 tokenizer as default
+        """
+        return self._get_num_tokens_by_gpt2(self._convert_prompt_message_to_text(prompt_messages))
+    
+    def _convert_prompt_message_to_text(self, message: list[PromptMessage]) -> str:
+        """
+            convert prompt message to text
+        """
+        text = ''
+        for item in message:
+            if isinstance(item, UserPromptMessage):
+                text += f'User: {item.content}'
+            elif isinstance(item, SystemPromptMessage):
+                text += f'System: {item.content}'
+            elif isinstance(item, AssistantPromptMessage):
+                text += f'Assistant: {item.content}'
+            else:
+                raise NotImplementedError(f'PromptMessage type {type(item)} is not supported')
+        return text
+
+    def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None:
+        """
+            used to define customizable model schema
+        """
+        rules = [
+            ParameterRule(
+                name='temperature',
+                type=ParameterType.FLOAT,
+                use_template='temperature',
+                label=I18nObject(
+                    zh_Hans='温度',
+                    en_US='Temperature'
+                ),
+            ),
+            ParameterRule(
+                name='top_p',
+                type=ParameterType.FLOAT,
+                use_template='top_p',
+                label=I18nObject(
+                    zh_Hans='Top P',
+                    en_US='Top P'
+                )
+            ),
+            ParameterRule(
+                name='max_tokens',
+                type=ParameterType.INT,
+                use_template='max_tokens',
+                min=1,
+                max=int(credentials.get('context_length', 2048)),
+                default=min(512, int(credentials.get('context_length', 2048))),
+                label=I18nObject(
+                    zh_Hans='最大生成长度',
+                    en_US='Max Tokens'
+                )
+            )
+        ]
+
+        completion_type = None
+
+        if 'completion_type' in credentials:
+            if credentials['completion_type'] == 'chat':
+                completion_type = LLMMode.CHAT.value
+            elif credentials['completion_type'] == 'completion':
+                completion_type = LLMMode.COMPLETION.value
+            else:
+                raise ValueError(f'completion_type {credentials["completion_type"]} is not supported')
+        
+        entity = AIModelEntity(
+            model=model,
+            label=I18nObject(
+                en_US=model
+            ),
+            parameter_rules=rules,
+            fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
+            model_type=ModelType.LLM,
+            model_properties={
+                ModelPropertyKey.MODE: completion_type,
+                ModelPropertyKey.CONTEXT_SIZE: int(credentials.get('context_length', 2048)),
+            },
+        )
+
+        return entity
+    
+    def _generate(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], 
+                 model_parameters: dict,
+                 tools: list[PromptMessageTool] | None = None, 
+                 stop: list[str] | None = None, stream: bool = True, user: str | None = None) \
+            -> LLMResult | Generator:
+        """
+            generate text from LLM
+        """
+        if 'server_url' not in credentials:
+            raise CredentialsValidateFailedError('server_url is required in credentials')
+        
+        if 'stream' in credentials and not bool(credentials['stream']) and stream:
+            raise ValueError(f'stream is not supported by model {model}')
+
+        try:
+            parameters = {}
+            if 'temperature' in model_parameters:
+                parameters['temperature'] = model_parameters['temperature']
+            if 'top_p' in model_parameters:
+                parameters['top_p'] = model_parameters['top_p']
+            if 'top_k' in model_parameters:
+                parameters['top_k'] = model_parameters['top_k']
+            if 'presence_penalty' in model_parameters:
+                parameters['presence_penalty'] = model_parameters['presence_penalty']
+            if 'frequency_penalty' in model_parameters:
+                parameters['frequency_penalty'] = model_parameters['frequency_penalty']
+
+            response = post(str(URL(credentials['server_url']) / 'v2' / 'models' / model / 'generate'), json={
+                'text_input': self._convert_prompt_message_to_text(prompt_messages),
+                'max_tokens': model_parameters.get('max_tokens', 512),
+                'parameters': {
+                    'stream': False,
+                    **parameters
+                },
+            }, timeout=(10, 120))
+            response.raise_for_status()
+            if response.status_code != 200:
+                raise InvokeBadRequestError(f'Invoke failed with status code {response.status_code}, {response.text}')
+            
+            if stream:
+                return self._handle_chat_stream_response(model=model, credentials=credentials, prompt_messages=prompt_messages,
+                                                        tools=tools, resp=response)
+            return self._handle_chat_generate_response(model=model, credentials=credentials, prompt_messages=prompt_messages,
+                                                        tools=tools, resp=response)
+        except Exception as ex:
+            raise InvokeConnectionError(f'An error occurred during connection: {str(ex)}')
+        
+    def _handle_chat_generate_response(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
+                                        tools: list[PromptMessageTool],
+                                        resp: Response) -> LLMResult:
+        """
+            handle normal chat generate response
+        """
+        text = resp.json()['text_output']
+
+        usage = LLMUsage.empty_usage()
+        usage.prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages)
+        usage.completion_tokens = self._get_num_tokens_by_gpt2(text)
+
+        return LLMResult(
+            model=model,
+            prompt_messages=prompt_messages,
+            message=AssistantPromptMessage(
+                content=text
+            ),
+            usage=usage
+        )
+
+    def _handle_chat_stream_response(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
+                                        tools: list[PromptMessageTool],
+                                        resp: Response) -> Generator:
+        """
+            handle normal chat generate response
+        """
+        text = resp.json()['text_output']
+
+        usage = LLMUsage.empty_usage()
+        usage.prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages)
+        usage.completion_tokens = self._get_num_tokens_by_gpt2(text)
+
+        yield LLMResultChunk(
+            model=model,
+            prompt_messages=prompt_messages,
+            delta=LLMResultChunkDelta(
+                index=0,
+                message=AssistantPromptMessage(
+                    content=text
+                ),
+                usage=usage
+            )
+        )
+
+    @property
+    def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
+        """
+        Map model invoke error to unified error
+        The key is the error type thrown to the caller
+        The value is the error type thrown by the model,
+        which needs to be converted into a unified error type for the caller.
+
+        :return: Invoke error mapping
+        """
+        return {
+            InvokeConnectionError: [
+            ],
+            InvokeServerUnavailableError: [
+            ],
+            InvokeRateLimitError: [
+            ],
+            InvokeAuthorizationError: [
+            ],
+            InvokeBadRequestError: [
+                ValueError
+            ]
+        }

+ 9 - 0
api/core/model_runtime/model_providers/triton_inference_server/triton_inference_server.py

@@ -0,0 +1,9 @@
+import logging
+
+from core.model_runtime.model_providers.__base.model_provider import ModelProvider
+
+logger = logging.getLogger(__name__)
+
+class XinferenceAIProvider(ModelProvider):
+    def validate_provider_credentials(self, credentials: dict) -> None:
+        pass

+ 84 - 0
api/core/model_runtime/model_providers/triton_inference_server/triton_inference_server.yaml

@@ -0,0 +1,84 @@
+provider: triton_inference_server
+label:
+  en_US: Triton Inference Server
+icon_small:
+  en_US: icon_s_en.svg
+icon_large:
+  en_US: icon_l_en.png
+background: "#EFFDFD"
+help:
+  title:
+    en_US: How to deploy Triton Inference Server
+    zh_Hans: 如何部署 Triton Inference Server
+  url:
+    en_US: https://github.com/triton-inference-server/server
+supported_model_types:
+  - llm
+configurate_methods:
+  - customizable-model
+model_credential_schema:
+  model:
+    label:
+      en_US: Model Name
+      zh_Hans: 模型名称
+    placeholder:
+      en_US: Enter your model name
+      zh_Hans: 输入模型名称
+  credential_form_schemas:
+    - variable: server_url
+      label:
+        zh_Hans: 服务器URL
+        en_US: Server url
+      type: secret-input
+      required: true
+      placeholder:
+        zh_Hans: 在此输入 Triton Inference Server 的服务器地址,如 http://192.168.1.100:8000
+        en_US: Enter the url of your Triton Inference Server, e.g. http://192.168.1.100:8000
+    - variable: context_size
+      label:
+        zh_Hans: 上下文大小
+        en_US: Context size
+      type: text-input
+      required: true
+      placeholder:
+        zh_Hans: 在此输入您的上下文大小
+        en_US: Enter the context size
+      default: 2048
+    - variable: completion_type
+      label:
+        zh_Hans: 补全类型
+        en_US: Model type
+      type: select
+      required: true
+      default: chat
+      placeholder:
+        zh_Hans: 在此输入您的补全类型
+        en_US: Enter the completion type
+      options:
+        - label:
+            zh_Hans: 补全模型
+            en_US: Completion model
+          value: completion
+        - label:
+            zh_Hans: 对话模型
+            en_US: Chat model
+          value: chat
+    - variable: stream
+      label:
+        zh_Hans: 流式输出
+        en_US: Stream output
+      type: select
+      required: true
+      default: true
+      placeholder:
+        zh_Hans: 是否支持流式输出
+        en_US: Whether to support stream output
+      options:
+        - label:
+            zh_Hans: 是
+            en_US: Yes
+          value: true
+        - label:
+            zh_Hans: 否
+            en_US: No
+          value: false