test_llm.py 8.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239
  1. import json
  2. import os
  3. from unittest.mock import MagicMock
  4. import pytest
  5. from core.app.entities.app_invoke_entities import InvokeFrom, ModelConfigWithCredentialsEntity
  6. from core.entities.provider_configuration import ProviderConfiguration, ProviderModelBundle
  7. from core.entities.provider_entities import CustomConfiguration, CustomProviderConfiguration, SystemConfiguration
  8. from core.model_manager import ModelInstance
  9. from core.model_runtime.entities.model_entities import ModelType
  10. from core.model_runtime.model_providers import ModelProviderFactory
  11. from core.workflow.entities.node_entities import SystemVariable
  12. from core.workflow.entities.variable_pool import VariablePool
  13. from core.workflow.nodes.base_node import UserFrom
  14. from core.workflow.nodes.llm.llm_node import LLMNode
  15. from extensions.ext_database import db
  16. from models.provider import ProviderType
  17. from models.workflow import WorkflowNodeExecutionStatus
  18. """FOR MOCK FIXTURES, DO NOT REMOVE"""
  19. from tests.integration_tests.model_runtime.__mock.openai import setup_openai_mock
  20. from tests.integration_tests.workflow.nodes.__mock.code_executor import setup_code_executor_mock
  21. @pytest.mark.parametrize('setup_openai_mock', [['chat']], indirect=True)
  22. def test_execute_llm(setup_openai_mock):
  23. node = LLMNode(
  24. tenant_id='1',
  25. app_id='1',
  26. workflow_id='1',
  27. user_id='1',
  28. invoke_from=InvokeFrom.WEB_APP,
  29. user_from=UserFrom.ACCOUNT,
  30. config={
  31. 'id': 'llm',
  32. 'data': {
  33. 'title': '123',
  34. 'type': 'llm',
  35. 'model': {
  36. 'provider': 'openai',
  37. 'name': 'gpt-3.5-turbo',
  38. 'mode': 'chat',
  39. 'completion_params': {}
  40. },
  41. 'prompt_template': [
  42. {
  43. 'role': 'system',
  44. 'text': 'you are a helpful assistant.\ntoday\'s weather is {{#abc.output#}}.'
  45. },
  46. {
  47. 'role': 'user',
  48. 'text': '{{#sys.query#}}'
  49. }
  50. ],
  51. 'memory': None,
  52. 'context': {
  53. 'enabled': False
  54. },
  55. 'vision': {
  56. 'enabled': False
  57. }
  58. }
  59. }
  60. )
  61. # construct variable pool
  62. pool = VariablePool(system_variables={
  63. SystemVariable.QUERY: 'what\'s the weather today?',
  64. SystemVariable.FILES: [],
  65. SystemVariable.CONVERSATION_ID: 'abababa',
  66. SystemVariable.USER_ID: 'aaa'
  67. }, user_inputs={})
  68. pool.append_variable(node_id='abc', variable_key_list=['output'], value='sunny')
  69. credentials = {
  70. 'openai_api_key': os.environ.get('OPENAI_API_KEY')
  71. }
  72. provider_instance = ModelProviderFactory().get_provider_instance('openai')
  73. model_type_instance = provider_instance.get_model_instance(ModelType.LLM)
  74. provider_model_bundle = ProviderModelBundle(
  75. configuration=ProviderConfiguration(
  76. tenant_id='1',
  77. provider=provider_instance.get_provider_schema(),
  78. preferred_provider_type=ProviderType.CUSTOM,
  79. using_provider_type=ProviderType.CUSTOM,
  80. system_configuration=SystemConfiguration(
  81. enabled=False
  82. ),
  83. custom_configuration=CustomConfiguration(
  84. provider=CustomProviderConfiguration(
  85. credentials=credentials
  86. )
  87. ),
  88. model_settings=[]
  89. ),
  90. provider_instance=provider_instance,
  91. model_type_instance=model_type_instance
  92. )
  93. model_instance = ModelInstance(provider_model_bundle=provider_model_bundle, model='gpt-3.5-turbo')
  94. model_config = ModelConfigWithCredentialsEntity(
  95. model='gpt-3.5-turbo',
  96. provider='openai',
  97. mode='chat',
  98. credentials=credentials,
  99. parameters={},
  100. model_schema=model_type_instance.get_model_schema('gpt-3.5-turbo'),
  101. provider_model_bundle=provider_model_bundle
  102. )
  103. # Mock db.session.close()
  104. db.session.close = MagicMock()
  105. node._fetch_model_config = MagicMock(return_value=tuple([model_instance, model_config]))
  106. # execute node
  107. result = node.run(pool)
  108. assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
  109. assert result.outputs['text'] is not None
  110. assert result.outputs['usage']['total_tokens'] > 0
  111. @pytest.mark.parametrize('setup_code_executor_mock', [['none']], indirect=True)
  112. @pytest.mark.parametrize('setup_openai_mock', [['chat']], indirect=True)
  113. def test_execute_llm_with_jinja2(setup_code_executor_mock, setup_openai_mock):
  114. """
  115. Test execute LLM node with jinja2
  116. """
  117. node = LLMNode(
  118. tenant_id='1',
  119. app_id='1',
  120. workflow_id='1',
  121. user_id='1',
  122. invoke_from=InvokeFrom.WEB_APP,
  123. user_from=UserFrom.ACCOUNT,
  124. config={
  125. 'id': 'llm',
  126. 'data': {
  127. 'title': '123',
  128. 'type': 'llm',
  129. 'model': {
  130. 'provider': 'openai',
  131. 'name': 'gpt-3.5-turbo',
  132. 'mode': 'chat',
  133. 'completion_params': {}
  134. },
  135. 'prompt_config': {
  136. 'jinja2_variables': [{
  137. 'variable': 'sys_query',
  138. 'value_selector': ['sys', 'query']
  139. }, {
  140. 'variable': 'output',
  141. 'value_selector': ['abc', 'output']
  142. }]
  143. },
  144. 'prompt_template': [
  145. {
  146. 'role': 'system',
  147. 'text': 'you are a helpful assistant.\ntoday\'s weather is {{#abc.output#}}',
  148. 'jinja2_text': 'you are a helpful assistant.\ntoday\'s weather is {{output}}.',
  149. 'edition_type': 'jinja2'
  150. },
  151. {
  152. 'role': 'user',
  153. 'text': '{{#sys.query#}}',
  154. 'jinja2_text': '{{sys_query}}',
  155. 'edition_type': 'basic'
  156. }
  157. ],
  158. 'memory': None,
  159. 'context': {
  160. 'enabled': False
  161. },
  162. 'vision': {
  163. 'enabled': False
  164. }
  165. }
  166. }
  167. )
  168. # construct variable pool
  169. pool = VariablePool(system_variables={
  170. SystemVariable.QUERY: 'what\'s the weather today?',
  171. SystemVariable.FILES: [],
  172. SystemVariable.CONVERSATION_ID: 'abababa',
  173. SystemVariable.USER_ID: 'aaa'
  174. }, user_inputs={})
  175. pool.append_variable(node_id='abc', variable_key_list=['output'], value='sunny')
  176. credentials = {
  177. 'openai_api_key': os.environ.get('OPENAI_API_KEY')
  178. }
  179. provider_instance = ModelProviderFactory().get_provider_instance('openai')
  180. model_type_instance = provider_instance.get_model_instance(ModelType.LLM)
  181. provider_model_bundle = ProviderModelBundle(
  182. configuration=ProviderConfiguration(
  183. tenant_id='1',
  184. provider=provider_instance.get_provider_schema(),
  185. preferred_provider_type=ProviderType.CUSTOM,
  186. using_provider_type=ProviderType.CUSTOM,
  187. system_configuration=SystemConfiguration(
  188. enabled=False
  189. ),
  190. custom_configuration=CustomConfiguration(
  191. provider=CustomProviderConfiguration(
  192. credentials=credentials
  193. )
  194. ),
  195. model_settings=[]
  196. ),
  197. provider_instance=provider_instance,
  198. model_type_instance=model_type_instance,
  199. )
  200. model_instance = ModelInstance(provider_model_bundle=provider_model_bundle, model='gpt-3.5-turbo')
  201. model_config = ModelConfigWithCredentialsEntity(
  202. model='gpt-3.5-turbo',
  203. provider='openai',
  204. mode='chat',
  205. credentials=credentials,
  206. parameters={},
  207. model_schema=model_type_instance.get_model_schema('gpt-3.5-turbo'),
  208. provider_model_bundle=provider_model_bundle
  209. )
  210. # Mock db.session.close()
  211. db.session.close = MagicMock()
  212. node._fetch_model_config = MagicMock(return_value=tuple([model_instance, model_config]))
  213. # execute node
  214. result = node.run(pool)
  215. assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
  216. assert 'sunny' in json.dumps(result.process_data)
  217. assert 'what\'s the weather today?' in json.dumps(result.process_data)