test_llm.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278
  1. import os
  2. from collections.abc import Generator
  3. import pytest
  4. from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
  5. from core.model_runtime.entities.message_entities import AssistantPromptMessage, UserPromptMessage
  6. from core.model_runtime.errors.validate import CredentialsValidateFailedError
  7. from core.model_runtime.model_providers.huggingface_hub.llm.llm import HuggingfaceHubLargeLanguageModel
  8. from tests.integration_tests.model_runtime.__mock.huggingface import setup_huggingface_mock
  9. @pytest.mark.parametrize("setup_huggingface_mock", [["none"]], indirect=True)
  10. def test_hosted_inference_api_validate_credentials(setup_huggingface_mock):
  11. model = HuggingfaceHubLargeLanguageModel()
  12. with pytest.raises(CredentialsValidateFailedError):
  13. model.validate_credentials(
  14. model="HuggingFaceH4/zephyr-7b-beta",
  15. credentials={"huggingfacehub_api_type": "hosted_inference_api", "huggingfacehub_api_token": "invalid_key"},
  16. )
  17. with pytest.raises(CredentialsValidateFailedError):
  18. model.validate_credentials(
  19. model="fake-model",
  20. credentials={"huggingfacehub_api_type": "hosted_inference_api", "huggingfacehub_api_token": "invalid_key"},
  21. )
  22. model.validate_credentials(
  23. model="HuggingFaceH4/zephyr-7b-beta",
  24. credentials={
  25. "huggingfacehub_api_type": "hosted_inference_api",
  26. "huggingfacehub_api_token": os.environ.get("HUGGINGFACE_API_KEY"),
  27. },
  28. )
  29. @pytest.mark.parametrize("setup_huggingface_mock", [["none"]], indirect=True)
  30. def test_hosted_inference_api_invoke_model(setup_huggingface_mock):
  31. model = HuggingfaceHubLargeLanguageModel()
  32. response = model.invoke(
  33. model="HuggingFaceH4/zephyr-7b-beta",
  34. credentials={
  35. "huggingfacehub_api_type": "hosted_inference_api",
  36. "huggingfacehub_api_token": os.environ.get("HUGGINGFACE_API_KEY"),
  37. },
  38. prompt_messages=[UserPromptMessage(content="Who are you?")],
  39. model_parameters={
  40. "temperature": 1.0,
  41. "top_k": 2,
  42. "top_p": 0.5,
  43. },
  44. stop=["How"],
  45. stream=False,
  46. user="abc-123",
  47. )
  48. assert isinstance(response, LLMResult)
  49. assert len(response.message.content) > 0
  50. @pytest.mark.parametrize("setup_huggingface_mock", [["none"]], indirect=True)
  51. def test_hosted_inference_api_invoke_stream_model(setup_huggingface_mock):
  52. model = HuggingfaceHubLargeLanguageModel()
  53. response = model.invoke(
  54. model="HuggingFaceH4/zephyr-7b-beta",
  55. credentials={
  56. "huggingfacehub_api_type": "hosted_inference_api",
  57. "huggingfacehub_api_token": os.environ.get("HUGGINGFACE_API_KEY"),
  58. },
  59. prompt_messages=[UserPromptMessage(content="Who are you?")],
  60. model_parameters={
  61. "temperature": 1.0,
  62. "top_k": 2,
  63. "top_p": 0.5,
  64. },
  65. stop=["How"],
  66. stream=True,
  67. user="abc-123",
  68. )
  69. assert isinstance(response, Generator)
  70. for chunk in response:
  71. assert isinstance(chunk, LLMResultChunk)
  72. assert isinstance(chunk.delta, LLMResultChunkDelta)
  73. assert isinstance(chunk.delta.message, AssistantPromptMessage)
  74. assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True
  75. @pytest.mark.parametrize("setup_huggingface_mock", [["none"]], indirect=True)
  76. def test_inference_endpoints_text_generation_validate_credentials(setup_huggingface_mock):
  77. model = HuggingfaceHubLargeLanguageModel()
  78. with pytest.raises(CredentialsValidateFailedError):
  79. model.validate_credentials(
  80. model="openchat/openchat_3.5",
  81. credentials={
  82. "huggingfacehub_api_type": "inference_endpoints",
  83. "huggingfacehub_api_token": "invalid_key",
  84. "huggingfacehub_endpoint_url": os.environ.get("HUGGINGFACE_TEXT_GEN_ENDPOINT_URL"),
  85. "task_type": "text-generation",
  86. },
  87. )
  88. model.validate_credentials(
  89. model="openchat/openchat_3.5",
  90. credentials={
  91. "huggingfacehub_api_type": "inference_endpoints",
  92. "huggingfacehub_api_token": os.environ.get("HUGGINGFACE_API_KEY"),
  93. "huggingfacehub_endpoint_url": os.environ.get("HUGGINGFACE_TEXT_GEN_ENDPOINT_URL"),
  94. "task_type": "text-generation",
  95. },
  96. )
  97. @pytest.mark.parametrize("setup_huggingface_mock", [["none"]], indirect=True)
  98. def test_inference_endpoints_text_generation_invoke_model(setup_huggingface_mock):
  99. model = HuggingfaceHubLargeLanguageModel()
  100. response = model.invoke(
  101. model="openchat/openchat_3.5",
  102. credentials={
  103. "huggingfacehub_api_type": "inference_endpoints",
  104. "huggingfacehub_api_token": os.environ.get("HUGGINGFACE_API_KEY"),
  105. "huggingfacehub_endpoint_url": os.environ.get("HUGGINGFACE_TEXT_GEN_ENDPOINT_URL"),
  106. "task_type": "text-generation",
  107. },
  108. prompt_messages=[UserPromptMessage(content="Who are you?")],
  109. model_parameters={
  110. "temperature": 1.0,
  111. "top_k": 2,
  112. "top_p": 0.5,
  113. },
  114. stop=["How"],
  115. stream=False,
  116. user="abc-123",
  117. )
  118. assert isinstance(response, LLMResult)
  119. assert len(response.message.content) > 0
  120. @pytest.mark.parametrize("setup_huggingface_mock", [["none"]], indirect=True)
  121. def test_inference_endpoints_text_generation_invoke_stream_model(setup_huggingface_mock):
  122. model = HuggingfaceHubLargeLanguageModel()
  123. response = model.invoke(
  124. model="openchat/openchat_3.5",
  125. credentials={
  126. "huggingfacehub_api_type": "inference_endpoints",
  127. "huggingfacehub_api_token": os.environ.get("HUGGINGFACE_API_KEY"),
  128. "huggingfacehub_endpoint_url": os.environ.get("HUGGINGFACE_TEXT_GEN_ENDPOINT_URL"),
  129. "task_type": "text-generation",
  130. },
  131. prompt_messages=[UserPromptMessage(content="Who are you?")],
  132. model_parameters={
  133. "temperature": 1.0,
  134. "top_k": 2,
  135. "top_p": 0.5,
  136. },
  137. stop=["How"],
  138. stream=True,
  139. user="abc-123",
  140. )
  141. assert isinstance(response, Generator)
  142. for chunk in response:
  143. assert isinstance(chunk, LLMResultChunk)
  144. assert isinstance(chunk.delta, LLMResultChunkDelta)
  145. assert isinstance(chunk.delta.message, AssistantPromptMessage)
  146. assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True
  147. @pytest.mark.parametrize("setup_huggingface_mock", [["none"]], indirect=True)
  148. def test_inference_endpoints_text2text_generation_validate_credentials(setup_huggingface_mock):
  149. model = HuggingfaceHubLargeLanguageModel()
  150. with pytest.raises(CredentialsValidateFailedError):
  151. model.validate_credentials(
  152. model="google/mt5-base",
  153. credentials={
  154. "huggingfacehub_api_type": "inference_endpoints",
  155. "huggingfacehub_api_token": "invalid_key",
  156. "huggingfacehub_endpoint_url": os.environ.get("HUGGINGFACE_TEXT2TEXT_GEN_ENDPOINT_URL"),
  157. "task_type": "text2text-generation",
  158. },
  159. )
  160. model.validate_credentials(
  161. model="google/mt5-base",
  162. credentials={
  163. "huggingfacehub_api_type": "inference_endpoints",
  164. "huggingfacehub_api_token": os.environ.get("HUGGINGFACE_API_KEY"),
  165. "huggingfacehub_endpoint_url": os.environ.get("HUGGINGFACE_TEXT2TEXT_GEN_ENDPOINT_URL"),
  166. "task_type": "text2text-generation",
  167. },
  168. )
  169. @pytest.mark.parametrize("setup_huggingface_mock", [["none"]], indirect=True)
  170. def test_inference_endpoints_text2text_generation_invoke_model(setup_huggingface_mock):
  171. model = HuggingfaceHubLargeLanguageModel()
  172. response = model.invoke(
  173. model="google/mt5-base",
  174. credentials={
  175. "huggingfacehub_api_type": "inference_endpoints",
  176. "huggingfacehub_api_token": os.environ.get("HUGGINGFACE_API_KEY"),
  177. "huggingfacehub_endpoint_url": os.environ.get("HUGGINGFACE_TEXT2TEXT_GEN_ENDPOINT_URL"),
  178. "task_type": "text2text-generation",
  179. },
  180. prompt_messages=[UserPromptMessage(content="Who are you?")],
  181. model_parameters={
  182. "temperature": 1.0,
  183. "top_k": 2,
  184. "top_p": 0.5,
  185. },
  186. stop=["How"],
  187. stream=False,
  188. user="abc-123",
  189. )
  190. assert isinstance(response, LLMResult)
  191. assert len(response.message.content) > 0
  192. @pytest.mark.parametrize("setup_huggingface_mock", [["none"]], indirect=True)
  193. def test_inference_endpoints_text2text_generation_invoke_stream_model(setup_huggingface_mock):
  194. model = HuggingfaceHubLargeLanguageModel()
  195. response = model.invoke(
  196. model="google/mt5-base",
  197. credentials={
  198. "huggingfacehub_api_type": "inference_endpoints",
  199. "huggingfacehub_api_token": os.environ.get("HUGGINGFACE_API_KEY"),
  200. "huggingfacehub_endpoint_url": os.environ.get("HUGGINGFACE_TEXT2TEXT_GEN_ENDPOINT_URL"),
  201. "task_type": "text2text-generation",
  202. },
  203. prompt_messages=[UserPromptMessage(content="Who are you?")],
  204. model_parameters={
  205. "temperature": 1.0,
  206. "top_k": 2,
  207. "top_p": 0.5,
  208. },
  209. stop=["How"],
  210. stream=True,
  211. user="abc-123",
  212. )
  213. assert isinstance(response, Generator)
  214. for chunk in response:
  215. assert isinstance(chunk, LLMResultChunk)
  216. assert isinstance(chunk.delta, LLMResultChunkDelta)
  217. assert isinstance(chunk.delta.message, AssistantPromptMessage)
  218. assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True
  219. def test_get_num_tokens():
  220. model = HuggingfaceHubLargeLanguageModel()
  221. num_tokens = model.get_num_tokens(
  222. model="google/mt5-base",
  223. credentials={
  224. "huggingfacehub_api_type": "inference_endpoints",
  225. "huggingfacehub_api_token": os.environ.get("HUGGINGFACE_API_KEY"),
  226. "huggingfacehub_endpoint_url": os.environ.get("HUGGINGFACE_TEXT2TEXT_GEN_ENDPOINT_URL"),
  227. "task_type": "text2text-generation",
  228. },
  229. prompt_messages=[UserPromptMessage(content="Hello World!")],
  230. )
  231. assert num_tokens == 7