test_llm.py 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130
  1. import os
  2. from collections.abc import Generator
  3. import pytest
  4. from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
  5. from core.model_runtime.entities.message_entities import (
  6. AssistantPromptMessage,
  7. PromptMessageTool,
  8. SystemPromptMessage,
  9. UserPromptMessage,
  10. )
  11. from core.model_runtime.errors.validate import CredentialsValidateFailedError
  12. from core.model_runtime.model_providers.oci.llm.llm import OCILargeLanguageModel
  13. def test_validate_credentials():
  14. model = OCILargeLanguageModel()
  15. with pytest.raises(CredentialsValidateFailedError):
  16. model.validate_credentials(
  17. model="cohere.command-r-plus",
  18. credentials={"oci_config_content": "invalid_key", "oci_key_content": "invalid_key"},
  19. )
  20. model.validate_credentials(
  21. model="cohere.command-r-plus",
  22. credentials={
  23. "oci_config_content": os.environ.get("OCI_CONFIG_CONTENT"),
  24. "oci_key_content": os.environ.get("OCI_KEY_CONTENT"),
  25. },
  26. )
  27. def test_invoke_model():
  28. model = OCILargeLanguageModel()
  29. response = model.invoke(
  30. model="cohere.command-r-plus",
  31. credentials={
  32. "oci_config_content": os.environ.get("OCI_CONFIG_CONTENT"),
  33. "oci_key_content": os.environ.get("OCI_KEY_CONTENT"),
  34. },
  35. prompt_messages=[UserPromptMessage(content="Hi")],
  36. model_parameters={"temperature": 0.5, "max_tokens": 10},
  37. stop=["How"],
  38. stream=False,
  39. user="abc-123",
  40. )
  41. assert isinstance(response, LLMResult)
  42. assert len(response.message.content) > 0
  43. def test_invoke_stream_model():
  44. model = OCILargeLanguageModel()
  45. response = model.invoke(
  46. model="meta.llama-3-70b-instruct",
  47. credentials={
  48. "oci_config_content": os.environ.get("OCI_CONFIG_CONTENT"),
  49. "oci_key_content": os.environ.get("OCI_KEY_CONTENT"),
  50. },
  51. prompt_messages=[UserPromptMessage(content="Hi")],
  52. model_parameters={"temperature": 0.5, "max_tokens": 100, "seed": 1234},
  53. stream=True,
  54. user="abc-123",
  55. )
  56. assert isinstance(response, Generator)
  57. for chunk in response:
  58. assert isinstance(chunk, LLMResultChunk)
  59. assert isinstance(chunk.delta, LLMResultChunkDelta)
  60. assert isinstance(chunk.delta.message, AssistantPromptMessage)
  61. assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True
  62. def test_invoke_model_with_function():
  63. model = OCILargeLanguageModel()
  64. response = model.invoke(
  65. model="cohere.command-r-plus",
  66. credentials={
  67. "oci_config_content": os.environ.get("OCI_CONFIG_CONTENT"),
  68. "oci_key_content": os.environ.get("OCI_KEY_CONTENT"),
  69. },
  70. prompt_messages=[UserPromptMessage(content="Hi")],
  71. model_parameters={"temperature": 0.5, "max_tokens": 100, "seed": 1234},
  72. stream=False,
  73. user="abc-123",
  74. tools=[
  75. PromptMessageTool(
  76. name="get_current_weather",
  77. description="Get the current weather in a given location",
  78. parameters={
  79. "type": "object",
  80. "properties": {
  81. "location": {"type": "string", "description": "The city and state e.g. San Francisco, CA"},
  82. "unit": {"type": "string", "enum": ["celsius", "fahrenheit"]},
  83. },
  84. "required": ["location"],
  85. },
  86. )
  87. ],
  88. )
  89. assert isinstance(response, LLMResult)
  90. assert len(response.message.content) > 0
  91. def test_get_num_tokens():
  92. model = OCILargeLanguageModel()
  93. num_tokens = model.get_num_tokens(
  94. model="cohere.command-r-plus",
  95. credentials={
  96. "oci_config_content": os.environ.get("OCI_CONFIG_CONTENT"),
  97. "oci_key_content": os.environ.get("OCI_KEY_CONTENT"),
  98. },
  99. prompt_messages=[
  100. SystemPromptMessage(
  101. content="You are a helpful AI assistant.",
  102. ),
  103. UserPromptMessage(content="Hello World!"),
  104. ],
  105. )
  106. assert num_tokens == 18