test_llm.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305
  1. import os
  2. from typing import Generator
  3. import pytest
  4. from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, \
  5. LLMResultChunkDelta
  6. from core.model_runtime.entities.message_entities import UserPromptMessage, AssistantPromptMessage
  7. from core.model_runtime.errors.validate import CredentialsValidateFailedError
  8. from core.model_runtime.model_providers.huggingface_hub.llm.llm import HuggingfaceHubLargeLanguageModel
  9. from tests.integration_tests.model_runtime.__mock.huggingface import setup_huggingface_mock
  10. @pytest.mark.parametrize('setup_huggingface_mock', [['none']], indirect=True)
  11. def test_hosted_inference_api_validate_credentials(setup_huggingface_mock):
  12. model = HuggingfaceHubLargeLanguageModel()
  13. with pytest.raises(CredentialsValidateFailedError):
  14. model.validate_credentials(
  15. model='HuggingFaceH4/zephyr-7b-beta',
  16. credentials={
  17. 'huggingfacehub_api_type': 'hosted_inference_api',
  18. 'huggingfacehub_api_token': 'invalid_key'
  19. }
  20. )
  21. with pytest.raises(CredentialsValidateFailedError):
  22. model.validate_credentials(
  23. model='fake-model',
  24. credentials={
  25. 'huggingfacehub_api_type': 'hosted_inference_api',
  26. 'huggingfacehub_api_token': 'invalid_key'
  27. }
  28. )
  29. model.validate_credentials(
  30. model='HuggingFaceH4/zephyr-7b-beta',
  31. credentials={
  32. 'huggingfacehub_api_type': 'hosted_inference_api',
  33. 'huggingfacehub_api_token': os.environ.get('HUGGINGFACE_API_KEY')
  34. }
  35. )
  36. @pytest.mark.parametrize('setup_huggingface_mock', [['none']], indirect=True)
  37. def test_hosted_inference_api_invoke_model(setup_huggingface_mock):
  38. model = HuggingfaceHubLargeLanguageModel()
  39. response = model.invoke(
  40. model='HuggingFaceH4/zephyr-7b-beta',
  41. credentials={
  42. 'huggingfacehub_api_type': 'hosted_inference_api',
  43. 'huggingfacehub_api_token': os.environ.get('HUGGINGFACE_API_KEY')
  44. },
  45. prompt_messages=[
  46. UserPromptMessage(
  47. content='Who are you?'
  48. )
  49. ],
  50. model_parameters={
  51. 'temperature': 1.0,
  52. 'top_k': 2,
  53. 'top_p': 0.5,
  54. },
  55. stop=['How'],
  56. stream=False,
  57. user="abc-123"
  58. )
  59. assert isinstance(response, LLMResult)
  60. assert len(response.message.content) > 0
  61. @pytest.mark.parametrize('setup_huggingface_mock', [['none']], indirect=True)
  62. def test_hosted_inference_api_invoke_stream_model(setup_huggingface_mock):
  63. model = HuggingfaceHubLargeLanguageModel()
  64. response = model.invoke(
  65. model='HuggingFaceH4/zephyr-7b-beta',
  66. credentials={
  67. 'huggingfacehub_api_type': 'hosted_inference_api',
  68. 'huggingfacehub_api_token': os.environ.get('HUGGINGFACE_API_KEY')
  69. },
  70. prompt_messages=[
  71. UserPromptMessage(
  72. content='Who are you?'
  73. )
  74. ],
  75. model_parameters={
  76. 'temperature': 1.0,
  77. 'top_k': 2,
  78. 'top_p': 0.5,
  79. },
  80. stop=['How'],
  81. stream=True,
  82. user="abc-123"
  83. )
  84. assert isinstance(response, Generator)
  85. for chunk in response:
  86. assert isinstance(chunk, LLMResultChunk)
  87. assert isinstance(chunk.delta, LLMResultChunkDelta)
  88. assert isinstance(chunk.delta.message, AssistantPromptMessage)
  89. assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True
  90. @pytest.mark.parametrize('setup_huggingface_mock', [['none']], indirect=True)
  91. def test_inference_endpoints_text_generation_validate_credentials(setup_huggingface_mock):
  92. model = HuggingfaceHubLargeLanguageModel()
  93. with pytest.raises(CredentialsValidateFailedError):
  94. model.validate_credentials(
  95. model='openchat/openchat_3.5',
  96. credentials={
  97. 'huggingfacehub_api_type': 'inference_endpoints',
  98. 'huggingfacehub_api_token': 'invalid_key',
  99. 'huggingfacehub_endpoint_url': os.environ.get('HUGGINGFACE_TEXT_GEN_ENDPOINT_URL'),
  100. 'task_type': 'text-generation'
  101. }
  102. )
  103. model.validate_credentials(
  104. model='openchat/openchat_3.5',
  105. credentials={
  106. 'huggingfacehub_api_type': 'inference_endpoints',
  107. 'huggingfacehub_api_token': os.environ.get('HUGGINGFACE_API_KEY'),
  108. 'huggingfacehub_endpoint_url': os.environ.get('HUGGINGFACE_TEXT_GEN_ENDPOINT_URL'),
  109. 'task_type': 'text-generation'
  110. }
  111. )
  112. @pytest.mark.parametrize('setup_huggingface_mock', [['none']], indirect=True)
  113. def test_inference_endpoints_text_generation_invoke_model(setup_huggingface_mock):
  114. model = HuggingfaceHubLargeLanguageModel()
  115. response = model.invoke(
  116. model='openchat/openchat_3.5',
  117. credentials={
  118. 'huggingfacehub_api_type': 'inference_endpoints',
  119. 'huggingfacehub_api_token': os.environ.get('HUGGINGFACE_API_KEY'),
  120. 'huggingfacehub_endpoint_url': os.environ.get('HUGGINGFACE_TEXT_GEN_ENDPOINT_URL'),
  121. 'task_type': 'text-generation'
  122. },
  123. prompt_messages=[
  124. UserPromptMessage(
  125. content='Who are you?'
  126. )
  127. ],
  128. model_parameters={
  129. 'temperature': 1.0,
  130. 'top_k': 2,
  131. 'top_p': 0.5,
  132. },
  133. stop=['How'],
  134. stream=False,
  135. user="abc-123"
  136. )
  137. assert isinstance(response, LLMResult)
  138. assert len(response.message.content) > 0
  139. @pytest.mark.parametrize('setup_huggingface_mock', [['none']], indirect=True)
  140. def test_inference_endpoints_text_generation_invoke_stream_model(setup_huggingface_mock):
  141. model = HuggingfaceHubLargeLanguageModel()
  142. response = model.invoke(
  143. model='openchat/openchat_3.5',
  144. credentials={
  145. 'huggingfacehub_api_type': 'inference_endpoints',
  146. 'huggingfacehub_api_token': os.environ.get('HUGGINGFACE_API_KEY'),
  147. 'huggingfacehub_endpoint_url': os.environ.get('HUGGINGFACE_TEXT_GEN_ENDPOINT_URL'),
  148. 'task_type': 'text-generation'
  149. },
  150. prompt_messages=[
  151. UserPromptMessage(
  152. content='Who are you?'
  153. )
  154. ],
  155. model_parameters={
  156. 'temperature': 1.0,
  157. 'top_k': 2,
  158. 'top_p': 0.5,
  159. },
  160. stop=['How'],
  161. stream=True,
  162. user="abc-123"
  163. )
  164. assert isinstance(response, Generator)
  165. for chunk in response:
  166. assert isinstance(chunk, LLMResultChunk)
  167. assert isinstance(chunk.delta, LLMResultChunkDelta)
  168. assert isinstance(chunk.delta.message, AssistantPromptMessage)
  169. assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True
  170. @pytest.mark.parametrize('setup_huggingface_mock', [['none']], indirect=True)
  171. def test_inference_endpoints_text2text_generation_validate_credentials(setup_huggingface_mock):
  172. model = HuggingfaceHubLargeLanguageModel()
  173. with pytest.raises(CredentialsValidateFailedError):
  174. model.validate_credentials(
  175. model='google/mt5-base',
  176. credentials={
  177. 'huggingfacehub_api_type': 'inference_endpoints',
  178. 'huggingfacehub_api_token': 'invalid_key',
  179. 'huggingfacehub_endpoint_url': os.environ.get('HUGGINGFACE_TEXT2TEXT_GEN_ENDPOINT_URL'),
  180. 'task_type': 'text2text-generation'
  181. }
  182. )
  183. model.validate_credentials(
  184. model='google/mt5-base',
  185. credentials={
  186. 'huggingfacehub_api_type': 'inference_endpoints',
  187. 'huggingfacehub_api_token': os.environ.get('HUGGINGFACE_API_KEY'),
  188. 'huggingfacehub_endpoint_url': os.environ.get('HUGGINGFACE_TEXT2TEXT_GEN_ENDPOINT_URL'),
  189. 'task_type': 'text2text-generation'
  190. }
  191. )
  192. @pytest.mark.parametrize('setup_huggingface_mock', [['none']], indirect=True)
  193. def test_inference_endpoints_text2text_generation_invoke_model(setup_huggingface_mock):
  194. model = HuggingfaceHubLargeLanguageModel()
  195. response = model.invoke(
  196. model='google/mt5-base',
  197. credentials={
  198. 'huggingfacehub_api_type': 'inference_endpoints',
  199. 'huggingfacehub_api_token': os.environ.get('HUGGINGFACE_API_KEY'),
  200. 'huggingfacehub_endpoint_url': os.environ.get('HUGGINGFACE_TEXT2TEXT_GEN_ENDPOINT_URL'),
  201. 'task_type': 'text2text-generation'
  202. },
  203. prompt_messages=[
  204. UserPromptMessage(
  205. content='Who are you?'
  206. )
  207. ],
  208. model_parameters={
  209. 'temperature': 1.0,
  210. 'top_k': 2,
  211. 'top_p': 0.5,
  212. },
  213. stop=['How'],
  214. stream=False,
  215. user="abc-123"
  216. )
  217. assert isinstance(response, LLMResult)
  218. assert len(response.message.content) > 0
  219. @pytest.mark.parametrize('setup_huggingface_mock', [['none']], indirect=True)
  220. def test_inference_endpoints_text2text_generation_invoke_stream_model(setup_huggingface_mock):
  221. model = HuggingfaceHubLargeLanguageModel()
  222. response = model.invoke(
  223. model='google/mt5-base',
  224. credentials={
  225. 'huggingfacehub_api_type': 'inference_endpoints',
  226. 'huggingfacehub_api_token': os.environ.get('HUGGINGFACE_API_KEY'),
  227. 'huggingfacehub_endpoint_url': os.environ.get('HUGGINGFACE_TEXT2TEXT_GEN_ENDPOINT_URL'),
  228. 'task_type': 'text2text-generation'
  229. },
  230. prompt_messages=[
  231. UserPromptMessage(
  232. content='Who are you?'
  233. )
  234. ],
  235. model_parameters={
  236. 'temperature': 1.0,
  237. 'top_k': 2,
  238. 'top_p': 0.5,
  239. },
  240. stop=['How'],
  241. stream=True,
  242. user="abc-123"
  243. )
  244. assert isinstance(response, Generator)
  245. for chunk in response:
  246. assert isinstance(chunk, LLMResultChunk)
  247. assert isinstance(chunk.delta, LLMResultChunkDelta)
  248. assert isinstance(chunk.delta.message, AssistantPromptMessage)
  249. assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True
  250. def test_get_num_tokens():
  251. model = HuggingfaceHubLargeLanguageModel()
  252. num_tokens = model.get_num_tokens(
  253. model='google/mt5-base',
  254. credentials={
  255. 'huggingfacehub_api_type': 'inference_endpoints',
  256. 'huggingfacehub_api_token': os.environ.get('HUGGINGFACE_API_KEY'),
  257. 'huggingfacehub_endpoint_url': os.environ.get('HUGGINGFACE_TEXT2TEXT_GEN_ENDPOINT_URL'),
  258. 'task_type': 'text2text-generation'
  259. },
  260. prompt_messages=[
  261. UserPromptMessage(
  262. content='Hello World!'
  263. )
  264. ]
  265. )
  266. assert num_tokens == 7