test_llm.py 7.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271
  1. import os
  2. from collections.abc import Generator
  3. from time import sleep
  4. import pytest
  5. from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
  6. from core.model_runtime.entities.message_entities import AssistantPromptMessage, SystemPromptMessage, UserPromptMessage
  7. from core.model_runtime.entities.model_entities import AIModelEntity
  8. from core.model_runtime.errors.validate import CredentialsValidateFailedError
  9. from core.model_runtime.model_providers.wenxin.llm.llm import ErnieBotLargeLanguageModel
  10. def test_predefined_models():
  11. model = ErnieBotLargeLanguageModel()
  12. model_schemas = model.predefined_models()
  13. assert len(model_schemas) >= 1
  14. assert isinstance(model_schemas[0], AIModelEntity)
  15. def test_validate_credentials_for_chat_model():
  16. sleep(3)
  17. model = ErnieBotLargeLanguageModel()
  18. with pytest.raises(CredentialsValidateFailedError):
  19. model.validate_credentials(
  20. model='ernie-bot',
  21. credentials={
  22. 'api_key': 'invalid_key',
  23. 'secret_key': 'invalid_key'
  24. }
  25. )
  26. model.validate_credentials(
  27. model='ernie-bot',
  28. credentials={
  29. 'api_key': os.environ.get('WENXIN_API_KEY'),
  30. 'secret_key': os.environ.get('WENXIN_SECRET_KEY')
  31. }
  32. )
  33. def test_invoke_model_ernie_bot():
  34. sleep(3)
  35. model = ErnieBotLargeLanguageModel()
  36. response = model.invoke(
  37. model='ernie-bot',
  38. credentials={
  39. 'api_key': os.environ.get('WENXIN_API_KEY'),
  40. 'secret_key': os.environ.get('WENXIN_SECRET_KEY')
  41. },
  42. prompt_messages=[
  43. UserPromptMessage(
  44. content='Hello World!'
  45. )
  46. ],
  47. model_parameters={
  48. 'temperature': 0.7,
  49. 'top_p': 1.0,
  50. },
  51. stop=['you'],
  52. user="abc-123",
  53. stream=False
  54. )
  55. assert isinstance(response, LLMResult)
  56. assert len(response.message.content) > 0
  57. assert response.usage.total_tokens > 0
  58. def test_invoke_model_ernie_bot_turbo():
  59. sleep(3)
  60. model = ErnieBotLargeLanguageModel()
  61. response = model.invoke(
  62. model='ernie-bot-turbo',
  63. credentials={
  64. 'api_key': os.environ.get('WENXIN_API_KEY'),
  65. 'secret_key': os.environ.get('WENXIN_SECRET_KEY')
  66. },
  67. prompt_messages=[
  68. UserPromptMessage(
  69. content='Hello World!'
  70. )
  71. ],
  72. model_parameters={
  73. 'temperature': 0.7,
  74. 'top_p': 1.0,
  75. },
  76. stop=['you'],
  77. user="abc-123",
  78. stream=False
  79. )
  80. assert isinstance(response, LLMResult)
  81. assert len(response.message.content) > 0
  82. assert response.usage.total_tokens > 0
  83. def test_invoke_model_ernie_8k():
  84. sleep(3)
  85. model = ErnieBotLargeLanguageModel()
  86. response = model.invoke(
  87. model='ernie-bot-8k',
  88. credentials={
  89. 'api_key': os.environ.get('WENXIN_API_KEY'),
  90. 'secret_key': os.environ.get('WENXIN_SECRET_KEY')
  91. },
  92. prompt_messages=[
  93. UserPromptMessage(
  94. content='Hello World!'
  95. )
  96. ],
  97. model_parameters={
  98. 'temperature': 0.7,
  99. 'top_p': 1.0,
  100. },
  101. stop=['you'],
  102. user="abc-123",
  103. stream=False
  104. )
  105. assert isinstance(response, LLMResult)
  106. assert len(response.message.content) > 0
  107. assert response.usage.total_tokens > 0
  108. def test_invoke_model_ernie_bot_4():
  109. sleep(3)
  110. model = ErnieBotLargeLanguageModel()
  111. response = model.invoke(
  112. model='ernie-bot-4',
  113. credentials={
  114. 'api_key': os.environ.get('WENXIN_API_KEY'),
  115. 'secret_key': os.environ.get('WENXIN_SECRET_KEY')
  116. },
  117. prompt_messages=[
  118. UserPromptMessage(
  119. content='Hello World!'
  120. )
  121. ],
  122. model_parameters={
  123. 'temperature': 0.7,
  124. 'top_p': 1.0,
  125. },
  126. stop=['you'],
  127. user="abc-123",
  128. stream=False
  129. )
  130. assert isinstance(response, LLMResult)
  131. assert len(response.message.content) > 0
  132. assert response.usage.total_tokens > 0
  133. def test_invoke_stream_model():
  134. sleep(3)
  135. model = ErnieBotLargeLanguageModel()
  136. response = model.invoke(
  137. model='ernie-3.5-8k',
  138. credentials={
  139. 'api_key': os.environ.get('WENXIN_API_KEY'),
  140. 'secret_key': os.environ.get('WENXIN_SECRET_KEY')
  141. },
  142. prompt_messages=[
  143. UserPromptMessage(
  144. content='Hello World!'
  145. )
  146. ],
  147. model_parameters={
  148. 'temperature': 0.7,
  149. 'top_p': 1.0,
  150. },
  151. stop=['you'],
  152. stream=True,
  153. user="abc-123"
  154. )
  155. assert isinstance(response, Generator)
  156. for chunk in response:
  157. assert isinstance(chunk, LLMResultChunk)
  158. assert isinstance(chunk.delta, LLMResultChunkDelta)
  159. assert isinstance(chunk.delta.message, AssistantPromptMessage)
  160. assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True
  161. def test_invoke_model_with_system():
  162. sleep(3)
  163. model = ErnieBotLargeLanguageModel()
  164. response = model.invoke(
  165. model='ernie-bot',
  166. credentials={
  167. 'api_key': os.environ.get('WENXIN_API_KEY'),
  168. 'secret_key': os.environ.get('WENXIN_SECRET_KEY')
  169. },
  170. prompt_messages=[
  171. SystemPromptMessage(
  172. content='你是Kasumi'
  173. ),
  174. UserPromptMessage(
  175. content='你是谁?'
  176. )
  177. ],
  178. model_parameters={
  179. 'temperature': 0.7,
  180. 'top_p': 1.0,
  181. },
  182. stop=['you'],
  183. stream=False,
  184. user="abc-123"
  185. )
  186. assert isinstance(response, LLMResult)
  187. assert 'kasumi' in response.message.content.lower()
  188. def test_invoke_with_search():
  189. sleep(3)
  190. model = ErnieBotLargeLanguageModel()
  191. response = model.invoke(
  192. model='ernie-bot',
  193. credentials={
  194. 'api_key': os.environ.get('WENXIN_API_KEY'),
  195. 'secret_key': os.environ.get('WENXIN_SECRET_KEY')
  196. },
  197. prompt_messages=[
  198. UserPromptMessage(
  199. content='北京今天的天气怎么样'
  200. )
  201. ],
  202. model_parameters={
  203. 'temperature': 0.7,
  204. 'top_p': 1.0,
  205. 'disable_search': True,
  206. },
  207. stop=[],
  208. stream=True,
  209. user="abc-123"
  210. )
  211. assert isinstance(response, Generator)
  212. total_message = ''
  213. for chunk in response:
  214. assert isinstance(chunk, LLMResultChunk)
  215. assert isinstance(chunk.delta, LLMResultChunkDelta)
  216. assert isinstance(chunk.delta.message, AssistantPromptMessage)
  217. total_message += chunk.delta.message.content
  218. print(chunk.delta.message.content)
  219. assert len(chunk.delta.message.content) > 0 if not chunk.delta.finish_reason else True
  220. # there should be 对不起、我不能、不支持……
  221. assert ('不' in total_message or '抱歉' in total_message or '无法' in total_message)
  222. def test_get_num_tokens():
  223. sleep(3)
  224. model = ErnieBotLargeLanguageModel()
  225. response = model.get_num_tokens(
  226. model='ernie-bot',
  227. credentials={
  228. 'api_key': os.environ.get('WENXIN_API_KEY'),
  229. 'secret_key': os.environ.get('WENXIN_SECRET_KEY')
  230. },
  231. prompt_messages=[
  232. UserPromptMessage(
  233. content='Hello World!'
  234. )
  235. ],
  236. tools=[]
  237. )
  238. assert isinstance(response, int)
  239. assert response == 10