test_llm.py 6.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267
  1. import os
  2. from collections.abc import Generator
  3. import pytest
  4. from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
  5. from core.model_runtime.entities.message_entities import AssistantPromptMessage, SystemPromptMessage, UserPromptMessage
  6. from core.model_runtime.errors.validate import CredentialsValidateFailedError
  7. from core.model_runtime.model_providers.cohere.llm.llm import CohereLargeLanguageModel
  8. def test_validate_credentials_for_chat_model():
  9. model = CohereLargeLanguageModel()
  10. with pytest.raises(CredentialsValidateFailedError):
  11. model.validate_credentials(
  12. model='command-light-chat',
  13. credentials={
  14. 'api_key': 'invalid_key'
  15. }
  16. )
  17. model.validate_credentials(
  18. model='command-light-chat',
  19. credentials={
  20. 'api_key': os.environ.get('COHERE_API_KEY')
  21. }
  22. )
  23. def test_validate_credentials_for_completion_model():
  24. model = CohereLargeLanguageModel()
  25. with pytest.raises(CredentialsValidateFailedError):
  26. model.validate_credentials(
  27. model='command-light',
  28. credentials={
  29. 'api_key': 'invalid_key'
  30. }
  31. )
  32. model.validate_credentials(
  33. model='command-light',
  34. credentials={
  35. 'api_key': os.environ.get('COHERE_API_KEY')
  36. }
  37. )
  38. def test_invoke_completion_model():
  39. model = CohereLargeLanguageModel()
  40. credentials = {
  41. 'api_key': os.environ.get('COHERE_API_KEY')
  42. }
  43. result = model.invoke(
  44. model='command-light',
  45. credentials=credentials,
  46. prompt_messages=[
  47. UserPromptMessage(
  48. content='Hello World!'
  49. )
  50. ],
  51. model_parameters={
  52. 'temperature': 0.0,
  53. 'max_tokens': 1
  54. },
  55. stream=False,
  56. user="abc-123"
  57. )
  58. assert isinstance(result, LLMResult)
  59. assert len(result.message.content) > 0
  60. assert model._num_tokens_from_string('command-light', credentials, result.message.content) == 1
  61. def test_invoke_stream_completion_model():
  62. model = CohereLargeLanguageModel()
  63. result = model.invoke(
  64. model='command-light',
  65. credentials={
  66. 'api_key': os.environ.get('COHERE_API_KEY')
  67. },
  68. prompt_messages=[
  69. UserPromptMessage(
  70. content='Hello World!'
  71. )
  72. ],
  73. model_parameters={
  74. 'temperature': 0.0,
  75. 'max_tokens': 100
  76. },
  77. stream=True,
  78. user="abc-123"
  79. )
  80. assert isinstance(result, Generator)
  81. for chunk in result:
  82. assert isinstance(chunk, LLMResultChunk)
  83. assert isinstance(chunk.delta, LLMResultChunkDelta)
  84. assert isinstance(chunk.delta.message, AssistantPromptMessage)
  85. assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True
  86. def test_invoke_chat_model():
  87. model = CohereLargeLanguageModel()
  88. result = model.invoke(
  89. model='command-light-chat',
  90. credentials={
  91. 'api_key': os.environ.get('COHERE_API_KEY')
  92. },
  93. prompt_messages=[
  94. SystemPromptMessage(
  95. content='You are a helpful AI assistant.',
  96. ),
  97. UserPromptMessage(
  98. content='Hello World!'
  99. )
  100. ],
  101. model_parameters={
  102. 'temperature': 0.0,
  103. 'p': 0.99,
  104. 'presence_penalty': 0.0,
  105. 'frequency_penalty': 0.0,
  106. 'max_tokens': 10
  107. },
  108. stop=['How'],
  109. stream=False,
  110. user="abc-123"
  111. )
  112. assert isinstance(result, LLMResult)
  113. assert len(result.message.content) > 0
  114. def test_invoke_stream_chat_model():
  115. model = CohereLargeLanguageModel()
  116. result = model.invoke(
  117. model='command-light-chat',
  118. credentials={
  119. 'api_key': os.environ.get('COHERE_API_KEY')
  120. },
  121. prompt_messages=[
  122. SystemPromptMessage(
  123. content='You are a helpful AI assistant.',
  124. ),
  125. UserPromptMessage(
  126. content='Hello World!'
  127. )
  128. ],
  129. model_parameters={
  130. 'temperature': 0.0,
  131. 'max_tokens': 100
  132. },
  133. stream=True,
  134. user="abc-123"
  135. )
  136. assert isinstance(result, Generator)
  137. for chunk in result:
  138. assert isinstance(chunk, LLMResultChunk)
  139. assert isinstance(chunk.delta, LLMResultChunkDelta)
  140. assert isinstance(chunk.delta.message, AssistantPromptMessage)
  141. assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True
  142. if chunk.delta.finish_reason is not None:
  143. assert chunk.delta.usage is not None
  144. assert chunk.delta.usage.completion_tokens > 0
  145. def test_get_num_tokens():
  146. model = CohereLargeLanguageModel()
  147. num_tokens = model.get_num_tokens(
  148. model='command-light',
  149. credentials={
  150. 'api_key': os.environ.get('COHERE_API_KEY')
  151. },
  152. prompt_messages=[
  153. UserPromptMessage(
  154. content='Hello World!'
  155. )
  156. ]
  157. )
  158. assert num_tokens == 3
  159. num_tokens = model.get_num_tokens(
  160. model='command-light-chat',
  161. credentials={
  162. 'api_key': os.environ.get('COHERE_API_KEY')
  163. },
  164. prompt_messages=[
  165. SystemPromptMessage(
  166. content='You are a helpful AI assistant.',
  167. ),
  168. UserPromptMessage(
  169. content='Hello World!'
  170. )
  171. ]
  172. )
  173. assert num_tokens == 15
  174. def test_fine_tuned_model():
  175. model = CohereLargeLanguageModel()
  176. # test invoke
  177. result = model.invoke(
  178. model='85ec47be-6139-4f75-a4be-0f0ec1ef115c-ft',
  179. credentials={
  180. 'api_key': os.environ.get('COHERE_API_KEY'),
  181. 'mode': 'completion'
  182. },
  183. prompt_messages=[
  184. SystemPromptMessage(
  185. content='You are a helpful AI assistant.',
  186. ),
  187. UserPromptMessage(
  188. content='Hello World!'
  189. )
  190. ],
  191. model_parameters={
  192. 'temperature': 0.0,
  193. 'max_tokens': 100
  194. },
  195. stream=False,
  196. user="abc-123"
  197. )
  198. assert isinstance(result, LLMResult)
  199. def test_fine_tuned_chat_model():
  200. model = CohereLargeLanguageModel()
  201. # test invoke
  202. result = model.invoke(
  203. model='94f2d55a-4c79-4c00-bde4-23962e74b170-ft',
  204. credentials={
  205. 'api_key': os.environ.get('COHERE_API_KEY'),
  206. 'mode': 'chat'
  207. },
  208. prompt_messages=[
  209. SystemPromptMessage(
  210. content='You are a helpful AI assistant.',
  211. ),
  212. UserPromptMessage(
  213. content='Hello World!'
  214. )
  215. ],
  216. model_parameters={
  217. 'temperature': 0.0,
  218. 'max_tokens': 100
  219. },
  220. stream=False,
  221. user="abc-123"
  222. )
  223. assert isinstance(result, LLMResult)