| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138 | import loggingfrom abc import abstractmethodfrom collections.abc import Generator, Mapping, Sequencefrom typing import TYPE_CHECKING, Any, Generic, Optional, TypeVar, Union, castfrom core.workflow.entities.node_entities import NodeRunResultfrom core.workflow.nodes.enums import NodeTypefrom core.workflow.nodes.event import NodeEvent, RunCompletedEventfrom models.workflow import WorkflowNodeExecutionStatusfrom .entities import BaseNodeDataif TYPE_CHECKING:    from core.workflow.graph_engine.entities.event import InNodeEvent    from core.workflow.graph_engine.entities.graph import Graph    from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams    from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeStatelogger = logging.getLogger(__name__)GenericNodeData = TypeVar("GenericNodeData", bound=BaseNodeData)class BaseNode(Generic[GenericNodeData]):    _node_data_cls: type[BaseNodeData]    _node_type: NodeType    def __init__(        self,        id: str,        config: Mapping[str, Any],        graph_init_params: "GraphInitParams",        graph: "Graph",        graph_runtime_state: "GraphRuntimeState",        previous_node_id: Optional[str] = None,        thread_pool_id: Optional[str] = None,    ) -> None:        self.id = id        self.tenant_id = graph_init_params.tenant_id        self.app_id = graph_init_params.app_id        self.workflow_type = graph_init_params.workflow_type        self.workflow_id = graph_init_params.workflow_id        self.graph_config = graph_init_params.graph_config        self.user_id = graph_init_params.user_id        self.user_from = graph_init_params.user_from        self.invoke_from = graph_init_params.invoke_from        self.workflow_call_depth = graph_init_params.call_depth        self.graph = graph        self.graph_runtime_state = graph_runtime_state        self.previous_node_id = previous_node_id        self.thread_pool_id = thread_pool_id        node_id = config.get("id")        if not node_id:            raise ValueError("Node ID is required.")        self.node_id = node_id        self.node_data: GenericNodeData = cast(GenericNodeData, self._node_data_cls(**config.get("data", {})))    @abstractmethod    def _run(self) -> NodeRunResult | Generator[Union[NodeEvent, "InNodeEvent"], None, None]:        """        Run node        :return:        """        raise NotImplementedError    def run(self) -> Generator[Union[NodeEvent, "InNodeEvent"], None, None]:        try:            result = self._run()        except Exception as e:            logger.error(f"Node {self.node_id} failed to run: {e}")            result = NodeRunResult(                status=WorkflowNodeExecutionStatus.FAILED,                error=str(e),            )        if isinstance(result, NodeRunResult):            yield RunCompletedEvent(run_result=result)        else:            yield from result    @classmethod    def extract_variable_selector_to_variable_mapping(        cls,        *,        graph_config: Mapping[str, Any],        config: Mapping[str, Any],    ) -> Mapping[str, Sequence[str]]:        """        Extract variable selector to variable mapping        :param graph_config: graph config        :param config: node config        :return:        """        node_id = config.get("id")        if not node_id:            raise ValueError("Node ID is required when extracting variable selector to variable mapping.")        node_data = cls._node_data_cls(**config.get("data", {}))        return cls._extract_variable_selector_to_variable_mapping(            graph_config=graph_config, node_id=node_id, node_data=cast(GenericNodeData, node_data)        )    @classmethod    def _extract_variable_selector_to_variable_mapping(        cls,        *,        graph_config: Mapping[str, Any],        node_id: str,        node_data: GenericNodeData,    ) -> Mapping[str, Sequence[str]]:        """        Extract variable selector to variable mapping        :param graph_config: graph config        :param node_id: node id        :param node_data: node data        :return:        """        return {}    @classmethod    def get_default_config(cls, filters: Optional[dict] = None) -> dict:        """        Get default config of node.        :param filters: filter by node config parameters.        :return:        """        return {}    @property    def node_type(self) -> NodeType:        """        Get node type        :return:        """        return self._node_type
 |