소스 검색

Add Stepfun LLM Support (#6346)

forrestlinfeng 9 달 전
부모
커밋
3b5b548af3

+ 0 - 0
api/core/model_runtime/model_providers/stepfun/__init__.py


BIN
api/core/model_runtime/model_providers/stepfun/_assets/icon_l_en.png


BIN
api/core/model_runtime/model_providers/stepfun/_assets/icon_s_en.png


+ 0 - 0
api/core/model_runtime/model_providers/stepfun/llm/__init__.py


+ 6 - 0
api/core/model_runtime/model_providers/stepfun/llm/_position.yaml

@@ -0,0 +1,6 @@
+- step-1-8k
+- step-1-32k
+- step-1-128k
+- step-1-256k
+- step-1v-8k
+- step-1v-32k

+ 328 - 0
api/core/model_runtime/model_providers/stepfun/llm/llm.py

@@ -0,0 +1,328 @@
+import json
+from collections.abc import Generator
+from typing import Optional, Union, cast
+
+import requests
+
+from core.model_runtime.entities.common_entities import I18nObject
+from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta
+from core.model_runtime.entities.message_entities import (
+    AssistantPromptMessage,
+    ImagePromptMessageContent,
+    PromptMessage,
+    PromptMessageContent,
+    PromptMessageContentType,
+    PromptMessageTool,
+    SystemPromptMessage,
+    ToolPromptMessage,
+    UserPromptMessage,
+)
+from core.model_runtime.entities.model_entities import (
+    AIModelEntity,
+    FetchFrom,
+    ModelFeature,
+    ModelPropertyKey,
+    ModelType,
+    ParameterRule,
+    ParameterType,
+)
+from core.model_runtime.model_providers.openai_api_compatible.llm.llm import OAIAPICompatLargeLanguageModel
+
+
+class StepfunLargeLanguageModel(OAIAPICompatLargeLanguageModel):
+    def _invoke(self, model: str, credentials: dict,
+                prompt_messages: list[PromptMessage], model_parameters: dict,
+                tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
+                stream: bool = True, user: Optional[str] = None) \
+            -> Union[LLMResult, Generator]:
+        self._add_custom_parameters(credentials)
+        self._add_function_call(model, credentials)
+        user = user[:32] if user else None
+        return super()._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user)
+
+    def validate_credentials(self, model: str, credentials: dict) -> None:
+        self._add_custom_parameters(credentials)
+        super().validate_credentials(model, credentials)
+
+    def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None:
+        return AIModelEntity(
+            model=model,
+            label=I18nObject(en_US=model, zh_Hans=model),
+            model_type=ModelType.LLM,
+            features=[ModelFeature.TOOL_CALL, ModelFeature.MULTI_TOOL_CALL, ModelFeature.STREAM_TOOL_CALL] 
+                if credentials.get('function_calling_type') == 'tool_call' 
+                else [],
+            fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
+            model_properties={
+                ModelPropertyKey.CONTEXT_SIZE: int(credentials.get('context_size', 8000)),
+                ModelPropertyKey.MODE: LLMMode.CHAT.value,
+            },
+            parameter_rules=[
+                ParameterRule(
+                    name='temperature',
+                    use_template='temperature',
+                    label=I18nObject(en_US='Temperature', zh_Hans='温度'),
+                    type=ParameterType.FLOAT,
+                ),
+                ParameterRule(
+                    name='max_tokens',
+                    use_template='max_tokens',
+                    default=512,
+                    min=1,
+                    max=int(credentials.get('max_tokens', 1024)),
+                    label=I18nObject(en_US='Max Tokens', zh_Hans='最大标记'),
+                    type=ParameterType.INT,
+                ),
+                ParameterRule(
+                    name='top_p',
+                    use_template='top_p',
+                    label=I18nObject(en_US='Top P', zh_Hans='Top P'),
+                    type=ParameterType.FLOAT,
+                ),
+            ]
+        )
+
+    def _add_custom_parameters(self, credentials: dict) -> None:
+        credentials['mode'] = 'chat'
+        credentials['endpoint_url'] = 'https://api.stepfun.com/v1'
+
+    def _add_function_call(self, model: str, credentials: dict) -> None:
+        model_schema = self.get_model_schema(model, credentials)
+        if model_schema and {
+            ModelFeature.TOOL_CALL, ModelFeature.MULTI_TOOL_CALL
+        }.intersection(model_schema.features or []):
+            credentials['function_calling_type'] = 'tool_call'
+
+    def _convert_prompt_message_to_dict(self, message: PromptMessage,credentials: Optional[dict] = None) -> dict:
+        """
+        Convert PromptMessage to dict for OpenAI API format
+        """
+        if isinstance(message, UserPromptMessage):
+            message = cast(UserPromptMessage, message)
+            if isinstance(message.content, str):
+                message_dict = {"role": "user", "content": message.content}
+            else:
+                sub_messages = []
+                for message_content in message.content:
+                    if message_content.type == PromptMessageContentType.TEXT:
+                        message_content = cast(PromptMessageContent, message_content)
+                        sub_message_dict = {
+                            "type": "text",
+                            "text": message_content.data
+                        }
+                        sub_messages.append(sub_message_dict)
+                    elif message_content.type == PromptMessageContentType.IMAGE:
+                        message_content = cast(ImagePromptMessageContent, message_content)
+                        sub_message_dict = {
+                            "type": "image_url",
+                            "image_url": {
+                                "url": message_content.data,
+                            }
+                        }
+                        sub_messages.append(sub_message_dict)
+                message_dict = {"role": "user", "content": sub_messages}
+        elif isinstance(message, AssistantPromptMessage):
+            message = cast(AssistantPromptMessage, message)
+            message_dict = {"role": "assistant", "content": message.content}
+            if message.tool_calls:
+                message_dict["tool_calls"] = []
+                for function_call in message.tool_calls:
+                    message_dict["tool_calls"].append({
+                        "id": function_call.id,
+                        "type": function_call.type,
+                        "function": {
+                            "name": function_call.function.name,
+                            "arguments": function_call.function.arguments
+                        }
+                    })
+        elif isinstance(message, ToolPromptMessage):
+            message = cast(ToolPromptMessage, message)
+            message_dict = {"role": "tool", "content": message.content, "tool_call_id": message.tool_call_id}
+        elif isinstance(message, SystemPromptMessage):
+            message = cast(SystemPromptMessage, message)
+            message_dict = {"role": "system", "content": message.content}
+        else:
+            raise ValueError(f"Got unknown type {message}")
+
+        if message.name:
+            message_dict["name"] = message.name
+
+        return message_dict
+
+    def _extract_response_tool_calls(self, response_tool_calls: list[dict]) -> list[AssistantPromptMessage.ToolCall]:
+        """
+        Extract tool calls from response
+
+        :param response_tool_calls: response tool calls
+        :return: list of tool calls
+        """
+        tool_calls = []
+        if response_tool_calls:
+            for response_tool_call in response_tool_calls:
+                function = AssistantPromptMessage.ToolCall.ToolCallFunction(
+                    name=response_tool_call["function"]["name"] if response_tool_call.get("function", {}).get("name") else "",
+                    arguments=response_tool_call["function"]["arguments"] if response_tool_call.get("function", {}).get("arguments") else ""
+                )
+
+                tool_call = AssistantPromptMessage.ToolCall(
+                    id=response_tool_call["id"] if response_tool_call.get("id") else "",
+                    type=response_tool_call["type"] if response_tool_call.get("type") else "",
+                    function=function
+                )
+                tool_calls.append(tool_call)
+
+        return tool_calls
+
+    def _handle_generate_stream_response(self, model: str, credentials: dict, response: requests.Response,
+                                         prompt_messages: list[PromptMessage]) -> Generator:
+        """
+        Handle llm stream response
+
+        :param model: model name
+        :param credentials: model credentials
+        :param response: streamed response
+        :param prompt_messages: prompt messages
+        :return: llm response chunk generator
+        """
+        full_assistant_content = ''
+        chunk_index = 0
+
+        def create_final_llm_result_chunk(index: int, message: AssistantPromptMessage, finish_reason: str) \
+                -> LLMResultChunk:
+            # calculate num tokens
+            prompt_tokens = self._num_tokens_from_string(model, prompt_messages[0].content)
+            completion_tokens = self._num_tokens_from_string(model, full_assistant_content)
+
+            # transform usage
+            usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
+
+            return LLMResultChunk(
+                model=model,
+                prompt_messages=prompt_messages,
+                delta=LLMResultChunkDelta(
+                    index=index,
+                    message=message,
+                    finish_reason=finish_reason,
+                    usage=usage
+                )
+            )
+
+        tools_calls: list[AssistantPromptMessage.ToolCall] = []
+        finish_reason = "Unknown"
+
+        def increase_tool_call(new_tool_calls: list[AssistantPromptMessage.ToolCall]):
+            def get_tool_call(tool_name: str):
+                if not tool_name:
+                    return tools_calls[-1]
+
+                tool_call = next((tool_call for tool_call in tools_calls if tool_call.function.name == tool_name), None)
+                if tool_call is None:
+                    tool_call = AssistantPromptMessage.ToolCall(
+                        id='',
+                        type='',
+                        function=AssistantPromptMessage.ToolCall.ToolCallFunction(name=tool_name, arguments="")
+                    )
+                    tools_calls.append(tool_call)
+
+                return tool_call
+
+            for new_tool_call in new_tool_calls:
+                # get tool call
+                tool_call = get_tool_call(new_tool_call.function.name)
+                # update tool call
+                if new_tool_call.id:
+                    tool_call.id = new_tool_call.id
+                if new_tool_call.type:
+                    tool_call.type = new_tool_call.type
+                if new_tool_call.function.name:
+                    tool_call.function.name = new_tool_call.function.name
+                if new_tool_call.function.arguments:
+                    tool_call.function.arguments += new_tool_call.function.arguments
+
+        for chunk in response.iter_lines(decode_unicode=True, delimiter="\n\n"):
+            if chunk:
+                # ignore sse comments
+                if chunk.startswith(':'):
+                    continue
+                decoded_chunk = chunk.strip().lstrip('data: ').lstrip()
+                chunk_json = None
+                try:
+                    chunk_json = json.loads(decoded_chunk)
+                # stream ended
+                except json.JSONDecodeError as e:
+                    yield create_final_llm_result_chunk(
+                        index=chunk_index + 1,
+                        message=AssistantPromptMessage(content=""),
+                        finish_reason="Non-JSON encountered."
+                    )
+                    break
+                if not chunk_json or len(chunk_json['choices']) == 0:
+                    continue
+
+                choice = chunk_json['choices'][0]
+                finish_reason = chunk_json['choices'][0].get('finish_reason')
+                chunk_index += 1
+
+                if 'delta' in choice:
+                    delta = choice['delta']
+                    delta_content = delta.get('content')
+
+                    assistant_message_tool_calls = delta.get('tool_calls', None)
+                    # assistant_message_function_call = delta.delta.function_call
+
+                    # extract tool calls from response
+                    if assistant_message_tool_calls:
+                        tool_calls = self._extract_response_tool_calls(assistant_message_tool_calls)
+                        increase_tool_call(tool_calls)
+
+                    if delta_content is None or delta_content == '':
+                        continue
+
+                    # transform assistant message to prompt message
+                    assistant_prompt_message = AssistantPromptMessage(
+                        content=delta_content,
+                        tool_calls=tool_calls if assistant_message_tool_calls else []
+                    )
+
+                    full_assistant_content += delta_content
+                elif 'text' in choice:
+                    choice_text = choice.get('text', '')
+                    if choice_text == '':
+                        continue
+
+                    # transform assistant message to prompt message
+                    assistant_prompt_message = AssistantPromptMessage(content=choice_text)
+                    full_assistant_content += choice_text
+                else:
+                    continue
+
+                # check payload indicator for completion
+                yield LLMResultChunk(
+                    model=model,
+                    prompt_messages=prompt_messages,
+                    delta=LLMResultChunkDelta(
+                        index=chunk_index,
+                        message=assistant_prompt_message,
+                    )
+                )
+
+            chunk_index += 1
+        
+        if tools_calls:
+            yield LLMResultChunk(
+                model=model,
+                prompt_messages=prompt_messages,
+                delta=LLMResultChunkDelta(
+                    index=chunk_index,
+                    message=AssistantPromptMessage(
+                        tool_calls=tools_calls,
+                        content=""
+                    ),
+                )
+            )
+
+        yield create_final_llm_result_chunk(
+            index=chunk_index,
+            message=AssistantPromptMessage(content=""),
+            finish_reason=finish_reason
+        )

