test_llm.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279
  1. import os
  2. from collections.abc import Generator
  3. import pytest
  4. from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
  5. from core.model_runtime.entities.message_entities import AssistantPromptMessage, UserPromptMessage
  6. from core.model_runtime.errors.validate import CredentialsValidateFailedError
  7. from core.model_runtime.model_providers.huggingface_hub.llm.llm import HuggingfaceHubLargeLanguageModel
  8. from tests.integration_tests.model_runtime.__mock.huggingface import setup_huggingface_mock
  9. @pytest.mark.skip
  10. @pytest.mark.parametrize("setup_huggingface_mock", [["none"]], indirect=True)
  11. def test_hosted_inference_api_validate_credentials(setup_huggingface_mock):
  12. model = HuggingfaceHubLargeLanguageModel()
  13. with pytest.raises(CredentialsValidateFailedError):
  14. model.validate_credentials(
  15. model="HuggingFaceH4/zephyr-7b-beta",
  16. credentials={"huggingfacehub_api_type": "hosted_inference_api", "huggingfacehub_api_token": "invalid_key"},
  17. )
  18. with pytest.raises(CredentialsValidateFailedError):
  19. model.validate_credentials(
  20. model="fake-model",
  21. credentials={"huggingfacehub_api_type": "hosted_inference_api", "huggingfacehub_api_token": "invalid_key"},
  22. )
  23. model.validate_credentials(
  24. model="HuggingFaceH4/zephyr-7b-beta",
  25. credentials={
  26. "huggingfacehub_api_type": "hosted_inference_api",
  27. "huggingfacehub_api_token": os.environ.get("HUGGINGFACE_API_KEY"),
  28. },
  29. )
  30. @pytest.mark.parametrize("setup_huggingface_mock", [["none"]], indirect=True)
  31. def test_hosted_inference_api_invoke_model(setup_huggingface_mock):
  32. model = HuggingfaceHubLargeLanguageModel()
  33. response = model.invoke(
  34. model="HuggingFaceH4/zephyr-7b-beta",
  35. credentials={
  36. "huggingfacehub_api_type": "hosted_inference_api",
  37. "huggingfacehub_api_token": os.environ.get("HUGGINGFACE_API_KEY"),
  38. },
  39. prompt_messages=[UserPromptMessage(content="Who are you?")],
  40. model_parameters={
  41. "temperature": 1.0,
  42. "top_k": 2,
  43. "top_p": 0.5,
  44. },
  45. stop=["How"],
  46. stream=False,
  47. user="abc-123",
  48. )
  49. assert isinstance(response, LLMResult)
  50. assert len(response.message.content) > 0
  51. @pytest.mark.parametrize("setup_huggingface_mock", [["none"]], indirect=True)
  52. def test_hosted_inference_api_invoke_stream_model(setup_huggingface_mock):
  53. model = HuggingfaceHubLargeLanguageModel()
  54. response = model.invoke(
  55. model="HuggingFaceH4/zephyr-7b-beta",
  56. credentials={
  57. "huggingfacehub_api_type": "hosted_inference_api",
  58. "huggingfacehub_api_token": os.environ.get("HUGGINGFACE_API_KEY"),
  59. },
  60. prompt_messages=[UserPromptMessage(content="Who are you?")],
  61. model_parameters={
  62. "temperature": 1.0,
  63. "top_k": 2,
  64. "top_p": 0.5,
  65. },
  66. stop=["How"],
  67. stream=True,
  68. user="abc-123",
  69. )
  70. assert isinstance(response, Generator)
  71. for chunk in response:
  72. assert isinstance(chunk, LLMResultChunk)
  73. assert isinstance(chunk.delta, LLMResultChunkDelta)
  74. assert isinstance(chunk.delta.message, AssistantPromptMessage)
  75. assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True
  76. @pytest.mark.parametrize("setup_huggingface_mock", [["none"]], indirect=True)
  77. def test_inference_endpoints_text_generation_validate_credentials(setup_huggingface_mock):
  78. model = HuggingfaceHubLargeLanguageModel()
  79. with pytest.raises(CredentialsValidateFailedError):
  80. model.validate_credentials(
  81. model="openchat/openchat_3.5",
  82. credentials={
  83. "huggingfacehub_api_type": "inference_endpoints",
  84. "huggingfacehub_api_token": "invalid_key",
  85. "huggingfacehub_endpoint_url": os.environ.get("HUGGINGFACE_TEXT_GEN_ENDPOINT_URL"),
  86. "task_type": "text-generation",
  87. },
  88. )
  89. model.validate_credentials(
  90. model="openchat/openchat_3.5",
  91. credentials={
  92. "huggingfacehub_api_type": "inference_endpoints",
  93. "huggingfacehub_api_token": os.environ.get("HUGGINGFACE_API_KEY"),
  94. "huggingfacehub_endpoint_url": os.environ.get("HUGGINGFACE_TEXT_GEN_ENDPOINT_URL"),
  95. "task_type": "text-generation",
  96. },
  97. )
  98. @pytest.mark.parametrize("setup_huggingface_mock", [["none"]], indirect=True)
  99. def test_inference_endpoints_text_generation_invoke_model(setup_huggingface_mock):
  100. model = HuggingfaceHubLargeLanguageModel()
  101. response = model.invoke(
  102. model="openchat/openchat_3.5",
  103. credentials={
  104. "huggingfacehub_api_type": "inference_endpoints",
  105. "huggingfacehub_api_token": os.environ.get("HUGGINGFACE_API_KEY"),
  106. "huggingfacehub_endpoint_url": os.environ.get("HUGGINGFACE_TEXT_GEN_ENDPOINT_URL"),
  107. "task_type": "text-generation",
  108. },
  109. prompt_messages=[UserPromptMessage(content="Who are you?")],
  110. model_parameters={
  111. "temperature": 1.0,
  112. "top_k": 2,
  113. "top_p": 0.5,
  114. },
  115. stop=["How"],
  116. stream=False,
  117. user="abc-123",
  118. )
  119. assert isinstance(response, LLMResult)
  120. assert len(response.message.content) > 0
  121. @pytest.mark.parametrize("setup_huggingface_mock", [["none"]], indirect=True)
  122. def test_inference_endpoints_text_generation_invoke_stream_model(setup_huggingface_mock):
  123. model = HuggingfaceHubLargeLanguageModel()
  124. response = model.invoke(
  125. model="openchat/openchat_3.5",
  126. credentials={
  127. "huggingfacehub_api_type": "inference_endpoints",
  128. "huggingfacehub_api_token": os.environ.get("HUGGINGFACE_API_KEY"),
  129. "huggingfacehub_endpoint_url": os.environ.get("HUGGINGFACE_TEXT_GEN_ENDPOINT_URL"),
  130. "task_type": "text-generation",
  131. },
  132. prompt_messages=[UserPromptMessage(content="Who are you?")],
  133. model_parameters={
  134. "temperature": 1.0,
  135. "top_k": 2,
  136. "top_p": 0.5,
  137. },
  138. stop=["How"],
  139. stream=True,
  140. user="abc-123",
  141. )
  142. assert isinstance(response, Generator)
  143. for chunk in response:
  144. assert isinstance(chunk, LLMResultChunk)
  145. assert isinstance(chunk.delta, LLMResultChunkDelta)
  146. assert isinstance(chunk.delta.message, AssistantPromptMessage)
  147. assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True
  148. @pytest.mark.parametrize("setup_huggingface_mock", [["none"]], indirect=True)
  149. def test_inference_endpoints_text2text_generation_validate_credentials(setup_huggingface_mock):
  150. model = HuggingfaceHubLargeLanguageModel()
  151. with pytest.raises(CredentialsValidateFailedError):
  152. model.validate_credentials(
  153. model="google/mt5-base",
  154. credentials={
  155. "huggingfacehub_api_type": "inference_endpoints",
  156. "huggingfacehub_api_token": "invalid_key",
  157. "huggingfacehub_endpoint_url": os.environ.get("HUGGINGFACE_TEXT2TEXT_GEN_ENDPOINT_URL"),
  158. "task_type": "text2text-generation",
  159. },
  160. )
  161. model.validate_credentials(
  162. model="google/mt5-base",
  163. credentials={
  164. "huggingfacehub_api_type": "inference_endpoints",
  165. "huggingfacehub_api_token": os.environ.get("HUGGINGFACE_API_KEY"),
  166. "huggingfacehub_endpoint_url": os.environ.get("HUGGINGFACE_TEXT2TEXT_GEN_ENDPOINT_URL"),
  167. "task_type": "text2text-generation",
  168. },
  169. )
  170. @pytest.mark.parametrize("setup_huggingface_mock", [["none"]], indirect=True)
  171. def test_inference_endpoints_text2text_generation_invoke_model(setup_huggingface_mock):
  172. model = HuggingfaceHubLargeLanguageModel()
  173. response = model.invoke(
  174. model="google/mt5-base",
  175. credentials={
  176. "huggingfacehub_api_type": "inference_endpoints",
  177. "huggingfacehub_api_token": os.environ.get("HUGGINGFACE_API_KEY"),
  178. "huggingfacehub_endpoint_url": os.environ.get("HUGGINGFACE_TEXT2TEXT_GEN_ENDPOINT_URL"),
  179. "task_type": "text2text-generation",
  180. },
  181. prompt_messages=[UserPromptMessage(content="Who are you?")],
  182. model_parameters={
  183. "temperature": 1.0,
  184. "top_k": 2,
  185. "top_p": 0.5,
  186. },
  187. stop=["How"],
  188. stream=False,
  189. user="abc-123",
  190. )
  191. assert isinstance(response, LLMResult)
  192. assert len(response.message.content) > 0
  193. @pytest.mark.parametrize("setup_huggingface_mock", [["none"]], indirect=True)
  194. def test_inference_endpoints_text2text_generation_invoke_stream_model(setup_huggingface_mock):
  195. model = HuggingfaceHubLargeLanguageModel()
  196. response = model.invoke(
  197. model="google/mt5-base",
  198. credentials={
  199. "huggingfacehub_api_type": "inference_endpoints",
  200. "huggingfacehub_api_token": os.environ.get("HUGGINGFACE_API_KEY"),
  201. "huggingfacehub_endpoint_url": os.environ.get("HUGGINGFACE_TEXT2TEXT_GEN_ENDPOINT_URL"),
  202. "task_type": "text2text-generation",
  203. },
  204. prompt_messages=[UserPromptMessage(content="Who are you?")],
  205. model_parameters={
  206. "temperature": 1.0,
  207. "top_k": 2,
  208. "top_p": 0.5,
  209. },
  210. stop=["How"],
  211. stream=True,
  212. user="abc-123",
  213. )
  214. assert isinstance(response, Generator)
  215. for chunk in response:
  216. assert isinstance(chunk, LLMResultChunk)
  217. assert isinstance(chunk.delta, LLMResultChunkDelta)
  218. assert isinstance(chunk.delta.message, AssistantPromptMessage)
  219. assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True
  220. def test_get_num_tokens():
  221. model = HuggingfaceHubLargeLanguageModel()
  222. num_tokens = model.get_num_tokens(
  223. model="google/mt5-base",
  224. credentials={
  225. "huggingfacehub_api_type": "inference_endpoints",
  226. "huggingfacehub_api_token": os.environ.get("HUGGINGFACE_API_KEY"),
  227. "huggingfacehub_endpoint_url": os.environ.get("HUGGINGFACE_TEXT2TEXT_GEN_ENDPOINT_URL"),
  228. "task_type": "text2text-generation",
  229. },
  230. prompt_messages=[UserPromptMessage(content="Hello World!")],
  231. )
  232. assert num_tokens == 7