workflow_entry.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395
  1. import logging
  2. import time
  3. import uuid
  4. from collections.abc import Generator, Mapping, Sequence
  5. from typing import Any, Optional, cast
  6. from configs import dify_config
  7. from core.app.app_config.entities import FileUploadConfig
  8. from core.app.apps.base_app_queue_manager import GenerateTaskStoppedError
  9. from core.app.entities.app_invoke_entities import InvokeFrom
  10. from core.file.models import File, FileTransferMethod, ImageConfig
  11. from core.workflow.callbacks import WorkflowCallback
  12. from core.workflow.entities.variable_pool import VariablePool
  13. from core.workflow.errors import WorkflowNodeRunFailedError
  14. from core.workflow.graph_engine.entities.event import GraphEngineEvent, GraphRunFailedEvent, InNodeEvent
  15. from core.workflow.graph_engine.entities.graph import Graph
  16. from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
  17. from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
  18. from core.workflow.graph_engine.graph_engine import GraphEngine
  19. from core.workflow.nodes import NodeType
  20. from core.workflow.nodes.base import BaseNode, BaseNodeData
  21. from core.workflow.nodes.event import NodeEvent
  22. from core.workflow.nodes.llm import LLMNodeData
  23. from core.workflow.nodes.node_mapping import node_type_classes_mapping
  24. from factories import file_factory
  25. from models.enums import UserFrom
  26. from models.workflow import (
  27. Workflow,
  28. WorkflowType,
  29. )
  30. logger = logging.getLogger(__name__)
  31. class WorkflowEntry:
  32. def __init__(
  33. self,
  34. tenant_id: str,
  35. app_id: str,
  36. workflow_id: str,
  37. workflow_type: WorkflowType,
  38. graph_config: Mapping[str, Any],
  39. graph: Graph,
  40. user_id: str,
  41. user_from: UserFrom,
  42. invoke_from: InvokeFrom,
  43. call_depth: int,
  44. variable_pool: VariablePool,
  45. thread_pool_id: Optional[str] = None,
  46. ) -> None:
  47. """
  48. Init workflow entry
  49. :param tenant_id: tenant id
  50. :param app_id: app id
  51. :param workflow_id: workflow id
  52. :param workflow_type: workflow type
  53. :param graph_config: workflow graph config
  54. :param graph: workflow graph
  55. :param user_id: user id
  56. :param user_from: user from
  57. :param invoke_from: invoke from
  58. :param call_depth: call depth
  59. :param variable_pool: variable pool
  60. :param thread_pool_id: thread pool id
  61. """
  62. # check call depth
  63. workflow_call_max_depth = dify_config.WORKFLOW_CALL_MAX_DEPTH
  64. if call_depth > workflow_call_max_depth:
  65. raise ValueError("Max workflow call depth {} reached.".format(workflow_call_max_depth))
  66. # init workflow run state
  67. self.graph_engine = GraphEngine(
  68. tenant_id=tenant_id,
  69. app_id=app_id,
  70. workflow_type=workflow_type,
  71. workflow_id=workflow_id,
  72. user_id=user_id,
  73. user_from=user_from,
  74. invoke_from=invoke_from,
  75. call_depth=call_depth,
  76. graph=graph,
  77. graph_config=graph_config,
  78. variable_pool=variable_pool,
  79. max_execution_steps=dify_config.WORKFLOW_MAX_EXECUTION_STEPS,
  80. max_execution_time=dify_config.WORKFLOW_MAX_EXECUTION_TIME,
  81. thread_pool_id=thread_pool_id,
  82. )
  83. def run(
  84. self,
  85. *,
  86. callbacks: Sequence[WorkflowCallback],
  87. ) -> Generator[GraphEngineEvent, None, None]:
  88. """
  89. :param callbacks: workflow callbacks
  90. """
  91. graph_engine = self.graph_engine
  92. try:
  93. # run workflow
  94. generator = graph_engine.run()
  95. for event in generator:
  96. if callbacks:
  97. for callback in callbacks:
  98. callback.on_event(event=event)
  99. yield event
  100. except GenerateTaskStoppedError:
  101. pass
  102. except Exception as e:
  103. logger.exception("Unknown Error when workflow entry running")
  104. if callbacks:
  105. for callback in callbacks:
  106. callback.on_event(event=GraphRunFailedEvent(error=str(e)))
  107. return
  108. @classmethod
  109. def single_step_run(
  110. cls, workflow: Workflow, node_id: str, user_id: str, user_inputs: dict
  111. ) -> tuple[BaseNode, Generator[NodeEvent | InNodeEvent, None, None]]:
  112. """
  113. Single step run workflow node
  114. :param workflow: Workflow instance
  115. :param node_id: node id
  116. :param user_id: user id
  117. :param user_inputs: user inputs
  118. :return:
  119. """
  120. # fetch node info from workflow graph
  121. graph = workflow.graph_dict
  122. if not graph:
  123. raise ValueError("workflow graph not found")
  124. nodes = graph.get("nodes")
  125. if not nodes:
  126. raise ValueError("nodes not found in workflow graph")
  127. # fetch node config from node id
  128. node_config = None
  129. for node in nodes:
  130. if node.get("id") == node_id:
  131. node_config = node
  132. break
  133. if not node_config:
  134. raise ValueError("node id not found in workflow graph")
  135. # Get node class
  136. node_type = NodeType(node_config.get("data", {}).get("type"))
  137. node_cls = node_type_classes_mapping.get(node_type)
  138. node_cls = cast(type[BaseNode], node_cls)
  139. if not node_cls:
  140. raise ValueError(f"Node class not found for node type {node_type}")
  141. # init variable pool
  142. variable_pool = VariablePool(
  143. system_variables={},
  144. user_inputs={},
  145. environment_variables=workflow.environment_variables,
  146. )
  147. # init graph
  148. graph = Graph.init(graph_config=workflow.graph_dict)
  149. # init workflow run state
  150. node_instance = node_cls(
  151. id=str(uuid.uuid4()),
  152. config=node_config,
  153. graph_init_params=GraphInitParams(
  154. tenant_id=workflow.tenant_id,
  155. app_id=workflow.app_id,
  156. workflow_type=WorkflowType.value_of(workflow.type),
  157. workflow_id=workflow.id,
  158. graph_config=workflow.graph_dict,
  159. user_id=user_id,
  160. user_from=UserFrom.ACCOUNT,
  161. invoke_from=InvokeFrom.DEBUGGER,
  162. call_depth=0,
  163. ),
  164. graph=graph,
  165. graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()),
  166. )
  167. try:
  168. # variable selector to variable mapping
  169. try:
  170. variable_mapping = node_cls.extract_variable_selector_to_variable_mapping(
  171. graph_config=workflow.graph_dict, config=node_config
  172. )
  173. except NotImplementedError:
  174. variable_mapping = {}
  175. cls.mapping_user_inputs_to_variable_pool(
  176. variable_mapping=variable_mapping,
  177. user_inputs=user_inputs,
  178. variable_pool=variable_pool,
  179. tenant_id=workflow.tenant_id,
  180. node_type=node_type,
  181. node_data=node_instance.node_data,
  182. )
  183. # run node
  184. generator = node_instance.run()
  185. return node_instance, generator
  186. except Exception as e:
  187. raise WorkflowNodeRunFailedError(node_instance=node_instance, error=str(e))
  188. @classmethod
  189. def run_free_node(
  190. cls, node_data: dict, node_id: str, tenant_id: str, user_id: str, user_inputs: dict[str, Any]
  191. ) -> tuple[BaseNode, Generator[NodeEvent | InNodeEvent, None, None]]:
  192. """
  193. Run free node
  194. NOTE: only parameter_extractor/question_classifier are supported
  195. :param node_data: node data
  196. :param user_id: user id
  197. :param user_inputs: user inputs
  198. :return:
  199. """
  200. # generate a fake graph
  201. node_config = {"id": node_id, "width": 114, "height": 514, "type": "custom", "data": node_data}
  202. start_node_config = {
  203. "id": "start",
  204. "width": 114,
  205. "height": 514,
  206. "type": "custom",
  207. "data": {
  208. "type": NodeType.START.value,
  209. "title": "Start",
  210. "desc": "Start",
  211. },
  212. }
  213. graph_dict = {
  214. "nodes": [start_node_config, node_config],
  215. "edges": [
  216. {
  217. "source": "start",
  218. "target": node_id,
  219. "sourceHandle": "source",
  220. "targetHandle": "target",
  221. }
  222. ],
  223. }
  224. node_type = NodeType(node_data.get("type", ""))
  225. if node_type not in {NodeType.PARAMETER_EXTRACTOR, NodeType.QUESTION_CLASSIFIER}:
  226. raise ValueError(f"Node type {node_type} not supported")
  227. node_cls = node_type_classes_mapping.get(node_type)
  228. if not node_cls:
  229. raise ValueError(f"Node class not found for node type {node_type}")
  230. graph = Graph.init(graph_config=graph_dict)
  231. # init variable pool
  232. variable_pool = VariablePool(
  233. system_variables={},
  234. user_inputs={},
  235. environment_variables=[],
  236. )
  237. node_cls = cast(type[BaseNode], node_cls)
  238. # init workflow run state
  239. node_instance: BaseNode = node_cls(
  240. id=str(uuid.uuid4()),
  241. config=node_config,
  242. graph_init_params=GraphInitParams(
  243. tenant_id=tenant_id,
  244. app_id="",
  245. workflow_type=WorkflowType.WORKFLOW,
  246. workflow_id="",
  247. graph_config=graph_dict,
  248. user_id=user_id,
  249. user_from=UserFrom.ACCOUNT,
  250. invoke_from=InvokeFrom.DEBUGGER,
  251. call_depth=0,
  252. ),
  253. graph=graph,
  254. graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()),
  255. )
  256. try:
  257. # variable selector to variable mapping
  258. try:
  259. variable_mapping = node_cls.extract_variable_selector_to_variable_mapping(
  260. graph_config=graph_dict, config=node_config
  261. )
  262. except NotImplementedError:
  263. variable_mapping = {}
  264. cls.mapping_user_inputs_to_variable_pool(
  265. variable_mapping=variable_mapping,
  266. user_inputs=user_inputs,
  267. variable_pool=variable_pool,
  268. tenant_id=tenant_id,
  269. node_type=node_type,
  270. node_data=node_instance.node_data,
  271. )
  272. # run node
  273. generator = node_instance.run()
  274. return node_instance, generator
  275. except Exception as e:
  276. raise WorkflowNodeRunFailedError(node_instance=node_instance, error=str(e))
  277. @staticmethod
  278. def handle_special_values(value: Optional[Mapping[str, Any]]) -> Mapping[str, Any] | None:
  279. return WorkflowEntry._handle_special_values(value)
  280. @staticmethod
  281. def _handle_special_values(value: Any) -> Any:
  282. if value is None:
  283. return value
  284. if isinstance(value, dict):
  285. res = {}
  286. for k, v in value.items():
  287. res[k] = WorkflowEntry._handle_special_values(v)
  288. return res
  289. if isinstance(value, list):
  290. res = []
  291. for item in value:
  292. res.append(WorkflowEntry._handle_special_values(item))
  293. return res
  294. if isinstance(value, File):
  295. return value.to_dict()
  296. return value
  297. @classmethod
  298. def mapping_user_inputs_to_variable_pool(
  299. cls,
  300. variable_mapping: Mapping[str, Sequence[str]],
  301. user_inputs: dict,
  302. variable_pool: VariablePool,
  303. tenant_id: str,
  304. node_type: NodeType,
  305. node_data: BaseNodeData,
  306. ) -> None:
  307. for node_variable, variable_selector in variable_mapping.items():
  308. # fetch node id and variable key from node_variable
  309. node_variable_list = node_variable.split(".")
  310. if len(node_variable_list) < 1:
  311. raise ValueError(f"Invalid node variable {node_variable}")
  312. node_variable_key = ".".join(node_variable_list[1:])
  313. if (node_variable_key not in user_inputs and node_variable not in user_inputs) and not variable_pool.get(
  314. variable_selector
  315. ):
  316. raise ValueError(f"Variable key {node_variable} not found in user inputs.")
  317. # fetch variable node id from variable selector
  318. variable_node_id = variable_selector[0]
  319. variable_key_list = variable_selector[1:]
  320. variable_key_list = cast(list[str], variable_key_list)
  321. # get input value
  322. input_value = user_inputs.get(node_variable)
  323. if not input_value:
  324. input_value = user_inputs.get(node_variable_key)
  325. # FIXME: temp fix for image type
  326. if node_type == NodeType.LLM:
  327. new_value = []
  328. if isinstance(input_value, list):
  329. node_data = cast(LLMNodeData, node_data)
  330. detail = node_data.vision.configs.detail if node_data.vision.configs else None
  331. for item in input_value:
  332. if isinstance(item, dict) and "type" in item and item["type"] == "image":
  333. transfer_method = FileTransferMethod.value_of(item.get("transfer_method"))
  334. mapping = {
  335. "id": item.get("id"),
  336. "transfer_method": transfer_method,
  337. "upload_file_id": item.get("upload_file_id"),
  338. "url": item.get("url"),
  339. }
  340. config = FileUploadConfig(image_config=ImageConfig(detail=detail) if detail else None)
  341. file = file_factory.build_from_mapping(
  342. mapping=mapping,
  343. tenant_id=tenant_id,
  344. config=config,
  345. )
  346. new_value.append(file)
  347. if new_value:
  348. input_value = new_value
  349. # append variable and value to variable pool
  350. variable_pool.add([variable_node_id] + variable_key_list, input_value)