+ 25 - 0
api/core/model_runtime/model_providers/stepfun/llm/step-1-128k.yaml

@@ -0,0 +1,25 @@
+model: step-1-128k
+label:
+  zh_Hans: step-1-128k
+  en_US: step-1-128k
+model_type: llm
+features:
+  - agent-thought
+model_properties:
+  mode: chat
+  context_size: 128000
+parameter_rules:
+  - name: temperature
+    use_template: temperature
+  - name: top_p
+    use_template: top_p
+  - name: max_tokens
+    use_template: max_tokens
+    default: 1024
+    min: 1
+    max: 128000
+pricing:
+  input: '0.04'
+  output: '0.20'
+  unit: '0.001'
+  currency: RMB

+ 25 - 0
api/core/model_runtime/model_providers/stepfun/llm/step-1-256k.yaml

@@ -0,0 +1,25 @@
+model: step-1-256k
+label:
+  zh_Hans: step-1-256k
+  en_US: step-1-256k
+model_type: llm
+features:
+  - agent-thought
+model_properties:
+  mode: chat
+  context_size: 256000
+parameter_rules:
+  - name: temperature
+    use_template: temperature
+  - name: top_p
+    use_template: top_p
+  - name: max_tokens
+    use_template: max_tokens
+    default: 1024
+    min: 1
+    max: 256000
+pricing:
+  input: '0.095'
+  output: '0.300'
+  unit: '0.001'
+  currency: RMB

