test_llm.py 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120
  1. import os
  2. from typing import Generator
  3. import pytest
  4. from core.model_runtime.entities.message_entities import SystemPromptMessage, UserPromptMessage, AssistantPromptMessage
  5. from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, \
  6. LLMResultChunkDelta
  7. from core.model_runtime.errors.validate import CredentialsValidateFailedError
  8. from core.model_runtime.model_providers.replicate.llm.llm import ReplicateLargeLanguageModel
  9. def test_validate_credentials():
  10. model = ReplicateLargeLanguageModel()
  11. with pytest.raises(CredentialsValidateFailedError):
  12. model.validate_credentials(
  13. model='meta/llama-2-13b-chat',
  14. credentials={
  15. 'replicate_api_token': 'invalid_key',
  16. 'model_version': 'f4e2de70d66816a838a89eeeb621910adffb0dd0baba3976c96980970978018d'
  17. }
  18. )
  19. model.validate_credentials(
  20. model='meta/llama-2-13b-chat',
  21. credentials={
  22. 'replicate_api_token': os.environ.get('REPLICATE_API_KEY'),
  23. 'model_version': 'f4e2de70d66816a838a89eeeb621910adffb0dd0baba3976c96980970978018d'
  24. }
  25. )
  26. def test_invoke_model():
  27. model = ReplicateLargeLanguageModel()
  28. response = model.invoke(
  29. model='meta/llama-2-13b-chat',
  30. credentials={
  31. 'replicate_api_token': os.environ.get('REPLICATE_API_KEY'),
  32. 'model_version': 'f4e2de70d66816a838a89eeeb621910adffb0dd0baba3976c96980970978018d'
  33. },
  34. prompt_messages=[
  35. SystemPromptMessage(
  36. content='You are a helpful AI assistant.',
  37. ),
  38. UserPromptMessage(
  39. content='Who are you?'
  40. )
  41. ],
  42. model_parameters={
  43. 'temperature': 1.0,
  44. 'top_k': 2,
  45. 'top_p': 0.5,
  46. },
  47. stop=['How'],
  48. stream=False,
  49. user="abc-123"
  50. )
  51. assert isinstance(response, LLMResult)
  52. assert len(response.message.content) > 0
  53. def test_invoke_stream_model():
  54. model = ReplicateLargeLanguageModel()
  55. response = model.invoke(
  56. model='mistralai/mixtral-8x7b-instruct-v0.1',
  57. credentials={
  58. 'replicate_api_token': os.environ.get('REPLICATE_API_KEY'),
  59. 'model_version': '2b56576fcfbe32fa0526897d8385dd3fb3d36ba6fd0dbe033c72886b81ade93e'
  60. },
  61. prompt_messages=[
  62. SystemPromptMessage(
  63. content='You are a helpful AI assistant.',
  64. ),
  65. UserPromptMessage(
  66. content='Who are you?'
  67. )
  68. ],
  69. model_parameters={
  70. 'temperature': 1.0,
  71. 'top_k': 2,
  72. 'top_p': 0.5,
  73. },
  74. stop=['How'],
  75. stream=True,
  76. user="abc-123"
  77. )
  78. assert isinstance(response, Generator)
  79. for chunk in response:
  80. assert isinstance(chunk, LLMResultChunk)
  81. assert isinstance(chunk.delta, LLMResultChunkDelta)
  82. assert isinstance(chunk.delta.message, AssistantPromptMessage)
  83. def test_get_num_tokens():
  84. model = ReplicateLargeLanguageModel()
  85. num_tokens = model.get_num_tokens(
  86. model='',
  87. credentials={
  88. 'replicate_api_token': os.environ.get('REPLICATE_API_KEY'),
  89. 'model_version': '2b56576fcfbe32fa0526897d8385dd3fb3d36ba6fd0dbe033c72886b81ade93e'
  90. },
  91. prompt_messages=[
  92. SystemPromptMessage(
  93. content='You are a helpful AI assistant.',
  94. ),
  95. UserPromptMessage(
  96. content='Hello World!'
  97. )
  98. ]
  99. )
  100. assert num_tokens == 14