test_llm.py 3.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120
  1. import os
  2. from collections.abc import Generator
  3. import pytest
  4. from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
  5. from core.model_runtime.entities.message_entities import AssistantPromptMessage, SystemPromptMessage, UserPromptMessage
  6. from core.model_runtime.errors.validate import CredentialsValidateFailedError
  7. from core.model_runtime.model_providers.bedrock.llm.llm import BedrockLargeLanguageModel
  8. def test_validate_credentials():
  9. model = BedrockLargeLanguageModel()
  10. with pytest.raises(CredentialsValidateFailedError):
  11. model.validate_credentials(
  12. model='meta.llama2-13b-chat-v1',
  13. credentials={
  14. 'anthropic_api_key': 'invalid_key'
  15. }
  16. )
  17. model.validate_credentials(
  18. model='meta.llama2-13b-chat-v1',
  19. credentials={
  20. "aws_region": os.getenv("AWS_REGION"),
  21. "aws_access_key": os.getenv("AWS_ACCESS_KEY"),
  22. "aws_secret_access_key": os.getenv("AWS_SECRET_ACCESS_KEY")
  23. }
  24. )
  25. def test_invoke_model():
  26. model = BedrockLargeLanguageModel()
  27. response = model.invoke(
  28. model='meta.llama2-13b-chat-v1',
  29. credentials={
  30. "aws_region": os.getenv("AWS_REGION"),
  31. "aws_access_key": os.getenv("AWS_ACCESS_KEY"),
  32. "aws_secret_access_key": os.getenv("AWS_SECRET_ACCESS_KEY")
  33. },
  34. prompt_messages=[
  35. SystemPromptMessage(
  36. content='You are a helpful AI assistant.',
  37. ),
  38. UserPromptMessage(
  39. content='Hello World!'
  40. )
  41. ],
  42. model_parameters={
  43. 'temperature': 0.0,
  44. 'top_p': 1.0,
  45. 'max_tokens_to_sample': 10
  46. },
  47. stop=['How'],
  48. stream=False,
  49. user="abc-123"
  50. )
  51. assert isinstance(response, LLMResult)
  52. assert len(response.message.content) > 0
  53. def test_invoke_stream_model():
  54. model = BedrockLargeLanguageModel()
  55. response = model.invoke(
  56. model='meta.llama2-13b-chat-v1',
  57. credentials={
  58. "aws_region": os.getenv("AWS_REGION"),
  59. "aws_access_key": os.getenv("AWS_ACCESS_KEY"),
  60. "aws_secret_access_key": os.getenv("AWS_SECRET_ACCESS_KEY")
  61. },
  62. prompt_messages=[
  63. SystemPromptMessage(
  64. content='You are a helpful AI assistant.',
  65. ),
  66. UserPromptMessage(
  67. content='Hello World!'
  68. )
  69. ],
  70. model_parameters={
  71. 'temperature': 0.0,
  72. 'max_tokens_to_sample': 100
  73. },
  74. stream=True,
  75. user="abc-123"
  76. )
  77. assert isinstance(response, Generator)
  78. for chunk in response:
  79. print(chunk)
  80. assert isinstance(chunk, LLMResultChunk)
  81. assert isinstance(chunk.delta, LLMResultChunkDelta)
  82. assert isinstance(chunk.delta.message, AssistantPromptMessage)
  83. assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True
  84. def test_get_num_tokens():
  85. model = BedrockLargeLanguageModel()
  86. num_tokens = model.get_num_tokens(
  87. model='meta.llama2-13b-chat-v1',
  88. credentials = {
  89. "aws_region": os.getenv("AWS_REGION"),
  90. "aws_access_key": os.getenv("AWS_ACCESS_KEY"),
  91. "aws_secret_access_key": os.getenv("AWS_SECRET_ACCESS_KEY")
  92. },
  93. messages=[
  94. SystemPromptMessage(
  95. content='You are a helpful AI assistant.',
  96. ),
  97. UserPromptMessage(
  98. content='Hello World!'
  99. )
  100. ]
  101. )
  102. assert num_tokens == 18