+ 28 - 0
api/core/model_runtime/model_providers/stepfun/llm/step-1-32k.yaml

@@ -0,0 +1,28 @@
+model: step-1-32k
+label:
+  zh_Hans: step-1-32k
+  en_US: step-1-32k
+model_type: llm
+features:
+  - agent-thought
+  - tool-call
+  - multi-tool-call
+  - stream-tool-call
+model_properties:
+  mode: chat
+  context_size: 32000
+parameter_rules:
+  - name: temperature
+    use_template: temperature
+  - name: top_p
+    use_template: top_p
+  - name: max_tokens
+    use_template: max_tokens
+    default: 1024
+    min: 1
+    max: 32000
+pricing:
+  input: '0.015'
+  output: '0.070'
+  unit: '0.001'
+  currency: RMB

+ 28 - 0
api/core/model_runtime/model_providers/stepfun/llm/step-1-8k.yaml

@@ -0,0 +1,28 @@
+model: step-1-8k
+label:
+  zh_Hans: step-1-8k
+  en_US: step-1-8k
+model_type: llm
+features:
+  - agent-thought
+  - tool-call
+  - multi-tool-call
+  - stream-tool-call
+model_properties:
+  mode: chat
+  context_size: 8000
+parameter_rules:
+  - name: temperature
+    use_template: temperature
+  - name: top_p
+    use_template: top_p
+  - name: max_tokens
+    use_template: max_tokens
+    default: 512
+    min: 1
+    max: 8000
+pricing:
+  input: '0.005'
+  output: '0.020'
+  unit: '0.001'
+  currency: RMB

