test_llm.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397
  1. import os
  2. from collections.abc import Generator
  3. import pytest
  4. from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
  5. from core.model_runtime.entities.message_entities import (
  6. AssistantPromptMessage,
  7. PromptMessageTool,
  8. SystemPromptMessage,
  9. TextPromptMessageContent,
  10. UserPromptMessage,
  11. )
  12. from core.model_runtime.entities.model_entities import AIModelEntity
  13. from core.model_runtime.errors.validate import CredentialsValidateFailedError
  14. from core.model_runtime.model_providers.xinference.llm.llm import XinferenceAILargeLanguageModel
  15. """FOR MOCK FIXTURES, DO NOT REMOVE"""
  16. from tests.integration_tests.model_runtime.__mock.openai import setup_openai_mock
  17. from tests.integration_tests.model_runtime.__mock.xinference import setup_xinference_mock
  18. @pytest.mark.parametrize('setup_openai_mock, setup_xinference_mock', [['chat', 'none']], indirect=True)
  19. def test_validate_credentials_for_chat_model(setup_openai_mock, setup_xinference_mock):
  20. model = XinferenceAILargeLanguageModel()
  21. with pytest.raises(CredentialsValidateFailedError):
  22. model.validate_credentials(
  23. model='ChatGLM3',
  24. credentials={
  25. 'server_url': os.environ.get('XINFERENCE_SERVER_URL'),
  26. 'model_uid': 'www ' + os.environ.get('XINFERENCE_CHAT_MODEL_UID')
  27. }
  28. )
  29. with pytest.raises(CredentialsValidateFailedError):
  30. model.validate_credentials(
  31. model='aaaaa',
  32. credentials={
  33. 'server_url': '',
  34. 'model_uid': ''
  35. }
  36. )
  37. model.validate_credentials(
  38. model='ChatGLM3',
  39. credentials={
  40. 'server_url': os.environ.get('XINFERENCE_SERVER_URL'),
  41. 'model_uid': os.environ.get('XINFERENCE_CHAT_MODEL_UID')
  42. }
  43. )
  44. @pytest.mark.parametrize('setup_openai_mock, setup_xinference_mock', [['chat', 'none']], indirect=True)
  45. def test_invoke_chat_model(setup_openai_mock, setup_xinference_mock):
  46. model = XinferenceAILargeLanguageModel()
  47. response = model.invoke(
  48. model='ChatGLM3',
  49. credentials={
  50. 'server_url': os.environ.get('XINFERENCE_SERVER_URL'),
  51. 'model_uid': os.environ.get('XINFERENCE_CHAT_MODEL_UID')
  52. },
  53. prompt_messages=[
  54. SystemPromptMessage(
  55. content='You are a helpful AI assistant.',
  56. ),
  57. UserPromptMessage(
  58. content='Hello World!'
  59. )
  60. ],
  61. model_parameters={
  62. 'temperature': 0.7,
  63. 'top_p': 1.0,
  64. },
  65. stop=['you'],
  66. user="abc-123",
  67. stream=False
  68. )
  69. assert isinstance(response, LLMResult)
  70. assert len(response.message.content) > 0
  71. assert response.usage.total_tokens > 0
  72. @pytest.mark.parametrize('setup_openai_mock, setup_xinference_mock', [['chat', 'none']], indirect=True)
  73. def test_invoke_stream_chat_model(setup_openai_mock, setup_xinference_mock):
  74. model = XinferenceAILargeLanguageModel()
  75. response = model.invoke(
  76. model='ChatGLM3',
  77. credentials={
  78. 'server_url': os.environ.get('XINFERENCE_SERVER_URL'),
  79. 'model_uid': os.environ.get('XINFERENCE_CHAT_MODEL_UID')
  80. },
  81. prompt_messages=[
  82. SystemPromptMessage(
  83. content='You are a helpful AI assistant.',
  84. ),
  85. UserPromptMessage(
  86. content='Hello World!'
  87. )
  88. ],
  89. model_parameters={
  90. 'temperature': 0.7,
  91. 'top_p': 1.0,
  92. },
  93. stop=['you'],
  94. stream=True,
  95. user="abc-123"
  96. )
  97. assert isinstance(response, Generator)
  98. for chunk in response:
  99. assert isinstance(chunk, LLMResultChunk)
  100. assert isinstance(chunk.delta, LLMResultChunkDelta)
  101. assert isinstance(chunk.delta.message, AssistantPromptMessage)
  102. assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True
  103. """
  104. Funtion calling of xinference does not support stream mode currently
  105. """
  106. # def test_invoke_stream_chat_model_with_functions():
  107. # model = XinferenceAILargeLanguageModel()
  108. # response = model.invoke(
  109. # model='ChatGLM3-6b',
  110. # credentials={
  111. # 'server_url': os.environ.get('XINFERENCE_SERVER_URL'),
  112. # 'model_type': 'text-generation',
  113. # 'model_name': 'ChatGLM3',
  114. # 'model_uid': os.environ.get('XINFERENCE_CHAT_MODEL_UID')
  115. # },
  116. # prompt_messages=[
  117. # SystemPromptMessage(
  118. # content='你是一个天气机器人,可以通过调用函数来获取天气信息',
  119. # ),
  120. # UserPromptMessage(
  121. # content='波士顿天气如何?'
  122. # )
  123. # ],
  124. # model_parameters={
  125. # 'temperature': 0,
  126. # 'top_p': 1.0,
  127. # },
  128. # stop=['you'],
  129. # user='abc-123',
  130. # stream=True,
  131. # tools=[
  132. # PromptMessageTool(
  133. # name='get_current_weather',
  134. # description='Get the current weather in a given location',
  135. # parameters={
  136. # "type": "object",
  137. # "properties": {
  138. # "location": {
  139. # "type": "string",
  140. # "description": "The city and state e.g. San Francisco, CA"
  141. # },
  142. # "unit": {
  143. # "type": "string",
  144. # "enum": ["celsius", "fahrenheit"]
  145. # }
  146. # },
  147. # "required": [
  148. # "location"
  149. # ]
  150. # }
  151. # )
  152. # ]
  153. # )
  154. # assert isinstance(response, Generator)
  155. # call: LLMResultChunk = None
  156. # chunks = []
  157. # for chunk in response:
  158. # chunks.append(chunk)
  159. # assert isinstance(chunk, LLMResultChunk)
  160. # assert isinstance(chunk.delta, LLMResultChunkDelta)
  161. # assert isinstance(chunk.delta.message, AssistantPromptMessage)
  162. # assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True
  163. # if chunk.delta.message.tool_calls and len(chunk.delta.message.tool_calls) > 0:
  164. # call = chunk
  165. # break
  166. # assert call is not None
  167. # assert call.delta.message.tool_calls[0].function.name == 'get_current_weather'
  168. # def test_invoke_chat_model_with_functions():
  169. # model = XinferenceAILargeLanguageModel()
  170. # response = model.invoke(
  171. # model='ChatGLM3-6b',
  172. # credentials={
  173. # 'server_url': os.environ.get('XINFERENCE_SERVER_URL'),
  174. # 'model_type': 'text-generation',
  175. # 'model_name': 'ChatGLM3',
  176. # 'model_uid': os.environ.get('XINFERENCE_CHAT_MODEL_UID')
  177. # },
  178. # prompt_messages=[
  179. # UserPromptMessage(
  180. # content='What is the weather like in San Francisco?'
  181. # )
  182. # ],
  183. # model_parameters={
  184. # 'temperature': 0.7,
  185. # 'top_p': 1.0,
  186. # },
  187. # stop=['you'],
  188. # user='abc-123',
  189. # stream=False,
  190. # tools=[
  191. # PromptMessageTool(
  192. # name='get_current_weather',
  193. # description='Get the current weather in a given location',
  194. # parameters={
  195. # "type": "object",
  196. # "properties": {
  197. # "location": {
  198. # "type": "string",
  199. # "description": "The city and state e.g. San Francisco, CA"
  200. # },
  201. # "unit": {
  202. # "type": "string",
  203. # "enum": [
  204. # "c",
  205. # "f"
  206. # ]
  207. # }
  208. # },
  209. # "required": [
  210. # "location"
  211. # ]
  212. # }
  213. # )
  214. # ]
  215. # )
  216. # assert isinstance(response, LLMResult)
  217. # assert len(response.message.content) > 0
  218. # assert response.usage.total_tokens > 0
  219. # assert response.message.tool_calls[0].function.name == 'get_current_weather'
  220. @pytest.mark.parametrize('setup_openai_mock, setup_xinference_mock', [['completion', 'none']], indirect=True)
  221. def test_validate_credentials_for_generation_model(setup_openai_mock, setup_xinference_mock):
  222. model = XinferenceAILargeLanguageModel()
  223. with pytest.raises(CredentialsValidateFailedError):
  224. model.validate_credentials(
  225. model='alapaca',
  226. credentials={
  227. 'server_url': os.environ.get('XINFERENCE_SERVER_URL'),
  228. 'model_uid': 'www ' + os.environ.get('XINFERENCE_GENERATION_MODEL_UID')
  229. }
  230. )
  231. with pytest.raises(CredentialsValidateFailedError):
  232. model.validate_credentials(
  233. model='alapaca',
  234. credentials={
  235. 'server_url': '',
  236. 'model_uid': ''
  237. }
  238. )
  239. model.validate_credentials(
  240. model='alapaca',
  241. credentials={
  242. 'server_url': os.environ.get('XINFERENCE_SERVER_URL'),
  243. 'model_uid': os.environ.get('XINFERENCE_GENERATION_MODEL_UID')
  244. }
  245. )
  246. @pytest.mark.parametrize('setup_openai_mock, setup_xinference_mock', [['completion', 'none']], indirect=True)
  247. def test_invoke_generation_model(setup_openai_mock, setup_xinference_mock):
  248. model = XinferenceAILargeLanguageModel()
  249. response = model.invoke(
  250. model='alapaca',
  251. credentials={
  252. 'server_url': os.environ.get('XINFERENCE_SERVER_URL'),
  253. 'model_uid': os.environ.get('XINFERENCE_GENERATION_MODEL_UID')
  254. },
  255. prompt_messages=[
  256. UserPromptMessage(
  257. content='the United States is'
  258. )
  259. ],
  260. model_parameters={
  261. 'temperature': 0.7,
  262. 'top_p': 1.0,
  263. },
  264. stop=['you'],
  265. user="abc-123",
  266. stream=False
  267. )
  268. assert isinstance(response, LLMResult)
  269. assert len(response.message.content) > 0
  270. assert response.usage.total_tokens > 0
  271. @pytest.mark.parametrize('setup_openai_mock, setup_xinference_mock', [['completion', 'none']], indirect=True)
  272. def test_invoke_stream_generation_model(setup_openai_mock, setup_xinference_mock):
  273. model = XinferenceAILargeLanguageModel()
  274. response = model.invoke(
  275. model='alapaca',
  276. credentials={
  277. 'server_url': os.environ.get('XINFERENCE_SERVER_URL'),
  278. 'model_uid': os.environ.get('XINFERENCE_GENERATION_MODEL_UID')
  279. },
  280. prompt_messages=[
  281. UserPromptMessage(
  282. content='the United States is'
  283. )
  284. ],
  285. model_parameters={
  286. 'temperature': 0.7,
  287. 'top_p': 1.0,
  288. },
  289. stop=['you'],
  290. stream=True,
  291. user="abc-123"
  292. )
  293. assert isinstance(response, Generator)
  294. for chunk in response:
  295. assert isinstance(chunk, LLMResultChunk)
  296. assert isinstance(chunk.delta, LLMResultChunkDelta)
  297. assert isinstance(chunk.delta.message, AssistantPromptMessage)
  298. assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True
  299. def test_get_num_tokens():
  300. model = XinferenceAILargeLanguageModel()
  301. num_tokens = model.get_num_tokens(
  302. model='ChatGLM3',
  303. credentials={
  304. 'server_url': os.environ.get('XINFERENCE_SERVER_URL'),
  305. 'model_uid': os.environ.get('XINFERENCE_GENERATION_MODEL_UID')
  306. },
  307. prompt_messages=[
  308. SystemPromptMessage(
  309. content='You are a helpful AI assistant.',
  310. ),
  311. UserPromptMessage(
  312. content='Hello World!'
  313. )
  314. ],
  315. tools=[
  316. PromptMessageTool(
  317. name='get_current_weather',
  318. description='Get the current weather in a given location',
  319. parameters={
  320. "type": "object",
  321. "properties": {
  322. "location": {
  323. "type": "string",
  324. "description": "The city and state e.g. San Francisco, CA"
  325. },
  326. "unit": {
  327. "type": "string",
  328. "enum": [
  329. "c",
  330. "f"
  331. ]
  332. }
  333. },
  334. "required": [
  335. "location"
  336. ]
  337. }
  338. )
  339. ]
  340. )
  341. assert isinstance(num_tokens, int)
  342. assert num_tokens == 77
  343. num_tokens = model.get_num_tokens(
  344. model='ChatGLM3',
  345. credentials={
  346. 'server_url': os.environ.get('XINFERENCE_SERVER_URL'),
  347. 'model_uid': os.environ.get('XINFERENCE_GENERATION_MODEL_UID')
  348. },
  349. prompt_messages=[
  350. SystemPromptMessage(
  351. content='You are a helpful AI assistant.',
  352. ),
  353. UserPromptMessage(
  354. content='Hello World!'
  355. )
  356. ],
  357. )
  358. assert isinstance(num_tokens, int)
  359. assert num_tokens == 21