test_llm.py 3.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118
  1. import os
  2. from typing import Generator
  3. import pytest
  4. from core.model_runtime.entities.message_entities import AssistantPromptMessage, UserPromptMessage, \
  5. SystemPromptMessage, PromptMessageTool
  6. from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunkDelta, \
  7. LLMResultChunk
  8. from core.model_runtime.errors.validate import CredentialsValidateFailedError
  9. from core.model_runtime.model_providers.togetherai.llm.llm import TogetherAILargeLanguageModel
  10. def test_validate_credentials():
  11. model = TogetherAILargeLanguageModel()
  12. with pytest.raises(CredentialsValidateFailedError):
  13. model.validate_credentials(
  14. model='mistralai/Mixtral-8x7B-Instruct-v0.1',
  15. credentials={
  16. 'api_key': 'invalid_key',
  17. 'mode': 'chat'
  18. }
  19. )
  20. model.validate_credentials(
  21. model='mistralai/Mixtral-8x7B-Instruct-v0.1',
  22. credentials={
  23. 'api_key': os.environ.get('TOGETHER_API_KEY'),
  24. 'mode': 'chat'
  25. }
  26. )
  27. def test_invoke_model():
  28. model = TogetherAILargeLanguageModel()
  29. response = model.invoke(
  30. model='mistralai/Mixtral-8x7B-Instruct-v0.1',
  31. credentials={
  32. 'api_key': os.environ.get('TOGETHER_API_KEY'),
  33. 'mode': 'completion'
  34. },
  35. prompt_messages=[
  36. SystemPromptMessage(
  37. content='You are a helpful AI assistant.',
  38. ),
  39. UserPromptMessage(
  40. content='Who are you?'
  41. )
  42. ],
  43. model_parameters={
  44. 'temperature': 1.0,
  45. 'top_k': 2,
  46. 'top_p': 0.5,
  47. },
  48. stop=['How'],
  49. stream=False,
  50. user="abc-123"
  51. )
  52. assert isinstance(response, LLMResult)
  53. assert len(response.message.content) > 0
  54. def test_invoke_stream_model():
  55. model = TogetherAILargeLanguageModel()
  56. response = model.invoke(
  57. model='mistralai/Mixtral-8x7B-Instruct-v0.1',
  58. credentials={
  59. 'api_key': os.environ.get('TOGETHER_API_KEY'),
  60. 'mode': 'chat'
  61. },
  62. prompt_messages=[
  63. SystemPromptMessage(
  64. content='You are a helpful AI assistant.',
  65. ),
  66. UserPromptMessage(
  67. content='Who are you?'
  68. )
  69. ],
  70. model_parameters={
  71. 'temperature': 1.0,
  72. 'top_k': 2,
  73. 'top_p': 0.5,
  74. },
  75. stop=['How'],
  76. stream=True,
  77. user="abc-123"
  78. )
  79. assert isinstance(response, Generator)
  80. for chunk in response:
  81. assert isinstance(chunk, LLMResultChunk)
  82. assert isinstance(chunk.delta, LLMResultChunkDelta)
  83. assert isinstance(chunk.delta.message, AssistantPromptMessage)
  84. def test_get_num_tokens():
  85. model = TogetherAILargeLanguageModel()
  86. num_tokens = model.get_num_tokens(
  87. model='mistralai/Mixtral-8x7B-Instruct-v0.1',
  88. credentials={
  89. 'api_key': os.environ.get('TOGETHER_API_KEY'),
  90. },
  91. prompt_messages=[
  92. SystemPromptMessage(
  93. content='You are a helpful AI assistant.',
  94. ),
  95. UserPromptMessage(
  96. content='Hello World!'
  97. )
  98. ]
  99. )
  100. assert isinstance(num_tokens, int)
  101. assert num_tokens == 21