test_llm.py 3.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115
  1. import os
  2. from typing import Generator
  3. import pytest
  4. from core.model_runtime.entities.message_entities import SystemPromptMessage, UserPromptMessage, AssistantPromptMessage
  5. from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, \
  6. LLMResultChunkDelta
  7. from core.model_runtime.errors.validate import CredentialsValidateFailedError
  8. from core.model_runtime.model_providers.spark.llm.llm import SparkLargeLanguageModel
  9. def test_validate_credentials():
  10. model = SparkLargeLanguageModel()
  11. with pytest.raises(CredentialsValidateFailedError):
  12. model.validate_credentials(
  13. model='spark-1.5',
  14. credentials={
  15. 'app_id': 'invalid_key'
  16. }
  17. )
  18. model.validate_credentials(
  19. model='spark-1.5',
  20. credentials={
  21. 'app_id': os.environ.get('SPARK_APP_ID'),
  22. 'api_secret': os.environ.get('SPARK_API_SECRET'),
  23. 'api_key': os.environ.get('SPARK_API_KEY')
  24. }
  25. )
  26. def test_invoke_model():
  27. model = SparkLargeLanguageModel()
  28. response = model.invoke(
  29. model='spark-1.5',
  30. credentials={
  31. 'app_id': os.environ.get('SPARK_APP_ID'),
  32. 'api_secret': os.environ.get('SPARK_API_SECRET'),
  33. 'api_key': os.environ.get('SPARK_API_KEY')
  34. },
  35. prompt_messages=[
  36. UserPromptMessage(
  37. content='Who are you?'
  38. )
  39. ],
  40. model_parameters={
  41. 'temperature': 0.5,
  42. 'max_tokens': 10
  43. },
  44. stop=['How'],
  45. stream=False,
  46. user="abc-123"
  47. )
  48. assert isinstance(response, LLMResult)
  49. assert len(response.message.content) > 0
  50. def test_invoke_stream_model():
  51. model = SparkLargeLanguageModel()
  52. response = model.invoke(
  53. model='spark-1.5',
  54. credentials={
  55. 'app_id': os.environ.get('SPARK_APP_ID'),
  56. 'api_secret': os.environ.get('SPARK_API_SECRET'),
  57. 'api_key': os.environ.get('SPARK_API_KEY')
  58. },
  59. prompt_messages=[
  60. UserPromptMessage(
  61. content='Hello World!'
  62. )
  63. ],
  64. model_parameters={
  65. 'temperature': 0.5,
  66. 'max_tokens': 100
  67. },
  68. stream=True,
  69. user="abc-123"
  70. )
  71. assert isinstance(response, Generator)
  72. for chunk in response:
  73. assert isinstance(chunk, LLMResultChunk)
  74. assert isinstance(chunk.delta, LLMResultChunkDelta)
  75. assert isinstance(chunk.delta.message, AssistantPromptMessage)
  76. assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True
  77. def test_get_num_tokens():
  78. model = SparkLargeLanguageModel()
  79. num_tokens = model.get_num_tokens(
  80. model='spark-1.5',
  81. credentials={
  82. 'app_id': os.environ.get('SPARK_APP_ID'),
  83. 'api_secret': os.environ.get('SPARK_API_SECRET'),
  84. 'api_key': os.environ.get('SPARK_API_KEY')
  85. },
  86. prompt_messages=[
  87. SystemPromptMessage(
  88. content='You are a helpful AI assistant.',
  89. ),
  90. UserPromptMessage(
  91. content='Hello World!'
  92. )
  93. ]
  94. )
  95. assert num_tokens == 14