test_llm.py 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124
  1. import os
  2. from collections.abc import Generator
  3. import pytest
  4. from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
  5. from core.model_runtime.entities.message_entities import (
  6. AssistantPromptMessage,
  7. PromptMessageTool,
  8. SystemPromptMessage,
  9. UserPromptMessage,
  10. )
  11. from core.model_runtime.entities.model_entities import AIModelEntity
  12. from core.model_runtime.errors.validate import CredentialsValidateFailedError
  13. from core.model_runtime.model_providers.stepfun.llm.llm import StepfunLargeLanguageModel
  14. def test_validate_credentials():
  15. model = StepfunLargeLanguageModel()
  16. with pytest.raises(CredentialsValidateFailedError):
  17. model.validate_credentials(model="step-1-8k", credentials={"api_key": "invalid_key"})
  18. model.validate_credentials(model="step-1-8k", credentials={"api_key": os.environ.get("STEPFUN_API_KEY")})
  19. def test_invoke_model():
  20. model = StepfunLargeLanguageModel()
  21. response = model.invoke(
  22. model="step-1-8k",
  23. credentials={"api_key": os.environ.get("STEPFUN_API_KEY")},
  24. prompt_messages=[UserPromptMessage(content="Hello World!")],
  25. model_parameters={"temperature": 0.9, "top_p": 0.7},
  26. stop=["Hi"],
  27. stream=False,
  28. user="abc-123",
  29. )
  30. assert isinstance(response, LLMResult)
  31. assert len(response.message.content) > 0
  32. def test_invoke_stream_model():
  33. model = StepfunLargeLanguageModel()
  34. response = model.invoke(
  35. model="step-1-8k",
  36. credentials={"api_key": os.environ.get("STEPFUN_API_KEY")},
  37. prompt_messages=[
  38. SystemPromptMessage(
  39. content="You are a helpful AI assistant.",
  40. ),
  41. UserPromptMessage(content="Hello World!"),
  42. ],
  43. model_parameters={"temperature": 0.9, "top_p": 0.7},
  44. stream=True,
  45. user="abc-123",
  46. )
  47. assert isinstance(response, Generator)
  48. for chunk in response:
  49. assert isinstance(chunk, LLMResultChunk)
  50. assert isinstance(chunk.delta, LLMResultChunkDelta)
  51. assert isinstance(chunk.delta.message, AssistantPromptMessage)
  52. assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True
  53. def test_get_customizable_model_schema():
  54. model = StepfunLargeLanguageModel()
  55. schema = model.get_customizable_model_schema(
  56. model="step-1-8k", credentials={"api_key": os.environ.get("STEPFUN_API_KEY")}
  57. )
  58. assert isinstance(schema, AIModelEntity)
  59. def test_invoke_chat_model_with_tools():
  60. model = StepfunLargeLanguageModel()
  61. result = model.invoke(
  62. model="step-1-8k",
  63. credentials={"api_key": os.environ.get("STEPFUN_API_KEY")},
  64. prompt_messages=[
  65. SystemPromptMessage(
  66. content="You are a helpful AI assistant.",
  67. ),
  68. UserPromptMessage(
  69. content="what's the weather today in Shanghai?",
  70. ),
  71. ],
  72. model_parameters={"temperature": 0.9, "max_tokens": 100},
  73. tools=[
  74. PromptMessageTool(
  75. name="get_weather",
  76. description="Determine weather in my location",
  77. parameters={
  78. "type": "object",
  79. "properties": {
  80. "location": {"type": "string", "description": "The city and state e.g. San Francisco, CA"},
  81. "unit": {"type": "string", "enum": ["c", "f"]},
  82. },
  83. "required": ["location"],
  84. },
  85. ),
  86. PromptMessageTool(
  87. name="get_stock_price",
  88. description="Get the current stock price",
  89. parameters={
  90. "type": "object",
  91. "properties": {"symbol": {"type": "string", "description": "The stock symbol"}},
  92. "required": ["symbol"],
  93. },
  94. ),
  95. ],
  96. stream=False,
  97. user="abc-123",
  98. )
  99. assert isinstance(result, LLMResult)
  100. assert isinstance(result.message, AssistantPromptMessage)
  101. assert len(result.message.tool_calls) > 0