workflow_entry.py 9.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267
  1. import logging
  2. import time
  3. import uuid
  4. from collections.abc import Generator, Mapping, Sequence
  5. from typing import Any, Optional
  6. from configs import dify_config
  7. from core.app.apps.base_app_queue_manager import GenerateTaskStoppedError
  8. from core.app.entities.app_invoke_entities import InvokeFrom
  9. from core.file.models import File
  10. from core.workflow.callbacks import WorkflowCallback
  11. from core.workflow.entities.variable_pool import VariablePool
  12. from core.workflow.errors import WorkflowNodeRunFailedError
  13. from core.workflow.graph_engine.entities.event import GraphEngineEvent, GraphRunFailedEvent, InNodeEvent
  14. from core.workflow.graph_engine.entities.graph import Graph
  15. from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
  16. from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
  17. from core.workflow.graph_engine.graph_engine import GraphEngine
  18. from core.workflow.nodes import NodeType
  19. from core.workflow.nodes.base import BaseNode
  20. from core.workflow.nodes.event import NodeEvent
  21. from core.workflow.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING
  22. from factories import file_factory
  23. from models.enums import UserFrom
  24. from models.workflow import (
  25. Workflow,
  26. WorkflowType,
  27. )
  28. logger = logging.getLogger(__name__)
  29. class WorkflowEntry:
  30. def __init__(
  31. self,
  32. tenant_id: str,
  33. app_id: str,
  34. workflow_id: str,
  35. workflow_type: WorkflowType,
  36. graph_config: Mapping[str, Any],
  37. graph: Graph,
  38. user_id: str,
  39. user_from: UserFrom,
  40. invoke_from: InvokeFrom,
  41. call_depth: int,
  42. variable_pool: VariablePool,
  43. thread_pool_id: Optional[str] = None,
  44. ) -> None:
  45. """
  46. Init workflow entry
  47. :param tenant_id: tenant id
  48. :param app_id: app id
  49. :param workflow_id: workflow id
  50. :param workflow_type: workflow type
  51. :param graph_config: workflow graph config
  52. :param graph: workflow graph
  53. :param user_id: user id
  54. :param user_from: user from
  55. :param invoke_from: invoke from
  56. :param call_depth: call depth
  57. :param variable_pool: variable pool
  58. :param thread_pool_id: thread pool id
  59. """
  60. # check call depth
  61. workflow_call_max_depth = dify_config.WORKFLOW_CALL_MAX_DEPTH
  62. if call_depth > workflow_call_max_depth:
  63. raise ValueError("Max workflow call depth {} reached.".format(workflow_call_max_depth))
  64. # init workflow run state
  65. self.graph_engine = GraphEngine(
  66. tenant_id=tenant_id,
  67. app_id=app_id,
  68. workflow_type=workflow_type,
  69. workflow_id=workflow_id,
  70. user_id=user_id,
  71. user_from=user_from,
  72. invoke_from=invoke_from,
  73. call_depth=call_depth,
  74. graph=graph,
  75. graph_config=graph_config,
  76. variable_pool=variable_pool,
  77. max_execution_steps=dify_config.WORKFLOW_MAX_EXECUTION_STEPS,
  78. max_execution_time=dify_config.WORKFLOW_MAX_EXECUTION_TIME,
  79. thread_pool_id=thread_pool_id,
  80. )
  81. def run(
  82. self,
  83. *,
  84. callbacks: Sequence[WorkflowCallback],
  85. ) -> Generator[GraphEngineEvent, None, None]:
  86. """
  87. :param callbacks: workflow callbacks
  88. """
  89. graph_engine = self.graph_engine
  90. try:
  91. # run workflow
  92. generator = graph_engine.run()
  93. for event in generator:
  94. if callbacks:
  95. for callback in callbacks:
  96. callback.on_event(event=event)
  97. yield event
  98. except GenerateTaskStoppedError:
  99. pass
  100. except Exception as e:
  101. logger.exception("Unknown Error when workflow entry running")
  102. if callbacks:
  103. for callback in callbacks:
  104. callback.on_event(event=GraphRunFailedEvent(error=str(e)))
  105. return
  106. @classmethod
  107. def single_step_run(
  108. cls,
  109. *,
  110. workflow: Workflow,
  111. node_id: str,
  112. user_id: str,
  113. user_inputs: dict,
  114. ) -> tuple[BaseNode, Generator[NodeEvent | InNodeEvent, None, None]]:
  115. """
  116. Single step run workflow node
  117. :param workflow: Workflow instance
  118. :param node_id: node id
  119. :param user_id: user id
  120. :param user_inputs: user inputs
  121. :return:
  122. """
  123. # fetch node info from workflow graph
  124. workflow_graph = workflow.graph_dict
  125. if not workflow_graph:
  126. raise ValueError("workflow graph not found")
  127. nodes = workflow_graph.get("nodes")
  128. if not nodes:
  129. raise ValueError("nodes not found in workflow graph")
  130. # fetch node config from node id
  131. try:
  132. node_config = next(filter(lambda node: node["id"] == node_id, nodes))
  133. except StopIteration:
  134. raise ValueError("node id not found in workflow graph")
  135. # Get node class
  136. node_type = NodeType(node_config.get("data", {}).get("type"))
  137. node_version = node_config.get("data", {}).get("version", "1")
  138. node_cls = NODE_TYPE_CLASSES_MAPPING[node_type][node_version]
  139. # init variable pool
  140. variable_pool = VariablePool(environment_variables=workflow.environment_variables)
  141. # init graph
  142. graph = Graph.init(graph_config=workflow.graph_dict)
  143. # init workflow run state
  144. node_instance = node_cls(
  145. id=str(uuid.uuid4()),
  146. config=node_config,
  147. graph_init_params=GraphInitParams(
  148. tenant_id=workflow.tenant_id,
  149. app_id=workflow.app_id,
  150. workflow_type=WorkflowType.value_of(workflow.type),
  151. workflow_id=workflow.id,
  152. graph_config=workflow.graph_dict,
  153. user_id=user_id,
  154. user_from=UserFrom.ACCOUNT,
  155. invoke_from=InvokeFrom.DEBUGGER,
  156. call_depth=0,
  157. ),
  158. graph=graph,
  159. graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()),
  160. )
  161. try:
  162. # variable selector to variable mapping
  163. variable_mapping = node_cls.extract_variable_selector_to_variable_mapping(
  164. graph_config=workflow.graph_dict, config=node_config
  165. )
  166. except NotImplementedError:
  167. variable_mapping = {}
  168. cls.mapping_user_inputs_to_variable_pool(
  169. variable_mapping=variable_mapping,
  170. user_inputs=user_inputs,
  171. variable_pool=variable_pool,
  172. tenant_id=workflow.tenant_id,
  173. )
  174. try:
  175. # run node
  176. generator = node_instance.run()
  177. except Exception as e:
  178. raise WorkflowNodeRunFailedError(node_instance=node_instance, error=str(e))
  179. return node_instance, generator
  180. @staticmethod
  181. def handle_special_values(value: Optional[Mapping[str, Any]]) -> Mapping[str, Any] | None:
  182. result = WorkflowEntry._handle_special_values(value)
  183. return result if isinstance(result, Mapping) or result is None else dict(result)
  184. @staticmethod
  185. def _handle_special_values(value: Any) -> Any:
  186. if value is None:
  187. return value
  188. if isinstance(value, dict):
  189. res = {}
  190. for k, v in value.items():
  191. res[k] = WorkflowEntry._handle_special_values(v)
  192. return res
  193. if isinstance(value, list):
  194. res_list = []
  195. for item in value:
  196. res_list.append(WorkflowEntry._handle_special_values(item))
  197. return res_list
  198. if isinstance(value, File):
  199. return value.to_dict()
  200. return value
  201. @classmethod
  202. def mapping_user_inputs_to_variable_pool(
  203. cls,
  204. *,
  205. variable_mapping: Mapping[str, Sequence[str]],
  206. user_inputs: dict,
  207. variable_pool: VariablePool,
  208. tenant_id: str,
  209. ) -> None:
  210. for node_variable, variable_selector in variable_mapping.items():
  211. # fetch node id and variable key from node_variable
  212. node_variable_list = node_variable.split(".")
  213. if len(node_variable_list) < 1:
  214. raise ValueError(f"Invalid node variable {node_variable}")
  215. node_variable_key = ".".join(node_variable_list[1:])
  216. if (node_variable_key not in user_inputs and node_variable not in user_inputs) and not variable_pool.get(
  217. variable_selector
  218. ):
  219. raise ValueError(f"Variable key {node_variable} not found in user inputs.")
  220. # environment variable already exist in variable pool, not from user inputs
  221. if variable_pool.get(variable_selector):
  222. continue
  223. # fetch variable node id from variable selector
  224. variable_node_id = variable_selector[0]
  225. variable_key_list = variable_selector[1:]
  226. variable_key_list = list(variable_key_list)
  227. # get input value
  228. input_value = user_inputs.get(node_variable)
  229. if not input_value:
  230. input_value = user_inputs.get(node_variable_key)
  231. if isinstance(input_value, dict) and "type" in input_value and "transfer_method" in input_value:
  232. input_value = file_factory.build_from_mapping(mapping=input_value, tenant_id=tenant_id)
  233. if (
  234. isinstance(input_value, list)
  235. and all(isinstance(item, dict) for item in input_value)
  236. and all("type" in item and "transfer_method" in item for item in input_value)
  237. ):
  238. input_value = file_factory.build_from_mappings(mappings=input_value, tenant_id=tenant_id)
  239. # append variable and value to variable pool
  240. variable_pool.add([variable_node_id] + variable_key_list, input_value)