test_llm.py 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119
  1. import os
  2. from collections.abc import Generator
  3. import pytest
  4. from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
  5. from core.model_runtime.entities.message_entities import AssistantPromptMessage, UserPromptMessage
  6. from core.model_runtime.errors.validate import CredentialsValidateFailedError
  7. from core.model_runtime.model_providers.volcengine_maas.llm.llm import VolcengineMaaSLargeLanguageModel
  8. def test_validate_credentials_for_chat_model():
  9. model = VolcengineMaaSLargeLanguageModel()
  10. with pytest.raises(CredentialsValidateFailedError):
  11. model.validate_credentials(
  12. model="NOT IMPORTANT",
  13. credentials={
  14. "api_endpoint_host": "maas-api.ml-platform-cn-beijing.volces.com",
  15. "volc_region": "cn-beijing",
  16. "volc_access_key_id": "INVALID",
  17. "volc_secret_access_key": "INVALID",
  18. "endpoint_id": "INVALID",
  19. },
  20. )
  21. model.validate_credentials(
  22. model="NOT IMPORTANT",
  23. credentials={
  24. "api_endpoint_host": "maas-api.ml-platform-cn-beijing.volces.com",
  25. "volc_region": "cn-beijing",
  26. "volc_access_key_id": os.environ.get("VOLC_API_KEY"),
  27. "volc_secret_access_key": os.environ.get("VOLC_SECRET_KEY"),
  28. "endpoint_id": os.environ.get("VOLC_MODEL_ENDPOINT_ID"),
  29. },
  30. )
  31. def test_invoke_model():
  32. model = VolcengineMaaSLargeLanguageModel()
  33. response = model.invoke(
  34. model="NOT IMPORTANT",
  35. credentials={
  36. "api_endpoint_host": "maas-api.ml-platform-cn-beijing.volces.com",
  37. "volc_region": "cn-beijing",
  38. "volc_access_key_id": os.environ.get("VOLC_API_KEY"),
  39. "volc_secret_access_key": os.environ.get("VOLC_SECRET_KEY"),
  40. "endpoint_id": os.environ.get("VOLC_MODEL_ENDPOINT_ID"),
  41. "base_model_name": "Skylark2-pro-4k",
  42. },
  43. prompt_messages=[UserPromptMessage(content="Hello World!")],
  44. model_parameters={
  45. "temperature": 0.7,
  46. "top_p": 1.0,
  47. "top_k": 1,
  48. },
  49. stop=["you"],
  50. user="abc-123",
  51. stream=False,
  52. )
  53. assert isinstance(response, LLMResult)
  54. assert len(response.message.content) > 0
  55. assert response.usage.total_tokens > 0
  56. def test_invoke_stream_model():
  57. model = VolcengineMaaSLargeLanguageModel()
  58. response = model.invoke(
  59. model="NOT IMPORTANT",
  60. credentials={
  61. "api_endpoint_host": "maas-api.ml-platform-cn-beijing.volces.com",
  62. "volc_region": "cn-beijing",
  63. "volc_access_key_id": os.environ.get("VOLC_API_KEY"),
  64. "volc_secret_access_key": os.environ.get("VOLC_SECRET_KEY"),
  65. "endpoint_id": os.environ.get("VOLC_MODEL_ENDPOINT_ID"),
  66. "base_model_name": "Skylark2-pro-4k",
  67. },
  68. prompt_messages=[UserPromptMessage(content="Hello World!")],
  69. model_parameters={
  70. "temperature": 0.7,
  71. "top_p": 1.0,
  72. "top_k": 1,
  73. },
  74. stop=["you"],
  75. stream=True,
  76. user="abc-123",
  77. )
  78. assert isinstance(response, Generator)
  79. for chunk in response:
  80. assert isinstance(chunk, LLMResultChunk)
  81. assert isinstance(chunk.delta, LLMResultChunkDelta)
  82. assert isinstance(chunk.delta.message, AssistantPromptMessage)
  83. assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True
  84. def test_get_num_tokens():
  85. model = VolcengineMaaSLargeLanguageModel()
  86. response = model.get_num_tokens(
  87. model="NOT IMPORTANT",
  88. credentials={
  89. "api_endpoint_host": "maas-api.ml-platform-cn-beijing.volces.com",
  90. "volc_region": "cn-beijing",
  91. "volc_access_key_id": os.environ.get("VOLC_API_KEY"),
  92. "volc_secret_access_key": os.environ.get("VOLC_SECRET_KEY"),
  93. "endpoint_id": os.environ.get("VOLC_MODEL_ENDPOINT_ID"),
  94. "base_model_name": "Skylark2-pro-4k",
  95. },
  96. prompt_messages=[UserPromptMessage(content="Hello World!")],
  97. tools=[],
  98. )
  99. assert isinstance(response, int)
  100. assert response == 6