test_llm.py 8.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235
  1. import json
  2. import os
  3. from unittest.mock import MagicMock
  4. import pytest
  5. from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
  6. from core.entities.provider_configuration import ProviderConfiguration, ProviderModelBundle
  7. from core.entities.provider_entities import CustomConfiguration, CustomProviderConfiguration, SystemConfiguration
  8. from core.model_manager import ModelInstance
  9. from core.model_runtime.entities.model_entities import ModelType
  10. from core.model_runtime.model_providers import ModelProviderFactory
  11. from core.workflow.entities.node_entities import SystemVariable
  12. from core.workflow.entities.variable_pool import VariablePool
  13. from core.workflow.nodes.base_node import UserFrom
  14. from core.workflow.nodes.llm.llm_node import LLMNode
  15. from extensions.ext_database import db
  16. from models.provider import ProviderType
  17. from models.workflow import WorkflowNodeExecutionStatus
  18. """FOR MOCK FIXTURES, DO NOT REMOVE"""
  19. from tests.integration_tests.model_runtime.__mock.openai import setup_openai_mock
  20. from tests.integration_tests.workflow.nodes.__mock.code_executor import setup_code_executor_mock
  21. @pytest.mark.parametrize('setup_openai_mock', [['chat']], indirect=True)
  22. def test_execute_llm(setup_openai_mock):
  23. node = LLMNode(
  24. tenant_id='1',
  25. app_id='1',
  26. workflow_id='1',
  27. user_id='1',
  28. user_from=UserFrom.ACCOUNT,
  29. config={
  30. 'id': 'llm',
  31. 'data': {
  32. 'title': '123',
  33. 'type': 'llm',
  34. 'model': {
  35. 'provider': 'openai',
  36. 'name': 'gpt-3.5-turbo',
  37. 'mode': 'chat',
  38. 'completion_params': {}
  39. },
  40. 'prompt_template': [
  41. {
  42. 'role': 'system',
  43. 'text': 'you are a helpful assistant.\ntoday\'s weather is {{#abc.output#}}.'
  44. },
  45. {
  46. 'role': 'user',
  47. 'text': '{{#sys.query#}}'
  48. }
  49. ],
  50. 'memory': None,
  51. 'context': {
  52. 'enabled': False
  53. },
  54. 'vision': {
  55. 'enabled': False
  56. }
  57. }
  58. }
  59. )
  60. # construct variable pool
  61. pool = VariablePool(system_variables={
  62. SystemVariable.QUERY: 'what\'s the weather today?',
  63. SystemVariable.FILES: [],
  64. SystemVariable.CONVERSATION_ID: 'abababa',
  65. SystemVariable.USER_ID: 'aaa'
  66. }, user_inputs={})
  67. pool.append_variable(node_id='abc', variable_key_list=['output'], value='sunny')
  68. credentials = {
  69. 'openai_api_key': os.environ.get('OPENAI_API_KEY')
  70. }
  71. provider_instance = ModelProviderFactory().get_provider_instance('openai')
  72. model_type_instance = provider_instance.get_model_instance(ModelType.LLM)
  73. provider_model_bundle = ProviderModelBundle(
  74. configuration=ProviderConfiguration(
  75. tenant_id='1',
  76. provider=provider_instance.get_provider_schema(),
  77. preferred_provider_type=ProviderType.CUSTOM,
  78. using_provider_type=ProviderType.CUSTOM,
  79. system_configuration=SystemConfiguration(
  80. enabled=False
  81. ),
  82. custom_configuration=CustomConfiguration(
  83. provider=CustomProviderConfiguration(
  84. credentials=credentials
  85. )
  86. )
  87. ),
  88. provider_instance=provider_instance,
  89. model_type_instance=model_type_instance
  90. )
  91. model_instance = ModelInstance(provider_model_bundle=provider_model_bundle, model='gpt-3.5-turbo')
  92. model_config = ModelConfigWithCredentialsEntity(
  93. model='gpt-3.5-turbo',
  94. provider='openai',
  95. mode='chat',
  96. credentials=credentials,
  97. parameters={},
  98. model_schema=model_type_instance.get_model_schema('gpt-3.5-turbo'),
  99. provider_model_bundle=provider_model_bundle
  100. )
  101. # Mock db.session.close()
  102. db.session.close = MagicMock()
  103. node._fetch_model_config = MagicMock(return_value=tuple([model_instance, model_config]))
  104. # execute node
  105. result = node.run(pool)
  106. assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
  107. assert result.outputs['text'] is not None
  108. assert result.outputs['usage']['total_tokens'] > 0
  109. @pytest.mark.parametrize('setup_code_executor_mock', [['none']], indirect=True)
  110. @pytest.mark.parametrize('setup_openai_mock', [['chat']], indirect=True)
  111. def test_execute_llm_with_jinja2(setup_code_executor_mock, setup_openai_mock):
  112. """
  113. Test execute LLM node with jinja2
  114. """
  115. node = LLMNode(
  116. tenant_id='1',
  117. app_id='1',
  118. workflow_id='1',
  119. user_id='1',
  120. user_from=UserFrom.ACCOUNT,
  121. config={
  122. 'id': 'llm',
  123. 'data': {
  124. 'title': '123',
  125. 'type': 'llm',
  126. 'model': {
  127. 'provider': 'openai',
  128. 'name': 'gpt-3.5-turbo',
  129. 'mode': 'chat',
  130. 'completion_params': {}
  131. },
  132. 'prompt_config': {
  133. 'jinja2_variables': [{
  134. 'variable': 'sys_query',
  135. 'value_selector': ['sys', 'query']
  136. }, {
  137. 'variable': 'output',
  138. 'value_selector': ['abc', 'output']
  139. }]
  140. },
  141. 'prompt_template': [
  142. {
  143. 'role': 'system',
  144. 'text': 'you are a helpful assistant.\ntoday\'s weather is {{#abc.output#}}',
  145. 'jinja2_text': 'you are a helpful assistant.\ntoday\'s weather is {{output}}.',
  146. 'edition_type': 'jinja2'
  147. },
  148. {
  149. 'role': 'user',
  150. 'text': '{{#sys.query#}}',
  151. 'jinja2_text': '{{sys_query}}',
  152. 'edition_type': 'basic'
  153. }
  154. ],
  155. 'memory': None,
  156. 'context': {
  157. 'enabled': False
  158. },
  159. 'vision': {
  160. 'enabled': False
  161. }
  162. }
  163. }
  164. )
  165. # construct variable pool
  166. pool = VariablePool(system_variables={
  167. SystemVariable.QUERY: 'what\'s the weather today?',
  168. SystemVariable.FILES: [],
  169. SystemVariable.CONVERSATION_ID: 'abababa',
  170. SystemVariable.USER_ID: 'aaa'
  171. }, user_inputs={})
  172. pool.append_variable(node_id='abc', variable_key_list=['output'], value='sunny')
  173. credentials = {
  174. 'openai_api_key': os.environ.get('OPENAI_API_KEY')
  175. }
  176. provider_instance = ModelProviderFactory().get_provider_instance('openai')
  177. model_type_instance = provider_instance.get_model_instance(ModelType.LLM)
  178. provider_model_bundle = ProviderModelBundle(
  179. configuration=ProviderConfiguration(
  180. tenant_id='1',
  181. provider=provider_instance.get_provider_schema(),
  182. preferred_provider_type=ProviderType.CUSTOM,
  183. using_provider_type=ProviderType.CUSTOM,
  184. system_configuration=SystemConfiguration(
  185. enabled=False
  186. ),
  187. custom_configuration=CustomConfiguration(
  188. provider=CustomProviderConfiguration(
  189. credentials=credentials
  190. )
  191. )
  192. ),
  193. provider_instance=provider_instance,
  194. model_type_instance=model_type_instance
  195. )
  196. model_instance = ModelInstance(provider_model_bundle=provider_model_bundle, model='gpt-3.5-turbo')
  197. model_config = ModelConfigWithCredentialsEntity(
  198. model='gpt-3.5-turbo',
  199. provider='openai',
  200. mode='chat',
  201. credentials=credentials,
  202. parameters={},
  203. model_schema=model_type_instance.get_model_schema('gpt-3.5-turbo'),
  204. provider_model_bundle=provider_model_bundle
  205. )
  206. # Mock db.session.close()
  207. db.session.close = MagicMock()
  208. node._fetch_model_config = MagicMock(return_value=tuple([model_instance, model_config]))
  209. # execute node
  210. result = node.run(pool)
  211. assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
  212. assert 'sunny' in json.dumps(result.process_data)
  213. assert 'what\'s the weather today?' in json.dumps(result.process_data)