+ 25 - 0
api/core/model_runtime/model_providers/stepfun/llm/step-1v-32k.yaml

@@ -0,0 +1,25 @@
+model: step-1v-32k
+label:
+  zh_Hans: step-1v-32k
+  en_US: step-1v-32k
+model_type: llm
+features:
+  - vision
+model_properties:
+  mode: chat
+  context_size: 32000
+parameter_rules:
+  - name: temperature
+    use_template: temperature
+  - name: top_p
+    use_template: top_p
+  - name: max_tokens
+    use_template: max_tokens
+    default: 1024
+    min: 1
+    max: 32000
+pricing:
+  input: '0.015'
+  output: '0.070'
+  unit: '0.001'
+  currency: RMB

+ 25 - 0
api/core/model_runtime/model_providers/stepfun/llm/step-1v-8k.yaml

@@ -0,0 +1,25 @@
+model: step-1v-8k
+label:
+  zh_Hans: step-1v-8k
+  en_US: step-1v-8k
+model_type: llm
+features:
+  - vision
+model_properties:
+  mode: chat
+  context_size: 8192
+parameter_rules:
+  - name: temperature
+    use_template: temperature
+  - name: top_p
+    use_template: top_p
+  - name: max_tokens
+    use_template: max_tokens
+    default: 512
+    min: 1
+    max: 8192
+pricing:
+  input: '0.005'
+  output: '0.020'
+  unit: '0.001'
+  currency: RMB

