test_llm.py 9.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286
  1. import os
  2. from typing import Generator
  3. import pytest
  4. from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
  5. from core.model_runtime.entities.message_entities import (AssistantPromptMessage, PromptMessageTool,
  6. SystemPromptMessage, TextPromptMessageContent,
  7. UserPromptMessage)
  8. from core.model_runtime.entities.model_entities import AIModelEntity
  9. from core.model_runtime.errors.validate import CredentialsValidateFailedError
  10. from core.model_runtime.model_providers.chatglm.llm.llm import ChatGLMLargeLanguageModel
  11. from tests.integration_tests.model_runtime.__mock.openai import setup_openai_mock
  12. def test_predefined_models():
  13. model = ChatGLMLargeLanguageModel()
  14. model_schemas = model.predefined_models()
  15. assert len(model_schemas) >= 1
  16. assert isinstance(model_schemas[0], AIModelEntity)
  17. @pytest.mark.parametrize('setup_openai_mock', [['chat']], indirect=True)
  18. def test_validate_credentials_for_chat_model(setup_openai_mock):
  19. model = ChatGLMLargeLanguageModel()
  20. with pytest.raises(CredentialsValidateFailedError):
  21. model.validate_credentials(
  22. model='chatglm2-6b',
  23. credentials={
  24. 'api_base': 'invalid_key'
  25. }
  26. )
  27. model.validate_credentials(
  28. model='chatglm2-6b',
  29. credentials={
  30. 'api_base': os.environ.get('CHATGLM_API_BASE')
  31. }
  32. )
  33. @pytest.mark.parametrize('setup_openai_mock', [['chat']], indirect=True)
  34. def test_invoke_model(setup_openai_mock):
  35. model = ChatGLMLargeLanguageModel()
  36. response = model.invoke(
  37. model='chatglm2-6b',
  38. credentials={
  39. 'api_base': os.environ.get('CHATGLM_API_BASE')
  40. },
  41. prompt_messages=[
  42. SystemPromptMessage(
  43. content='You are a helpful AI assistant.',
  44. ),
  45. UserPromptMessage(
  46. content='Hello World!'
  47. )
  48. ],
  49. model_parameters={
  50. 'temperature': 0.7,
  51. 'top_p': 1.0,
  52. },
  53. stop=['you'],
  54. user="abc-123",
  55. stream=False
  56. )
  57. assert isinstance(response, LLMResult)
  58. assert len(response.message.content) > 0
  59. assert response.usage.total_tokens > 0
  60. @pytest.mark.parametrize('setup_openai_mock', [['chat']], indirect=True)
  61. def test_invoke_stream_model(setup_openai_mock):
  62. model = ChatGLMLargeLanguageModel()
  63. response = model.invoke(
  64. model='chatglm2-6b',
  65. credentials={
  66. 'api_base': os.environ.get('CHATGLM_API_BASE')
  67. },
  68. prompt_messages=[
  69. SystemPromptMessage(
  70. content='You are a helpful AI assistant.',
  71. ),
  72. UserPromptMessage(
  73. content='Hello World!'
  74. )
  75. ],
  76. model_parameters={
  77. 'temperature': 0.7,
  78. 'top_p': 1.0,
  79. },
  80. stop=['you'],
  81. stream=True,
  82. user="abc-123"
  83. )
  84. assert isinstance(response, Generator)
  85. for chunk in response:
  86. assert isinstance(chunk, LLMResultChunk)
  87. assert isinstance(chunk.delta, LLMResultChunkDelta)
  88. assert isinstance(chunk.delta.message, AssistantPromptMessage)
  89. assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True
  90. @pytest.mark.parametrize('setup_openai_mock', [['chat']], indirect=True)
  91. def test_invoke_stream_model_with_functions(setup_openai_mock):
  92. model = ChatGLMLargeLanguageModel()
  93. response = model.invoke(
  94. model='chatglm3-6b',
  95. credentials={
  96. 'api_base': os.environ.get('CHATGLM_API_BASE')
  97. },
  98. prompt_messages=[
  99. SystemPromptMessage(
  100. content='你是一个天气机器人,你不知道今天的天气怎么样,你需要通过调用一个函数来获取天气信息。'
  101. ),
  102. UserPromptMessage(
  103. content='波士顿天气如何?'
  104. )
  105. ],
  106. model_parameters={
  107. 'temperature': 0,
  108. 'top_p': 1.0,
  109. },
  110. stop=['you'],
  111. user='abc-123',
  112. stream=True,
  113. tools=[
  114. PromptMessageTool(
  115. name='get_current_weather',
  116. description='Get the current weather in a given location',
  117. parameters={
  118. "type": "object",
  119. "properties": {
  120. "location": {
  121. "type": "string",
  122. "description": "The city and state e.g. San Francisco, CA"
  123. },
  124. "unit": {
  125. "type": "string",
  126. "enum": ["celsius", "fahrenheit"]
  127. }
  128. },
  129. "required": [
  130. "location"
  131. ]
  132. }
  133. )
  134. ]
  135. )
  136. assert isinstance(response, Generator)
  137. call: LLMResultChunk = None
  138. chunks = []
  139. for chunk in response:
  140. chunks.append(chunk)
  141. assert isinstance(chunk, LLMResultChunk)
  142. assert isinstance(chunk.delta, LLMResultChunkDelta)
  143. assert isinstance(chunk.delta.message, AssistantPromptMessage)
  144. assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True
  145. if chunk.delta.message.tool_calls and len(chunk.delta.message.tool_calls) > 0:
  146. call = chunk
  147. break
  148. assert call is not None
  149. assert call.delta.message.tool_calls[0].function.name == 'get_current_weather'
  150. @pytest.mark.parametrize('setup_openai_mock', [['chat']], indirect=True)
  151. def test_invoke_model_with_functions(setup_openai_mock):
  152. model = ChatGLMLargeLanguageModel()
  153. response = model.invoke(
  154. model='chatglm3-6b',
  155. credentials={
  156. 'api_base': os.environ.get('CHATGLM_API_BASE')
  157. },
  158. prompt_messages=[
  159. UserPromptMessage(
  160. content='What is the weather like in San Francisco?'
  161. )
  162. ],
  163. model_parameters={
  164. 'temperature': 0.7,
  165. 'top_p': 1.0,
  166. },
  167. stop=['you'],
  168. user='abc-123',
  169. stream=False,
  170. tools=[
  171. PromptMessageTool(
  172. name='get_current_weather',
  173. description='Get the current weather in a given location',
  174. parameters={
  175. "type": "object",
  176. "properties": {
  177. "location": {
  178. "type": "string",
  179. "description": "The city and state e.g. San Francisco, CA"
  180. },
  181. "unit": {
  182. "type": "string",
  183. "enum": [
  184. "c",
  185. "f"
  186. ]
  187. }
  188. },
  189. "required": [
  190. "location"
  191. ]
  192. }
  193. )
  194. ]
  195. )
  196. assert isinstance(response, LLMResult)
  197. assert len(response.message.content) > 0
  198. assert response.usage.total_tokens > 0
  199. assert response.message.tool_calls[0].function.name == 'get_current_weather'
  200. def test_get_num_tokens():
  201. model = ChatGLMLargeLanguageModel()
  202. num_tokens = model.get_num_tokens(
  203. model='chatglm2-6b',
  204. credentials={
  205. 'api_base': os.environ.get('CHATGLM_API_BASE')
  206. },
  207. prompt_messages=[
  208. SystemPromptMessage(
  209. content='You are a helpful AI assistant.',
  210. ),
  211. UserPromptMessage(
  212. content='Hello World!'
  213. )
  214. ],
  215. tools=[
  216. PromptMessageTool(
  217. name='get_current_weather',
  218. description='Get the current weather in a given location',
  219. parameters={
  220. "type": "object",
  221. "properties": {
  222. "location": {
  223. "type": "string",
  224. "description": "The city and state e.g. San Francisco, CA"
  225. },
  226. "unit": {
  227. "type": "string",
  228. "enum": [
  229. "c",
  230. "f"
  231. ]
  232. }
  233. },
  234. "required": [
  235. "location"
  236. ]
  237. }
  238. )
  239. ]
  240. )
  241. assert isinstance(num_tokens, int)
  242. assert num_tokens == 77
  243. num_tokens = model.get_num_tokens(
  244. model='chatglm2-6b',
  245. credentials={
  246. 'api_base': os.environ.get('CHATGLM_API_BASE')
  247. },
  248. prompt_messages=[
  249. SystemPromptMessage(
  250. content='You are a helpful AI assistant.',
  251. ),
  252. UserPromptMessage(
  253. content='Hello World!'
  254. )
  255. ],
  256. )
  257. assert isinstance(num_tokens, int)
  258. assert num_tokens == 21