workflow_entry.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397
  1. import logging
  2. import time
  3. import uuid
  4. from collections.abc import Generator, Mapping, Sequence
  5. from typing import Any, Optional, cast
  6. from configs import dify_config
  7. from core.app.app_config.entities import FileExtraConfig
  8. from core.app.apps.base_app_queue_manager import GenerateTaskStoppedError
  9. from core.app.entities.app_invoke_entities import InvokeFrom
  10. from core.file.file_obj import FileTransferMethod, FileType, FileVar
  11. from core.workflow.callbacks.base_workflow_callback import WorkflowCallback
  12. from core.workflow.entities.base_node_data_entities import BaseNodeData
  13. from core.workflow.entities.node_entities import NodeType, UserFrom
  14. from core.workflow.entities.variable_pool import VariablePool
  15. from core.workflow.errors import WorkflowNodeRunFailedError
  16. from core.workflow.graph_engine.entities.event import GraphEngineEvent, GraphRunFailedEvent, InNodeEvent
  17. from core.workflow.graph_engine.entities.graph import Graph
  18. from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
  19. from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
  20. from core.workflow.graph_engine.graph_engine import GraphEngine
  21. from core.workflow.nodes.base_node import BaseNode
  22. from core.workflow.nodes.event import RunEvent
  23. from core.workflow.nodes.llm.entities import LLMNodeData
  24. from core.workflow.nodes.node_mapping import node_classes
  25. from models.workflow import (
  26. Workflow,
  27. WorkflowType,
  28. )
  29. logger = logging.getLogger(__name__)
  30. class WorkflowEntry:
  31. def __init__(
  32. self,
  33. tenant_id: str,
  34. app_id: str,
  35. workflow_id: str,
  36. workflow_type: WorkflowType,
  37. graph_config: Mapping[str, Any],
  38. graph: Graph,
  39. user_id: str,
  40. user_from: UserFrom,
  41. invoke_from: InvokeFrom,
  42. call_depth: int,
  43. variable_pool: VariablePool,
  44. thread_pool_id: Optional[str] = None,
  45. ) -> None:
  46. """
  47. Init workflow entry
  48. :param tenant_id: tenant id
  49. :param app_id: app id
  50. :param workflow_id: workflow id
  51. :param workflow_type: workflow type
  52. :param graph_config: workflow graph config
  53. :param graph: workflow graph
  54. :param user_id: user id
  55. :param user_from: user from
  56. :param invoke_from: invoke from
  57. :param call_depth: call depth
  58. :param variable_pool: variable pool
  59. :param thread_pool_id: thread pool id
  60. """
  61. # check call depth
  62. workflow_call_max_depth = dify_config.WORKFLOW_CALL_MAX_DEPTH
  63. if call_depth > workflow_call_max_depth:
  64. raise ValueError("Max workflow call depth {} reached.".format(workflow_call_max_depth))
  65. # init workflow run state
  66. self.graph_engine = GraphEngine(
  67. tenant_id=tenant_id,
  68. app_id=app_id,
  69. workflow_type=workflow_type,
  70. workflow_id=workflow_id,
  71. user_id=user_id,
  72. user_from=user_from,
  73. invoke_from=invoke_from,
  74. call_depth=call_depth,
  75. graph=graph,
  76. graph_config=graph_config,
  77. variable_pool=variable_pool,
  78. max_execution_steps=dify_config.WORKFLOW_MAX_EXECUTION_STEPS,
  79. max_execution_time=dify_config.WORKFLOW_MAX_EXECUTION_TIME,
  80. thread_pool_id=thread_pool_id,
  81. )
  82. def run(
  83. self,
  84. *,
  85. callbacks: Sequence[WorkflowCallback],
  86. ) -> Generator[GraphEngineEvent, None, None]:
  87. """
  88. :param callbacks: workflow callbacks
  89. """
  90. graph_engine = self.graph_engine
  91. try:
  92. # run workflow
  93. generator = graph_engine.run()
  94. for event in generator:
  95. if callbacks:
  96. for callback in callbacks:
  97. callback.on_event(event=event)
  98. yield event
  99. except GenerateTaskStoppedError:
  100. pass
  101. except Exception as e:
  102. logger.exception("Unknown Error when workflow entry running")
  103. if callbacks:
  104. for callback in callbacks:
  105. callback.on_event(event=GraphRunFailedEvent(error=str(e)))
  106. return
  107. @classmethod
  108. def single_step_run(
  109. cls, workflow: Workflow, node_id: str, user_id: str, user_inputs: dict
  110. ) -> tuple[BaseNode, Generator[RunEvent | InNodeEvent, None, None]]:
  111. """
  112. Single step run workflow node
  113. :param workflow: Workflow instance
  114. :param node_id: node id
  115. :param user_id: user id
  116. :param user_inputs: user inputs
  117. :return:
  118. """
  119. # fetch node info from workflow graph
  120. graph = workflow.graph_dict
  121. if not graph:
  122. raise ValueError("workflow graph not found")
  123. nodes = graph.get("nodes")
  124. if not nodes:
  125. raise ValueError("nodes not found in workflow graph")
  126. # fetch node config from node id
  127. node_config = None
  128. for node in nodes:
  129. if node.get("id") == node_id:
  130. node_config = node
  131. break
  132. if not node_config:
  133. raise ValueError("node id not found in workflow graph")
  134. # Get node class
  135. node_type = NodeType.value_of(node_config.get("data", {}).get("type"))
  136. node_cls = node_classes.get(node_type)
  137. node_cls = cast(type[BaseNode], node_cls)
  138. if not node_cls:
  139. raise ValueError(f"Node class not found for node type {node_type}")
  140. # init variable pool
  141. variable_pool = VariablePool(
  142. system_variables={},
  143. user_inputs={},
  144. environment_variables=workflow.environment_variables,
  145. )
  146. # init graph
  147. graph = Graph.init(graph_config=workflow.graph_dict)
  148. # init workflow run state
  149. node_instance: BaseNode = node_cls(
  150. id=str(uuid.uuid4()),
  151. config=node_config,
  152. graph_init_params=GraphInitParams(
  153. tenant_id=workflow.tenant_id,
  154. app_id=workflow.app_id,
  155. workflow_type=WorkflowType.value_of(workflow.type),
  156. workflow_id=workflow.id,
  157. graph_config=workflow.graph_dict,
  158. user_id=user_id,
  159. user_from=UserFrom.ACCOUNT,
  160. invoke_from=InvokeFrom.DEBUGGER,
  161. call_depth=0,
  162. ),
  163. graph=graph,
  164. graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()),
  165. )
  166. try:
  167. # variable selector to variable mapping
  168. try:
  169. variable_mapping = node_cls.extract_variable_selector_to_variable_mapping(
  170. graph_config=workflow.graph_dict, config=node_config
  171. )
  172. except NotImplementedError:
  173. variable_mapping = {}
  174. cls.mapping_user_inputs_to_variable_pool(
  175. variable_mapping=variable_mapping,
  176. user_inputs=user_inputs,
  177. variable_pool=variable_pool,
  178. tenant_id=workflow.tenant_id,
  179. node_type=node_type,
  180. node_data=node_instance.node_data,
  181. )
  182. # run node
  183. generator = node_instance.run()
  184. return node_instance, generator
  185. except Exception as e:
  186. raise WorkflowNodeRunFailedError(node_instance=node_instance, error=str(e))
  187. @classmethod
  188. def run_free_node(
  189. cls, node_data: dict, node_id: str, tenant_id: str, user_id: str, user_inputs: dict[str, Any]
  190. ) -> tuple[BaseNode, Generator[RunEvent | InNodeEvent, None, None]]:
  191. """
  192. Run free node
  193. NOTE: only parameter_extractor/question_classifier are supported
  194. :param node_data: node data
  195. :param user_id: user id
  196. :param user_inputs: user inputs
  197. :return:
  198. """
  199. # generate a fake graph
  200. node_config = {"id": node_id, "width": 114, "height": 514, "type": "custom", "data": node_data}
  201. start_node_config = {
  202. "id": "start",
  203. "width": 114,
  204. "height": 514,
  205. "type": "custom",
  206. "data": {
  207. "type": NodeType.START.value,
  208. "title": "Start",
  209. "desc": "Start",
  210. },
  211. }
  212. graph_dict = {
  213. "nodes": [start_node_config, node_config],
  214. "edges": [
  215. {
  216. "source": "start",
  217. "target": node_id,
  218. "sourceHandle": "source",
  219. "targetHandle": "target",
  220. }
  221. ],
  222. }
  223. node_type = NodeType.value_of(node_data.get("type", ""))
  224. if node_type not in {NodeType.PARAMETER_EXTRACTOR, NodeType.QUESTION_CLASSIFIER}:
  225. raise ValueError(f"Node type {node_type} not supported")
  226. node_cls = node_classes.get(node_type)
  227. if not node_cls:
  228. raise ValueError(f"Node class not found for node type {node_type}")
  229. graph = Graph.init(graph_config=graph_dict)
  230. # init variable pool
  231. variable_pool = VariablePool(
  232. system_variables={},
  233. user_inputs={},
  234. environment_variables=[],
  235. )
  236. node_cls = cast(type[BaseNode], node_cls)
  237. # init workflow run state
  238. node_instance: BaseNode = node_cls(
  239. id=str(uuid.uuid4()),
  240. config=node_config,
  241. graph_init_params=GraphInitParams(
  242. tenant_id=tenant_id,
  243. app_id="",
  244. workflow_type=WorkflowType.WORKFLOW,
  245. workflow_id="",
  246. graph_config=graph_dict,
  247. user_id=user_id,
  248. user_from=UserFrom.ACCOUNT,
  249. invoke_from=InvokeFrom.DEBUGGER,
  250. call_depth=0,
  251. ),
  252. graph=graph,
  253. graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()),
  254. )
  255. try:
  256. # variable selector to variable mapping
  257. try:
  258. variable_mapping = node_cls.extract_variable_selector_to_variable_mapping(
  259. graph_config=graph_dict, config=node_config
  260. )
  261. except NotImplementedError:
  262. variable_mapping = {}
  263. cls.mapping_user_inputs_to_variable_pool(
  264. variable_mapping=variable_mapping,
  265. user_inputs=user_inputs,
  266. variable_pool=variable_pool,
  267. tenant_id=tenant_id,
  268. node_type=node_type,
  269. node_data=node_instance.node_data,
  270. )
  271. # run node
  272. generator = node_instance.run()
  273. return node_instance, generator
  274. except Exception as e:
  275. raise WorkflowNodeRunFailedError(node_instance=node_instance, error=str(e))
  276. @classmethod
  277. def handle_special_values(cls, value: Optional[Mapping[str, Any]]) -> Optional[dict]:
  278. """
  279. Handle special values
  280. :param value: value
  281. :return:
  282. """
  283. if not value:
  284. return None
  285. new_value = dict(value) if value else {}
  286. if isinstance(new_value, dict):
  287. for key, val in new_value.items():
  288. if isinstance(val, FileVar):
  289. new_value[key] = val.to_dict()
  290. elif isinstance(val, list):
  291. new_val = []
  292. for v in val:
  293. if isinstance(v, FileVar):
  294. new_val.append(v.to_dict())
  295. else:
  296. new_val.append(v)
  297. new_value[key] = new_val
  298. return new_value
  299. @classmethod
  300. def mapping_user_inputs_to_variable_pool(
  301. cls,
  302. variable_mapping: Mapping[str, Sequence[str]],
  303. user_inputs: dict,
  304. variable_pool: VariablePool,
  305. tenant_id: str,
  306. node_type: NodeType,
  307. node_data: BaseNodeData,
  308. ) -> None:
  309. for node_variable, variable_selector in variable_mapping.items():
  310. # fetch node id and variable key from node_variable
  311. node_variable_list = node_variable.split(".")
  312. if len(node_variable_list) < 1:
  313. raise ValueError(f"Invalid node variable {node_variable}")
  314. node_variable_key = ".".join(node_variable_list[1:])
  315. if (node_variable_key not in user_inputs and node_variable not in user_inputs) and not variable_pool.get(
  316. variable_selector
  317. ):
  318. raise ValueError(f"Variable key {node_variable} not found in user inputs.")
  319. # fetch variable node id from variable selector
  320. variable_node_id = variable_selector[0]
  321. variable_key_list = variable_selector[1:]
  322. variable_key_list = cast(list[str], variable_key_list)
  323. # get input value
  324. input_value = user_inputs.get(node_variable)
  325. if not input_value:
  326. input_value = user_inputs.get(node_variable_key)
  327. # FIXME: temp fix for image type
  328. if node_type == NodeType.LLM:
  329. new_value = []
  330. if isinstance(input_value, list):
  331. node_data = cast(LLMNodeData, node_data)
  332. detail = node_data.vision.configs.detail if node_data.vision.configs else None
  333. for item in input_value:
  334. if isinstance(item, dict) and "type" in item and item["type"] == "image":
  335. transfer_method = FileTransferMethod.value_of(item.get("transfer_method"))
  336. file = FileVar(
  337. tenant_id=tenant_id,
  338. type=FileType.IMAGE,
  339. transfer_method=transfer_method,
  340. url=item.get("url") if transfer_method == FileTransferMethod.REMOTE_URL else None,
  341. related_id=item.get("upload_file_id")
  342. if transfer_method == FileTransferMethod.LOCAL_FILE
  343. else None,
  344. extra_config=FileExtraConfig(image_config={"detail": detail} if detail else None),
  345. )
  346. new_value.append(file)
  347. if new_value:
  348. value = new_value
  349. # append variable and value to variable pool
  350. variable_pool.add([variable_node_id] + variable_key_list, input_value)