test_llm.py 3.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116
  1. import os
  2. from typing import Generator
  3. import pytest
  4. from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
  5. from core.model_runtime.entities.message_entities import (AssistantPromptMessage, PromptMessageTool,
  6. SystemPromptMessage, UserPromptMessage)
  7. from core.model_runtime.errors.validate import CredentialsValidateFailedError
  8. from core.model_runtime.model_providers.togetherai.llm.llm import TogetherAILargeLanguageModel
  9. def test_validate_credentials():
  10. model = TogetherAILargeLanguageModel()
  11. with pytest.raises(CredentialsValidateFailedError):
  12. model.validate_credentials(
  13. model='mistralai/Mixtral-8x7B-Instruct-v0.1',
  14. credentials={
  15. 'api_key': 'invalid_key',
  16. 'mode': 'chat'
  17. }
  18. )
  19. model.validate_credentials(
  20. model='mistralai/Mixtral-8x7B-Instruct-v0.1',
  21. credentials={
  22. 'api_key': os.environ.get('TOGETHER_API_KEY'),
  23. 'mode': 'chat'
  24. }
  25. )
  26. def test_invoke_model():
  27. model = TogetherAILargeLanguageModel()
  28. response = model.invoke(
  29. model='mistralai/Mixtral-8x7B-Instruct-v0.1',
  30. credentials={
  31. 'api_key': os.environ.get('TOGETHER_API_KEY'),
  32. 'mode': 'completion'
  33. },
  34. prompt_messages=[
  35. SystemPromptMessage(
  36. content='You are a helpful AI assistant.',
  37. ),
  38. UserPromptMessage(
  39. content='Who are you?'
  40. )
  41. ],
  42. model_parameters={
  43. 'temperature': 1.0,
  44. 'top_k': 2,
  45. 'top_p': 0.5,
  46. },
  47. stop=['How'],
  48. stream=False,
  49. user="abc-123"
  50. )
  51. assert isinstance(response, LLMResult)
  52. assert len(response.message.content) > 0
  53. def test_invoke_stream_model():
  54. model = TogetherAILargeLanguageModel()
  55. response = model.invoke(
  56. model='mistralai/Mixtral-8x7B-Instruct-v0.1',
  57. credentials={
  58. 'api_key': os.environ.get('TOGETHER_API_KEY'),
  59. 'mode': 'chat'
  60. },
  61. prompt_messages=[
  62. SystemPromptMessage(
  63. content='You are a helpful AI assistant.',
  64. ),
  65. UserPromptMessage(
  66. content='Who are you?'
  67. )
  68. ],
  69. model_parameters={
  70. 'temperature': 1.0,
  71. 'top_k': 2,
  72. 'top_p': 0.5,
  73. },
  74. stop=['How'],
  75. stream=True,
  76. user="abc-123"
  77. )
  78. assert isinstance(response, Generator)
  79. for chunk in response:
  80. assert isinstance(chunk, LLMResultChunk)
  81. assert isinstance(chunk.delta, LLMResultChunkDelta)
  82. assert isinstance(chunk.delta.message, AssistantPromptMessage)
  83. def test_get_num_tokens():
  84. model = TogetherAILargeLanguageModel()
  85. num_tokens = model.get_num_tokens(
  86. model='mistralai/Mixtral-8x7B-Instruct-v0.1',
  87. credentials={
  88. 'api_key': os.environ.get('TOGETHER_API_KEY'),
  89. },
  90. prompt_messages=[
  91. SystemPromptMessage(
  92. content='You are a helpful AI assistant.',
  93. ),
  94. UserPromptMessage(
  95. content='Hello World!'
  96. )
  97. ]
  98. )
  99. assert isinstance(num_tokens, int)
  100. assert num_tokens == 21