test_llm.py 3.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124
  1. import os
  2. from collections.abc import Generator
  3. import pytest
  4. from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
  5. from core.model_runtime.entities.message_entities import (
  6. AssistantPromptMessage,
  7. PromptMessageTool,
  8. SystemPromptMessage,
  9. UserPromptMessage,
  10. )
  11. from core.model_runtime.errors.validate import CredentialsValidateFailedError
  12. from core.model_runtime.model_providers.novita.llm.llm import NovitaLargeLanguageModel
  13. def test_validate_credentials():
  14. model = NovitaLargeLanguageModel()
  15. with pytest.raises(CredentialsValidateFailedError):
  16. model.validate_credentials(
  17. model='meta-llama/llama-3-8b-instruct',
  18. credentials={
  19. 'api_key': 'invalid_key',
  20. 'mode': 'chat'
  21. }
  22. )
  23. model.validate_credentials(
  24. model='meta-llama/llama-3-8b-instruct',
  25. credentials={
  26. 'api_key': os.environ.get('NOVITA_API_KEY'),
  27. 'mode': 'chat'
  28. }
  29. )
  30. def test_invoke_model():
  31. model = NovitaLargeLanguageModel()
  32. response = model.invoke(
  33. model='meta-llama/llama-3-8b-instruct',
  34. credentials={
  35. 'api_key': os.environ.get('NOVITA_API_KEY'),
  36. 'mode': 'completion'
  37. },
  38. prompt_messages=[
  39. SystemPromptMessage(
  40. content='You are a helpful AI assistant.',
  41. ),
  42. UserPromptMessage(
  43. content='Who are you?'
  44. )
  45. ],
  46. model_parameters={
  47. 'temperature': 1.0,
  48. 'top_p': 0.5,
  49. 'max_tokens': 10,
  50. },
  51. stop=['How'],
  52. stream=False,
  53. user="novita"
  54. )
  55. assert isinstance(response, LLMResult)
  56. assert len(response.message.content) > 0
  57. def test_invoke_stream_model():
  58. model = NovitaLargeLanguageModel()
  59. response = model.invoke(
  60. model='meta-llama/llama-3-8b-instruct',
  61. credentials={
  62. 'api_key': os.environ.get('NOVITA_API_KEY'),
  63. 'mode': 'chat'
  64. },
  65. prompt_messages=[
  66. SystemPromptMessage(
  67. content='You are a helpful AI assistant.',
  68. ),
  69. UserPromptMessage(
  70. content='Who are you?'
  71. )
  72. ],
  73. model_parameters={
  74. 'temperature': 1.0,
  75. 'top_k': 2,
  76. 'top_p': 0.5,
  77. 'max_tokens': 100
  78. },
  79. stream=True,
  80. user="novita"
  81. )
  82. assert isinstance(response, Generator)
  83. for chunk in response:
  84. assert isinstance(chunk, LLMResultChunk)
  85. assert isinstance(chunk.delta, LLMResultChunkDelta)
  86. assert isinstance(chunk.delta.message, AssistantPromptMessage)
  87. def test_get_num_tokens():
  88. model = NovitaLargeLanguageModel()
  89. num_tokens = model.get_num_tokens(
  90. model='meta-llama/llama-3-8b-instruct',
  91. credentials={
  92. 'api_key': os.environ.get('NOVITA_API_KEY'),
  93. },
  94. prompt_messages=[
  95. SystemPromptMessage(
  96. content='You are a helpful AI assistant.',
  97. ),
  98. UserPromptMessage(
  99. content='Hello World!'
  100. )
  101. ]
  102. )
  103. assert isinstance(num_tokens, int)
  104. assert num_tokens == 21