node.py 5.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159
  1. import logging
  2. from abc import abstractmethod
  3. from collections.abc import Generator, Mapping, Sequence
  4. from typing import TYPE_CHECKING, Any, Generic, Optional, TypeVar, Union, cast
  5. from core.workflow.entities.node_entities import NodeRunResult
  6. from core.workflow.nodes.enums import CONTINUE_ON_ERROR_NODE_TYPE, RETRY_ON_ERROR_NODE_TYPE, NodeType
  7. from core.workflow.nodes.event import NodeEvent, RunCompletedEvent
  8. from models.workflow import WorkflowNodeExecutionStatus
  9. from .entities import BaseNodeData
  10. if TYPE_CHECKING:
  11. from core.workflow.graph_engine.entities.event import InNodeEvent
  12. from core.workflow.graph_engine.entities.graph import Graph
  13. from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
  14. from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
  15. logger = logging.getLogger(__name__)
  16. GenericNodeData = TypeVar("GenericNodeData", bound=BaseNodeData)
  17. class BaseNode(Generic[GenericNodeData]):
  18. _node_data_cls: type[BaseNodeData]
  19. _node_type: NodeType
  20. def __init__(
  21. self,
  22. id: str,
  23. config: Mapping[str, Any],
  24. graph_init_params: "GraphInitParams",
  25. graph: "Graph",
  26. graph_runtime_state: "GraphRuntimeState",
  27. previous_node_id: Optional[str] = None,
  28. thread_pool_id: Optional[str] = None,
  29. ) -> None:
  30. self.id = id
  31. self.tenant_id = graph_init_params.tenant_id
  32. self.app_id = graph_init_params.app_id
  33. self.workflow_type = graph_init_params.workflow_type
  34. self.workflow_id = graph_init_params.workflow_id
  35. self.graph_config = graph_init_params.graph_config
  36. self.user_id = graph_init_params.user_id
  37. self.user_from = graph_init_params.user_from
  38. self.invoke_from = graph_init_params.invoke_from
  39. self.workflow_call_depth = graph_init_params.call_depth
  40. self.graph = graph
  41. self.graph_runtime_state = graph_runtime_state
  42. self.previous_node_id = previous_node_id
  43. self.thread_pool_id = thread_pool_id
  44. node_id = config.get("id")
  45. if not node_id:
  46. raise ValueError("Node ID is required.")
  47. self.node_id = node_id
  48. node_data = self._node_data_cls.model_validate(config.get("data", {}))
  49. self.node_data = cast(GenericNodeData, node_data)
  50. @abstractmethod
  51. def _run(self) -> NodeRunResult | Generator[Union[NodeEvent, "InNodeEvent"], None, None]:
  52. """
  53. Run node
  54. :return:
  55. """
  56. raise NotImplementedError
  57. def run(self) -> Generator[Union[NodeEvent, "InNodeEvent"], None, None]:
  58. try:
  59. result = self._run()
  60. except Exception as e:
  61. logger.exception(f"Node {self.node_id} failed to run")
  62. result = NodeRunResult(
  63. status=WorkflowNodeExecutionStatus.FAILED,
  64. error=str(e),
  65. error_type="WorkflowNodeError",
  66. )
  67. if isinstance(result, NodeRunResult):
  68. yield RunCompletedEvent(run_result=result)
  69. else:
  70. yield from result
  71. @classmethod
  72. def extract_variable_selector_to_variable_mapping(
  73. cls,
  74. *,
  75. graph_config: Mapping[str, Any],
  76. config: Mapping[str, Any],
  77. ) -> Mapping[str, Sequence[str]]:
  78. """
  79. Extract variable selector to variable mapping
  80. :param graph_config: graph config
  81. :param config: node config
  82. :return:
  83. """
  84. node_id = config.get("id")
  85. if not node_id:
  86. raise ValueError("Node ID is required when extracting variable selector to variable mapping.")
  87. node_data = cls._node_data_cls(**config.get("data", {}))
  88. return cls._extract_variable_selector_to_variable_mapping(
  89. graph_config=graph_config, node_id=node_id, node_data=cast(GenericNodeData, node_data)
  90. )
  91. @classmethod
  92. def _extract_variable_selector_to_variable_mapping(
  93. cls,
  94. *,
  95. graph_config: Mapping[str, Any],
  96. node_id: str,
  97. node_data: GenericNodeData,
  98. ) -> Mapping[str, Sequence[str]]:
  99. """
  100. Extract variable selector to variable mapping
  101. :param graph_config: graph config
  102. :param node_id: node id
  103. :param node_data: node data
  104. :return:
  105. """
  106. return {}
  107. @classmethod
  108. def get_default_config(cls, filters: Optional[dict] = None) -> dict:
  109. """
  110. Get default config of node.
  111. :param filters: filter by node config parameters.
  112. :return:
  113. """
  114. return {}
  115. @property
  116. def node_type(self) -> NodeType:
  117. """
  118. Get node type
  119. :return:
  120. """
  121. return self._node_type
  122. @property
  123. def should_continue_on_error(self) -> bool:
  124. """judge if should continue on error
  125. Returns:
  126. bool: if should continue on error
  127. """
  128. return self.node_data.error_strategy is not None and self.node_type in CONTINUE_ON_ERROR_NODE_TYPE
  129. @property
  130. def should_retry(self) -> bool:
  131. """judge if should retry
  132. Returns:
  133. bool: if should retry
  134. """
  135. return self.node_data.retry_config.retry_enabled and self.node_type in RETRY_ON_ERROR_NODE_TYPE