+ 30 - 0
api/core/model_runtime/model_providers/stepfun/stepfun.py

@@ -0,0 +1,30 @@
+import logging
+
+from core.model_runtime.entities.model_entities import ModelType
+from core.model_runtime.errors.validate import CredentialsValidateFailedError
+from core.model_runtime.model_providers.__base.model_provider import ModelProvider
+
+logger = logging.getLogger(__name__)
+
+
+class StepfunProvider(ModelProvider):
+
+    def validate_provider_credentials(self, credentials: dict) -> None:
+        """
+        Validate provider credentials
+        if validate failed, raise exception
+
+        :param credentials: provider credentials, credentials form defined in `provider_credential_schema`.
+        """
+        try:
+            model_instance = self.get_model_instance(ModelType.LLM)
+
+            model_instance.validate_credentials(
+                model='step-1-8k',
+                credentials=credentials
+            )
+        except CredentialsValidateFailedError as ex:
+            raise ex
+        except Exception as ex:
+            logger.exception(f'{self.get_provider_schema().provider} credentials validate failed')
+            raise ex

+ 81 - 0
api/core/model_runtime/model_providers/stepfun/stepfun.yaml

@@ -0,0 +1,81 @@
+provider: stepfun
+label:
+  zh_Hans: 阶跃星辰
+  en_US: Stepfun
+description:
+  en_US: Models provided by stepfun, such as step-1-8k, step-1-32k、step-1v-8k、step-1v-32k, step-1-128k and step-1-256k
+  zh_Hans: 阶跃星辰提供的模型,例如 step-1-8k、step-1-32k、step-1v-8k、step-1v-32k、step-1-128k 和 step-1-256k。
+icon_small:
+  en_US: icon_s_en.png
+icon_large:
+  en_US: icon_l_en.png
+background: "#FFFFFF"
+help:
+  title:
+    en_US: Get your API Key from stepfun
+    zh_Hans: 从 stepfun 获取 API Key
+  url:
+    en_US: https://platform.stepfun.com/interface-key
+supported_model_types:
+  - llm
+configurate_methods:
+  - predefined-model
+  - customizable-model
+provider_credential_schema:
+  credential_form_schemas:
+    - variable: api_key
+      label:
+        en_US: API Key
+      type: secret-input
+      required: true
+      placeholder:
+        zh_Hans: 在此输入您的 API Key
+        en_US: Enter your API Key
+model_credential_schema:
+  model:
+    label:
+      en_US: Model Name
+      zh_Hans: 模型名称
+    placeholder:
+      en_US: Enter your model name
+      zh_Hans: 输入模型名称
+  credential_form_schemas:
+    - variable: api_key
+      label:
+        en_US: API Key
+      type: secret-input
+      required: true
+      placeholder:
+        zh_Hans: 在此输入您的 API Key
+        en_US: Enter your API Key
+    - variable: context_size
+      label:
+        zh_Hans: 模型上下文长度
+        en_US: Model context size
+      required: true
+      type: text-input
+      default: '8192'
+      placeholder:
+        zh_Hans: 在此输入您的模型上下文长度
+        en_US: Enter your Model context size
+    - variable: max_tokens
+      label:
+        zh_Hans: 最大 token 上限
+        en_US: Upper bound for max tokens
+      default: '8192'
+      type: text-input
+    - variable: function_calling_type
+      label:
+        en_US: Function calling
+      type: select
+      required: false
+      default: no_call
+      options:
+        - value: no_call
+          label:
+            en_US: Not supported
+            zh_Hans: 不支持
+        - value: tool_call
+          label:
+            en_US: Tool Call
+            zh_Hans: Tool Call

+ 0 - 0
api/tests/integration_tests/model_runtime/stepfun/__init__.py


+ 176 - 0
api/tests/integration_tests/model_runtime/stepfun/test_llm.py

