test_llm.py 6.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192
  1. import os
  2. from collections.abc import Generator
  3. import pytest
  4. from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
  5. from core.model_runtime.entities.message_entities import AssistantPromptMessage, SystemPromptMessage, UserPromptMessage
  6. from core.model_runtime.errors.validate import CredentialsValidateFailedError
  7. from core.model_runtime.model_providers.cohere.llm.llm import CohereLargeLanguageModel
  8. def test_validate_credentials_for_chat_model():
  9. model = CohereLargeLanguageModel()
  10. with pytest.raises(CredentialsValidateFailedError):
  11. model.validate_credentials(model="command-light-chat", credentials={"api_key": "invalid_key"})
  12. model.validate_credentials(model="command-light-chat", credentials={"api_key": os.environ.get("COHERE_API_KEY")})
  13. def test_validate_credentials_for_completion_model():
  14. model = CohereLargeLanguageModel()
  15. with pytest.raises(CredentialsValidateFailedError):
  16. model.validate_credentials(model="command-light", credentials={"api_key": "invalid_key"})
  17. model.validate_credentials(model="command-light", credentials={"api_key": os.environ.get("COHERE_API_KEY")})
  18. def test_invoke_completion_model():
  19. model = CohereLargeLanguageModel()
  20. credentials = {"api_key": os.environ.get("COHERE_API_KEY")}
  21. result = model.invoke(
  22. model="command-light",
  23. credentials=credentials,
  24. prompt_messages=[UserPromptMessage(content="Hello World!")],
  25. model_parameters={"temperature": 0.0, "max_tokens": 1},
  26. stream=False,
  27. user="abc-123",
  28. )
  29. assert isinstance(result, LLMResult)
  30. assert len(result.message.content) > 0
  31. assert model._num_tokens_from_string("command-light", credentials, result.message.content) == 1
  32. def test_invoke_stream_completion_model():
  33. model = CohereLargeLanguageModel()
  34. result = model.invoke(
  35. model="command-light",
  36. credentials={"api_key": os.environ.get("COHERE_API_KEY")},
  37. prompt_messages=[UserPromptMessage(content="Hello World!")],
  38. model_parameters={"temperature": 0.0, "max_tokens": 100},
  39. stream=True,
  40. user="abc-123",
  41. )
  42. assert isinstance(result, Generator)
  43. for chunk in result:
  44. assert isinstance(chunk, LLMResultChunk)
  45. assert isinstance(chunk.delta, LLMResultChunkDelta)
  46. assert isinstance(chunk.delta.message, AssistantPromptMessage)
  47. assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True
  48. def test_invoke_chat_model():
  49. model = CohereLargeLanguageModel()
  50. result = model.invoke(
  51. model="command-light-chat",
  52. credentials={"api_key": os.environ.get("COHERE_API_KEY")},
  53. prompt_messages=[
  54. SystemPromptMessage(
  55. content="You are a helpful AI assistant.",
  56. ),
  57. UserPromptMessage(content="Hello World!"),
  58. ],
  59. model_parameters={
  60. "temperature": 0.0,
  61. "p": 0.99,
  62. "presence_penalty": 0.0,
  63. "frequency_penalty": 0.0,
  64. "max_tokens": 10,
  65. },
  66. stop=["How"],
  67. stream=False,
  68. user="abc-123",
  69. )
  70. assert isinstance(result, LLMResult)
  71. assert len(result.message.content) > 0
  72. def test_invoke_stream_chat_model():
  73. model = CohereLargeLanguageModel()
  74. result = model.invoke(
  75. model="command-light-chat",
  76. credentials={"api_key": os.environ.get("COHERE_API_KEY")},
  77. prompt_messages=[
  78. SystemPromptMessage(
  79. content="You are a helpful AI assistant.",
  80. ),
  81. UserPromptMessage(content="Hello World!"),
  82. ],
  83. model_parameters={"temperature": 0.0, "max_tokens": 100},
  84. stream=True,
  85. user="abc-123",
  86. )
  87. assert isinstance(result, Generator)
  88. for chunk in result:
  89. assert isinstance(chunk, LLMResultChunk)
  90. assert isinstance(chunk.delta, LLMResultChunkDelta)
  91. assert isinstance(chunk.delta.message, AssistantPromptMessage)
  92. assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True
  93. if chunk.delta.finish_reason is not None:
  94. assert chunk.delta.usage is not None
  95. assert chunk.delta.usage.completion_tokens > 0
  96. def test_get_num_tokens():
  97. model = CohereLargeLanguageModel()
  98. num_tokens = model.get_num_tokens(
  99. model="command-light",
  100. credentials={"api_key": os.environ.get("COHERE_API_KEY")},
  101. prompt_messages=[UserPromptMessage(content="Hello World!")],
  102. )
  103. assert num_tokens == 3
  104. num_tokens = model.get_num_tokens(
  105. model="command-light-chat",
  106. credentials={"api_key": os.environ.get("COHERE_API_KEY")},
  107. prompt_messages=[
  108. SystemPromptMessage(
  109. content="You are a helpful AI assistant.",
  110. ),
  111. UserPromptMessage(content="Hello World!"),
  112. ],
  113. )
  114. assert num_tokens == 15
  115. def test_fine_tuned_model():
  116. model = CohereLargeLanguageModel()
  117. # test invoke
  118. result = model.invoke(
  119. model="85ec47be-6139-4f75-a4be-0f0ec1ef115c-ft",
  120. credentials={"api_key": os.environ.get("COHERE_API_KEY"), "mode": "completion"},
  121. prompt_messages=[
  122. SystemPromptMessage(
  123. content="You are a helpful AI assistant.",
  124. ),
  125. UserPromptMessage(content="Hello World!"),
  126. ],
  127. model_parameters={"temperature": 0.0, "max_tokens": 100},
  128. stream=False,
  129. user="abc-123",
  130. )
  131. assert isinstance(result, LLMResult)
  132. def test_fine_tuned_chat_model():
  133. model = CohereLargeLanguageModel()
  134. # test invoke
  135. result = model.invoke(
  136. model="94f2d55a-4c79-4c00-bde4-23962e74b170-ft",
  137. credentials={"api_key": os.environ.get("COHERE_API_KEY"), "mode": "chat"},
  138. prompt_messages=[
  139. SystemPromptMessage(
  140. content="You are a helpful AI assistant.",
  141. ),
  142. UserPromptMessage(content="Hello World!"),
  143. ],
  144. model_parameters={"temperature": 0.0, "max_tokens": 100},
  145. stream=False,
  146. user="abc-123",
  147. )
  148. assert isinstance(result, LLMResult)