test_llm.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408
  1. import os
  2. from collections.abc import Generator
  3. import pytest
  4. from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
  5. from core.model_runtime.entities.message_entities import (
  6. AssistantPromptMessage,
  7. ImagePromptMessageContent,
  8. PromptMessageTool,
  9. SystemPromptMessage,
  10. TextPromptMessageContent,
  11. UserPromptMessage,
  12. )
  13. from core.model_runtime.entities.model_entities import AIModelEntity, ModelType
  14. from core.model_runtime.errors.validate import CredentialsValidateFailedError
  15. from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
  16. from core.model_runtime.model_providers.openai.llm.llm import OpenAILargeLanguageModel
  17. """FOR MOCK FIXTURES, DO NOT REMOVE"""
  18. from tests.integration_tests.model_runtime.__mock.openai import setup_openai_mock
  19. def test_predefined_models():
  20. model = OpenAILargeLanguageModel()
  21. model_schemas = model.predefined_models()
  22. assert len(model_schemas) >= 1
  23. assert isinstance(model_schemas[0], AIModelEntity)
  24. @pytest.mark.parametrize('setup_openai_mock', [['chat']], indirect=True)
  25. def test_validate_credentials_for_chat_model(setup_openai_mock):
  26. model = OpenAILargeLanguageModel()
  27. with pytest.raises(CredentialsValidateFailedError):
  28. model.validate_credentials(
  29. model='gpt-3.5-turbo',
  30. credentials={
  31. 'openai_api_key': 'invalid_key'
  32. }
  33. )
  34. model.validate_credentials(
  35. model='gpt-3.5-turbo',
  36. credentials={
  37. 'openai_api_key': os.environ.get('OPENAI_API_KEY')
  38. }
  39. )
  40. @pytest.mark.parametrize('setup_openai_mock', [['completion']], indirect=True)
  41. def test_validate_credentials_for_completion_model(setup_openai_mock):
  42. model = OpenAILargeLanguageModel()
  43. with pytest.raises(CredentialsValidateFailedError):
  44. model.validate_credentials(
  45. model='text-davinci-003',
  46. credentials={
  47. 'openai_api_key': 'invalid_key'
  48. }
  49. )
  50. model.validate_credentials(
  51. model='text-davinci-003',
  52. credentials={
  53. 'openai_api_key': os.environ.get('OPENAI_API_KEY')
  54. }
  55. )
  56. @pytest.mark.parametrize('setup_openai_mock', [['completion']], indirect=True)
  57. def test_invoke_completion_model(setup_openai_mock):
  58. model = OpenAILargeLanguageModel()
  59. result = model.invoke(
  60. model='gpt-3.5-turbo-instruct',
  61. credentials={
  62. 'openai_api_key': os.environ.get('OPENAI_API_KEY'),
  63. 'openai_api_base': 'https://api.openai.com'
  64. },
  65. prompt_messages=[
  66. UserPromptMessage(
  67. content='Hello World!'
  68. )
  69. ],
  70. model_parameters={
  71. 'temperature': 0.0,
  72. 'max_tokens': 1
  73. },
  74. stream=False,
  75. user="abc-123"
  76. )
  77. assert isinstance(result, LLMResult)
  78. assert len(result.message.content) > 0
  79. assert model._num_tokens_from_string('gpt-3.5-turbo-instruct', result.message.content) == 1
  80. @pytest.mark.parametrize('setup_openai_mock', [['completion']], indirect=True)
  81. def test_invoke_stream_completion_model(setup_openai_mock):
  82. model = OpenAILargeLanguageModel()
  83. result = model.invoke(
  84. model='gpt-3.5-turbo-instruct',
  85. credentials={
  86. 'openai_api_key': os.environ.get('OPENAI_API_KEY'),
  87. 'openai_organization': os.environ.get('OPENAI_ORGANIZATION'),
  88. },
  89. prompt_messages=[
  90. UserPromptMessage(
  91. content='Hello World!'
  92. )
  93. ],
  94. model_parameters={
  95. 'temperature': 0.0,
  96. 'max_tokens': 100
  97. },
  98. stream=True,
  99. user="abc-123"
  100. )
  101. assert isinstance(result, Generator)
  102. for chunk in result:
  103. assert isinstance(chunk, LLMResultChunk)
  104. assert isinstance(chunk.delta, LLMResultChunkDelta)
  105. assert isinstance(chunk.delta.message, AssistantPromptMessage)
  106. assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True
  107. @pytest.mark.parametrize('setup_openai_mock', [['chat']], indirect=True)
  108. def test_invoke_chat_model(setup_openai_mock):
  109. model = OpenAILargeLanguageModel()
  110. result = model.invoke(
  111. model='gpt-3.5-turbo',
  112. credentials={
  113. 'openai_api_key': os.environ.get('OPENAI_API_KEY')
  114. },
  115. prompt_messages=[
  116. SystemPromptMessage(
  117. content='You are a helpful AI assistant.',
  118. ),
  119. UserPromptMessage(
  120. content='Hello World!'
  121. )
  122. ],
  123. model_parameters={
  124. 'temperature': 0.0,
  125. 'top_p': 1.0,
  126. 'presence_penalty': 0.0,
  127. 'frequency_penalty': 0.0,
  128. 'max_tokens': 10
  129. },
  130. stop=['How'],
  131. stream=False,
  132. user="abc-123"
  133. )
  134. assert isinstance(result, LLMResult)
  135. assert len(result.message.content) > 0
  136. @pytest.mark.parametrize('setup_openai_mock', [['chat']], indirect=True)
  137. def test_invoke_chat_model_with_vision(setup_openai_mock):
  138. model = OpenAILargeLanguageModel()
  139. result = model.invoke(
  140. model='gpt-4-vision-preview',
  141. credentials={
  142. 'openai_api_key': os.environ.get('OPENAI_API_KEY')
  143. },
  144. prompt_messages=[
  145. SystemPromptMessage(
  146. content='You are a helpful AI assistant.',
  147. ),
  148. UserPromptMessage(
  149. content=[
  150. TextPromptMessageContent(
  151. data='Hello World!',
  152. ),
  153. ImagePromptMessageContent(
  154. data=''
  155. )
  156. ]
  157. )
  158. ],
  159. model_parameters={
  160. 'temperature': 0.0,
  161. 'max_tokens': 100
  162. },
  163. stream=False,
  164. user="abc-123"
  165. )
  166. assert isinstance(result, LLMResult)
  167. assert len(result.message.content) > 0
  168. @pytest.mark.parametrize('setup_openai_mock', [['chat']], indirect=True)
  169. def test_invoke_chat_model_with_tools(setup_openai_mock):
  170. model = OpenAILargeLanguageModel()
  171. result = model.invoke(
  172. model='gpt-3.5-turbo',
  173. credentials={
  174. 'openai_api_key': os.environ.get('OPENAI_API_KEY')
  175. },
  176. prompt_messages=[
  177. SystemPromptMessage(
  178. content='You are a helpful AI assistant.',
  179. ),
  180. UserPromptMessage(
  181. content="what's the weather today in London?",
  182. )
  183. ],
  184. model_parameters={
  185. 'temperature': 0.0,
  186. 'max_tokens': 100
  187. },
  188. tools=[
  189. PromptMessageTool(
  190. name='get_weather',
  191. description='Determine weather in my location',
  192. parameters={
  193. "type": "object",
  194. "properties": {
  195. "location": {
  196. "type": "string",
  197. "description": "The city and state e.g. San Francisco, CA"
  198. },
  199. "unit": {
  200. "type": "string",
  201. "enum": [
  202. "c",
  203. "f"
  204. ]
  205. }
  206. },
  207. "required": [
  208. "location"
  209. ]
  210. }
  211. ),
  212. PromptMessageTool(
  213. name='get_stock_price',
  214. description='Get the current stock price',
  215. parameters={
  216. "type": "object",
  217. "properties": {
  218. "symbol": {
  219. "type": "string",
  220. "description": "The stock symbol"
  221. }
  222. },
  223. "required": [
  224. "symbol"
  225. ]
  226. }
  227. )
  228. ],
  229. stream=False,
  230. user="abc-123"
  231. )
  232. assert isinstance(result, LLMResult)
  233. assert isinstance(result.message, AssistantPromptMessage)
  234. assert len(result.message.tool_calls) > 0
  235. @pytest.mark.parametrize('setup_openai_mock', [['chat']], indirect=True)
  236. def test_invoke_stream_chat_model(setup_openai_mock):
  237. model = OpenAILargeLanguageModel()
  238. result = model.invoke(
  239. model='gpt-3.5-turbo',
  240. credentials={
  241. 'openai_api_key': os.environ.get('OPENAI_API_KEY')
  242. },
  243. prompt_messages=[
  244. SystemPromptMessage(
  245. content='You are a helpful AI assistant.',
  246. ),
  247. UserPromptMessage(
  248. content='Hello World!'
  249. )
  250. ],
  251. model_parameters={
  252. 'temperature': 0.0,
  253. 'max_tokens': 100
  254. },
  255. stream=True,
  256. user="abc-123"
  257. )
  258. assert isinstance(result, Generator)
  259. for chunk in result:
  260. assert isinstance(chunk, LLMResultChunk)
  261. assert isinstance(chunk.delta, LLMResultChunkDelta)
  262. assert isinstance(chunk.delta.message, AssistantPromptMessage)
  263. assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True
  264. if chunk.delta.finish_reason is not None:
  265. assert chunk.delta.usage is not None
  266. assert chunk.delta.usage.completion_tokens > 0
  267. def test_get_num_tokens():
  268. model = OpenAILargeLanguageModel()
  269. num_tokens = model.get_num_tokens(
  270. model='gpt-3.5-turbo-instruct',
  271. credentials={
  272. 'openai_api_key': os.environ.get('OPENAI_API_KEY')
  273. },
  274. prompt_messages=[
  275. UserPromptMessage(
  276. content='Hello World!'
  277. )
  278. ]
  279. )
  280. assert num_tokens == 3
  281. num_tokens = model.get_num_tokens(
  282. model='gpt-3.5-turbo',
  283. credentials={
  284. 'openai_api_key': os.environ.get('OPENAI_API_KEY')
  285. },
  286. prompt_messages=[
  287. SystemPromptMessage(
  288. content='You are a helpful AI assistant.',
  289. ),
  290. UserPromptMessage(
  291. content='Hello World!'
  292. )
  293. ],
  294. tools=[
  295. PromptMessageTool(
  296. name='get_weather',
  297. description='Determine weather in my location',
  298. parameters={
  299. "type": "object",
  300. "properties": {
  301. "location": {
  302. "type": "string",
  303. "description": "The city and state e.g. San Francisco, CA"
  304. },
  305. "unit": {
  306. "type": "string",
  307. "enum": [
  308. "c",
  309. "f"
  310. ]
  311. }
  312. },
  313. "required": [
  314. "location"
  315. ]
  316. }
  317. ),
  318. ]
  319. )
  320. assert num_tokens == 72
  321. @pytest.mark.parametrize('setup_openai_mock', [['chat', 'remote']], indirect=True)
  322. def test_fine_tuned_models(setup_openai_mock):
  323. model = OpenAILargeLanguageModel()
  324. remote_models = model.remote_models(credentials={
  325. 'openai_api_key': os.environ.get('OPENAI_API_KEY')
  326. })
  327. if not remote_models:
  328. assert isinstance(remote_models, list)
  329. else:
  330. assert isinstance(remote_models[0], AIModelEntity)
  331. for llm_model in remote_models:
  332. if llm_model.model_type == ModelType.LLM:
  333. break
  334. assert isinstance(llm_model, AIModelEntity)
  335. # test invoke
  336. result = model.invoke(
  337. model=llm_model.model,
  338. credentials={
  339. 'openai_api_key': os.environ.get('OPENAI_API_KEY')
  340. },
  341. prompt_messages=[
  342. SystemPromptMessage(
  343. content='You are a helpful AI assistant.',
  344. ),
  345. UserPromptMessage(
  346. content='Hello World!'
  347. )
  348. ],
  349. model_parameters={
  350. 'temperature': 0.0,
  351. 'max_tokens': 100
  352. },
  353. stream=False,
  354. user="abc-123"
  355. )
  356. assert isinstance(result, LLMResult)
  357. def test__get_num_tokens_by_gpt2():
  358. model = OpenAILargeLanguageModel()
  359. num_tokens = model._get_num_tokens_by_gpt2('Hello World!')
  360. assert num_tokens == 3