@@ -0,0 +1,176 @@
+import os
+from collections.abc import Generator
+
+import pytest
+
+from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
+from core.model_runtime.entities.message_entities import (
+    AssistantPromptMessage,
+    ImagePromptMessageContent,
+    PromptMessageTool,
+    SystemPromptMessage,
+    TextPromptMessageContent,
+    UserPromptMessage,
+)
+from core.model_runtime.entities.model_entities import AIModelEntity, ModelType
+from core.model_runtime.errors.validate import CredentialsValidateFailedError
+from core.model_runtime.model_providers.stepfun.llm.llm import StepfunLargeLanguageModel
+
+
+def test_validate_credentials():
+    model = StepfunLargeLanguageModel()
+
+    with pytest.raises(CredentialsValidateFailedError):
+        model.validate_credentials(
+            model='step-1-8k',
+            credentials={
+                'api_key': 'invalid_key'
+            }
+        )
+
+    model.validate_credentials(
+        model='step-1-8k',
+        credentials={
+            'api_key': os.environ.get('STEPFUN_API_KEY')
+        }
+    )
+
+def test_invoke_model():
+    model = StepfunLargeLanguageModel()
+
+    response = model.invoke(
+        model='step-1-8k',
+        credentials={
+            'api_key': os.environ.get('STEPFUN_API_KEY')
+        },
+        prompt_messages=[
+            UserPromptMessage(
+                content='Hello World!'
+            )
+        ],
+        model_parameters={
+            'temperature': 0.9,
+            'top_p': 0.7
+        },
+        stop=['Hi'],
+        stream=False,
+        user="abc-123"
+    )
+
+    assert isinstance(response, LLMResult)
+    assert len(response.message.content) > 0
+
+
+def test_invoke_stream_model():
+    model = StepfunLargeLanguageModel()
+
+    response = model.invoke(
+        model='step-1-8k',
+        credentials={
+            'api_key': os.environ.get('STEPFUN_API_KEY')
+        },
+        prompt_messages=[
+            SystemPromptMessage(
+                content='You are a helpful AI assistant.',
+            ),
+            UserPromptMessage(
+                content='Hello World!'
+            )
+        ],
+        model_parameters={
+            'temperature': 0.9,
+            'top_p': 0.7
+        },
+        stream=True,
+        user="abc-123"
+    )
+
+    assert isinstance(response, Generator)
+
+    for chunk in response:
+        assert isinstance(chunk, LLMResultChunk)
+        assert isinstance(chunk.delta, LLMResultChunkDelta)
+        assert isinstance(chunk.delta.message, AssistantPromptMessage)
+        assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True
+
+
+def test_get_customizable_model_schema():
+    model = StepfunLargeLanguageModel()
+
+    schema = model.get_customizable_model_schema(
+        model='step-1-8k',
+        credentials={
+            'api_key': os.environ.get('STEPFUN_API_KEY')
+        }
+    )
+    assert isinstance(schema, AIModelEntity)
+
+
+def test_invoke_chat_model_with_tools():
+    model = StepfunLargeLanguageModel()
+
+    result = model.invoke(
+        model='step-1-8k',
+        credentials={
+            'api_key': os.environ.get('STEPFUN_API_KEY')
+        },
+        prompt_messages=[
+            SystemPromptMessage(
+                content='You are a helpful AI assistant.',
+            ),
+            UserPromptMessage(
+                content="what's the weather today in Shanghai?",
+            )
+        ],
+        model_parameters={
+            'temperature': 0.9,
+            'max_tokens': 100
+        },
+        tools=[
+            PromptMessageTool(
+                name='get_weather',
+                description='Determine weather in my location',
+                parameters={
+                    "type": "object",
+                    "properties": {
+                      "location": {
+                        "type": "string",
+                        "description": "The city and state e.g. San Francisco, CA"
+                      },
+                      "unit": {
+                        "type": "string",
+                        "enum": [
+                          "c",
+                          "f"
+                        ]
+                      }
+                    },
+                    "required": [
+                      "location"
+                    ]
+                  }
+            ),
+            PromptMessageTool(
+                name='get_stock_price',
+                description='Get the current stock price',
+                parameters={
+                    "type": "object",
+                    "properties": {
+                      "symbol": {
+                        "type": "string",
+                        "description": "The stock symbol"
+                      }
+                    },
+                    "required": [
+                      "symbol"
+                    ]
+                  }
+            )
+        ],
+        stream=False,
+        user="abc-123"
+    )
+
+    assert isinstance(result, LLMResult)
+    assert isinstance(result.message, AssistantPromptMessage)
+    assert len(result.message.tool_calls) > 0