test_tool.py 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124
  1. import time
  2. import uuid
  3. from core.app.entities.app_invoke_entities import InvokeFrom
  4. from core.workflow.entities.node_entities import NodeRunResult, UserFrom
  5. from core.workflow.entities.variable_pool import VariablePool
  6. from core.workflow.enums import SystemVariableKey
  7. from core.workflow.graph_engine.entities.graph import Graph
  8. from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
  9. from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
  10. from core.workflow.nodes.tool.tool_node import ToolNode
  11. from models.workflow import WorkflowNodeExecutionStatus, WorkflowType
  12. def init_tool_node(config: dict):
  13. graph_config = {
  14. "edges": [
  15. {
  16. "id": "start-source-next-target",
  17. "source": "start",
  18. "target": "1",
  19. },
  20. ],
  21. "nodes": [{"data": {"type": "start"}, "id": "start"}, config],
  22. }
  23. graph = Graph.init(graph_config=graph_config)
  24. init_params = GraphInitParams(
  25. tenant_id="1",
  26. app_id="1",
  27. workflow_type=WorkflowType.WORKFLOW,
  28. workflow_id="1",
  29. graph_config=graph_config,
  30. user_id="1",
  31. user_from=UserFrom.ACCOUNT,
  32. invoke_from=InvokeFrom.DEBUGGER,
  33. call_depth=0,
  34. )
  35. # construct variable pool
  36. variable_pool = VariablePool(
  37. system_variables={SystemVariableKey.FILES: [], SystemVariableKey.USER_ID: "aaa"},
  38. user_inputs={},
  39. environment_variables=[],
  40. conversation_variables=[],
  41. )
  42. return ToolNode(
  43. id=str(uuid.uuid4()),
  44. graph_init_params=init_params,
  45. graph=graph,
  46. graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()),
  47. config=config,
  48. )
  49. def test_tool_variable_invoke():
  50. node = init_tool_node(
  51. config={
  52. "id": "1",
  53. "data": {
  54. "title": "a",
  55. "desc": "a",
  56. "provider_id": "maths",
  57. "provider_type": "builtin",
  58. "provider_name": "maths",
  59. "tool_name": "eval_expression",
  60. "tool_label": "eval_expression",
  61. "tool_configurations": {},
  62. "tool_parameters": {
  63. "expression": {
  64. "type": "variable",
  65. "value": ["1", "123", "args1"],
  66. }
  67. },
  68. },
  69. }
  70. )
  71. node.graph_runtime_state.variable_pool.add(["1", "123", "args1"], "1+1")
  72. # execute node
  73. result = node._run()
  74. assert isinstance(result, NodeRunResult)
  75. assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
  76. assert result.outputs is not None
  77. assert "2" in result.outputs["text"]
  78. assert result.outputs["files"] == []
  79. def test_tool_mixed_invoke():
  80. node = init_tool_node(
  81. config={
  82. "id": "1",
  83. "data": {
  84. "title": "a",
  85. "desc": "a",
  86. "provider_id": "maths",
  87. "provider_type": "builtin",
  88. "provider_name": "maths",
  89. "tool_name": "eval_expression",
  90. "tool_label": "eval_expression",
  91. "tool_configurations": {},
  92. "tool_parameters": {
  93. "expression": {
  94. "type": "mixed",
  95. "value": "{{#1.args1#}}",
  96. }
  97. },
  98. },
  99. }
  100. )
  101. node.graph_runtime_state.variable_pool.add(["1", "args1"], "1+1")
  102. # execute node
  103. result = node._run()
  104. assert isinstance(result, NodeRunResult)
  105. assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
  106. assert result.outputs is not None
  107. assert "2" in result.outputs["text"]
  108. assert result.outputs["files"] == []