test_llm.py 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158
  1. import os
  2. import pytest
  3. from typing import Generator
  4. from time import sleep
  5. from core.model_runtime.entities.message_entities import AssistantPromptMessage, UserPromptMessage
  6. from core.model_runtime.entities.model_entities import AIModelEntity
  7. from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunkDelta, \
  8. LLMResultChunk
  9. from core.model_runtime.errors.validate import CredentialsValidateFailedError
  10. from core.model_runtime.model_providers.minimax.llm.llm import MinimaxLargeLanguageModel
  11. def test_predefined_models():
  12. model = MinimaxLargeLanguageModel()
  13. model_schemas = model.predefined_models()
  14. assert len(model_schemas) >= 1
  15. assert isinstance(model_schemas[0], AIModelEntity)
  16. def test_validate_credentials_for_chat_model():
  17. sleep(3)
  18. model = MinimaxLargeLanguageModel()
  19. with pytest.raises(CredentialsValidateFailedError):
  20. model.validate_credentials(
  21. model='abab5.5-chat',
  22. credentials={
  23. 'minimax_api_key': 'invalid_key',
  24. 'minimax_group_id': 'invalid_key'
  25. }
  26. )
  27. model.validate_credentials(
  28. model='abab5.5-chat',
  29. credentials={
  30. 'minimax_api_key': os.environ.get('MINIMAX_API_KEY'),
  31. 'minimax_group_id': os.environ.get('MINIMAX_GROUP_ID')
  32. }
  33. )
  34. def test_invoke_model():
  35. sleep(3)
  36. model = MinimaxLargeLanguageModel()
  37. response = model.invoke(
  38. model='abab5-chat',
  39. credentials={
  40. 'minimax_api_key': os.environ.get('MINIMAX_API_KEY'),
  41. 'minimax_group_id': os.environ.get('MINIMAX_GROUP_ID')
  42. },
  43. prompt_messages=[
  44. UserPromptMessage(
  45. content='Hello World!'
  46. )
  47. ],
  48. model_parameters={
  49. 'temperature': 0.7,
  50. 'top_p': 1.0,
  51. 'top_k': 1,
  52. },
  53. stop=['you'],
  54. user="abc-123",
  55. stream=False
  56. )
  57. assert isinstance(response, LLMResult)
  58. assert len(response.message.content) > 0
  59. assert response.usage.total_tokens > 0
  60. def test_invoke_stream_model():
  61. sleep(3)
  62. model = MinimaxLargeLanguageModel()
  63. response = model.invoke(
  64. model='abab5.5-chat',
  65. credentials={
  66. 'minimax_api_key': os.environ.get('MINIMAX_API_KEY'),
  67. 'minimax_group_id': os.environ.get('MINIMAX_GROUP_ID')
  68. },
  69. prompt_messages=[
  70. UserPromptMessage(
  71. content='Hello World!'
  72. )
  73. ],
  74. model_parameters={
  75. 'temperature': 0.7,
  76. 'top_p': 1.0,
  77. 'top_k': 1,
  78. },
  79. stop=['you'],
  80. stream=True,
  81. user="abc-123"
  82. )
  83. assert isinstance(response, Generator)
  84. for chunk in response:
  85. assert isinstance(chunk, LLMResultChunk)
  86. assert isinstance(chunk.delta, LLMResultChunkDelta)
  87. assert isinstance(chunk.delta.message, AssistantPromptMessage)
  88. assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True
  89. def test_invoke_with_search():
  90. sleep(3)
  91. model = MinimaxLargeLanguageModel()
  92. response = model.invoke(
  93. model='abab5.5-chat',
  94. credentials={
  95. 'minimax_api_key': os.environ.get('MINIMAX_API_KEY'),
  96. 'minimax_group_id': os.environ.get('MINIMAX_GROUP_ID')
  97. },
  98. prompt_messages=[
  99. UserPromptMessage(
  100. content='北京今天的天气怎么样'
  101. )
  102. ],
  103. model_parameters={
  104. 'temperature': 0.7,
  105. 'top_p': 1.0,
  106. 'top_k': 1,
  107. 'plugin_web_search': True,
  108. },
  109. stop=['you'],
  110. stream=True,
  111. user="abc-123"
  112. )
  113. assert isinstance(response, Generator)
  114. total_message = ''
  115. for chunk in response:
  116. assert isinstance(chunk, LLMResultChunk)
  117. assert isinstance(chunk.delta, LLMResultChunkDelta)
  118. assert isinstance(chunk.delta.message, AssistantPromptMessage)
  119. total_message += chunk.delta.message.content
  120. assert len(chunk.delta.message.content) > 0 if not chunk.delta.finish_reason else True
  121. assert '参考资料' in total_message
  122. def test_get_num_tokens():
  123. sleep(3)
  124. model = MinimaxLargeLanguageModel()
  125. response = model.get_num_tokens(
  126. model='abab5.5-chat',
  127. credentials={
  128. 'minimax_api_key': os.environ.get('MINIMAX_API_KEY'),
  129. 'minimax_group_id': os.environ.get('MINIMAX_GROUP_ID')
  130. },
  131. prompt_messages=[
  132. UserPromptMessage(
  133. content='Hello World!'
  134. )
  135. ],
  136. tools=[]
  137. )
  138. assert isinstance(response, int)
  139. assert response == 30