base_node.py 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122
  1. from abc import ABC, abstractmethod
  2. from collections.abc import Generator, Mapping, Sequence
  3. from typing import Any, Optional
  4. from core.workflow.entities.base_node_data_entities import BaseNodeData
  5. from core.workflow.entities.node_entities import NodeRunResult, NodeType
  6. from core.workflow.graph_engine.entities.event import InNodeEvent
  7. from core.workflow.graph_engine.entities.graph import Graph
  8. from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
  9. from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
  10. from core.workflow.nodes.event import RunCompletedEvent, RunEvent
  11. class BaseNode(ABC):
  12. _node_data_cls: type[BaseNodeData]
  13. _node_type: NodeType
  14. def __init__(self,
  15. id: str,
  16. config: Mapping[str, Any],
  17. graph_init_params: GraphInitParams,
  18. graph: Graph,
  19. graph_runtime_state: GraphRuntimeState,
  20. previous_node_id: Optional[str] = None,
  21. thread_pool_id: Optional[str] = None) -> None:
  22. self.id = id
  23. self.tenant_id = graph_init_params.tenant_id
  24. self.app_id = graph_init_params.app_id
  25. self.workflow_type = graph_init_params.workflow_type
  26. self.workflow_id = graph_init_params.workflow_id
  27. self.graph_config = graph_init_params.graph_config
  28. self.user_id = graph_init_params.user_id
  29. self.user_from = graph_init_params.user_from
  30. self.invoke_from = graph_init_params.invoke_from
  31. self.workflow_call_depth = graph_init_params.call_depth
  32. self.graph = graph
  33. self.graph_runtime_state = graph_runtime_state
  34. self.previous_node_id = previous_node_id
  35. self.thread_pool_id = thread_pool_id
  36. node_id = config.get("id")
  37. if not node_id:
  38. raise ValueError("Node ID is required.")
  39. self.node_id = node_id
  40. self.node_data = self._node_data_cls(**config.get("data", {}))
  41. @abstractmethod
  42. def _run(self) \
  43. -> NodeRunResult | Generator[RunEvent | InNodeEvent, None, None]:
  44. """
  45. Run node
  46. :return:
  47. """
  48. raise NotImplementedError
  49. def run(self) -> Generator[RunEvent | InNodeEvent, None, None]:
  50. """
  51. Run node entry
  52. :return:
  53. """
  54. result = self._run()
  55. if isinstance(result, NodeRunResult):
  56. yield RunCompletedEvent(
  57. run_result=result
  58. )
  59. else:
  60. yield from result
  61. @classmethod
  62. def extract_variable_selector_to_variable_mapping(cls, graph_config: Mapping[str, Any], config: dict) -> Mapping[str, Sequence[str]]:
  63. """
  64. Extract variable selector to variable mapping
  65. :param graph_config: graph config
  66. :param config: node config
  67. :return:
  68. """
  69. node_id = config.get("id")
  70. if not node_id:
  71. raise ValueError("Node ID is required when extracting variable selector to variable mapping.")
  72. node_data = cls._node_data_cls(**config.get("data", {}))
  73. return cls._extract_variable_selector_to_variable_mapping(
  74. graph_config=graph_config,
  75. node_id=node_id,
  76. node_data=node_data
  77. )
  78. @classmethod
  79. def _extract_variable_selector_to_variable_mapping(
  80. cls,
  81. graph_config: Mapping[str, Any],
  82. node_id: str,
  83. node_data: BaseNodeData
  84. ) -> Mapping[str, Sequence[str]]:
  85. """
  86. Extract variable selector to variable mapping
  87. :param graph_config: graph config
  88. :param node_id: node id
  89. :param node_data: node data
  90. :return:
  91. """
  92. return {}
  93. @classmethod
  94. def get_default_config(cls, filters: Optional[dict] = None) -> dict:
  95. """
  96. Get default config of node.
  97. :param filters: filter by node config parameters.
  98. :return:
  99. """
  100. return {}
  101. @property
  102. def node_type(self) -> NodeType:
  103. """
  104. Get node type
  105. :return:
  106. """
  107. return self._node_type