base_node.py 5.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187
  1. from abc import ABC, abstractmethod
  2. from collections.abc import Mapping, Sequence
  3. from enum import Enum
  4. from typing import Any, Optional
  5. from core.app.entities.app_invoke_entities import InvokeFrom
  6. from core.workflow.callbacks.base_workflow_callback import WorkflowCallback
  7. from core.workflow.entities.base_node_data_entities import BaseIterationState, BaseNodeData
  8. from core.workflow.entities.node_entities import NodeRunResult, NodeType
  9. from core.workflow.entities.variable_pool import VariablePool
  10. class UserFrom(Enum):
  11. """
  12. User from
  13. """
  14. ACCOUNT = "account"
  15. END_USER = "end-user"
  16. @classmethod
  17. def value_of(cls, value: str) -> "UserFrom":
  18. """
  19. Value of
  20. :param value: value
  21. :return:
  22. """
  23. for item in cls:
  24. if item.value == value:
  25. return item
  26. raise ValueError(f"Invalid value: {value}")
  27. class BaseNode(ABC):
  28. _node_data_cls: type[BaseNodeData]
  29. _node_type: NodeType
  30. tenant_id: str
  31. app_id: str
  32. workflow_id: str
  33. user_id: str
  34. user_from: UserFrom
  35. invoke_from: InvokeFrom
  36. workflow_call_depth: int
  37. node_id: str
  38. node_data: BaseNodeData
  39. node_run_result: Optional[NodeRunResult] = None
  40. callbacks: Sequence[WorkflowCallback]
  41. def __init__(self, tenant_id: str,
  42. app_id: str,
  43. workflow_id: str,
  44. user_id: str,
  45. user_from: UserFrom,
  46. invoke_from: InvokeFrom,
  47. config: Mapping[str, Any],
  48. callbacks: Sequence[WorkflowCallback] | None = None,
  49. workflow_call_depth: int = 0) -> None:
  50. self.tenant_id = tenant_id
  51. self.app_id = app_id
  52. self.workflow_id = workflow_id
  53. self.user_id = user_id
  54. self.user_from = user_from
  55. self.invoke_from = invoke_from
  56. self.workflow_call_depth = workflow_call_depth
  57. # TODO: May need to check if key exists.
  58. self.node_id = config["id"]
  59. if not self.node_id:
  60. raise ValueError("Node ID is required.")
  61. self.node_data = self._node_data_cls(**config.get("data", {}))
  62. self.callbacks = callbacks or []
  63. @abstractmethod
  64. def _run(self, variable_pool: VariablePool) -> NodeRunResult:
  65. """
  66. Run node
  67. :param variable_pool: variable pool
  68. :return:
  69. """
  70. raise NotImplementedError
  71. def run(self, variable_pool: VariablePool) -> NodeRunResult:
  72. """
  73. Run node entry
  74. :param variable_pool: variable pool
  75. :return:
  76. """
  77. result = self._run(
  78. variable_pool=variable_pool
  79. )
  80. self.node_run_result = result
  81. return result
  82. def publish_text_chunk(self, text: str, value_selector: list[str] = None) -> None:
  83. """
  84. Publish text chunk
  85. :param text: chunk text
  86. :param value_selector: value selector
  87. :return:
  88. """
  89. if self.callbacks:
  90. for callback in self.callbacks:
  91. callback.on_node_text_chunk(
  92. node_id=self.node_id,
  93. text=text,
  94. metadata={
  95. "node_type": self.node_type,
  96. "value_selector": value_selector
  97. }
  98. )
  99. @classmethod
  100. def extract_variable_selector_to_variable_mapping(cls, config: dict):
  101. """
  102. Extract variable selector to variable mapping
  103. :param config: node config
  104. :return:
  105. """
  106. node_data = cls._node_data_cls(**config.get("data", {}))
  107. return cls._extract_variable_selector_to_variable_mapping(node_data)
  108. @classmethod
  109. def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> Mapping[str, Sequence[str]]:
  110. """
  111. Extract variable selector to variable mapping
  112. :param node_data: node data
  113. :return:
  114. """
  115. return {}
  116. @classmethod
  117. def get_default_config(cls, filters: Optional[dict] = None) -> dict:
  118. """
  119. Get default config of node.
  120. :param filters: filter by node config parameters.
  121. :return:
  122. """
  123. return {}
  124. @property
  125. def node_type(self) -> NodeType:
  126. """
  127. Get node type
  128. :return:
  129. """
  130. return self._node_type
  131. class BaseIterationNode(BaseNode):
  132. @abstractmethod
  133. def _run(self, variable_pool: VariablePool) -> BaseIterationState:
  134. """
  135. Run node
  136. :param variable_pool: variable pool
  137. :return:
  138. """
  139. raise NotImplementedError
  140. def run(self, variable_pool: VariablePool) -> BaseIterationState:
  141. """
  142. Run node entry
  143. :param variable_pool: variable pool
  144. :return:
  145. """
  146. return self._run(variable_pool=variable_pool)
  147. def get_next_iteration(self, variable_pool: VariablePool, state: BaseIterationState) -> NodeRunResult | str:
  148. """
  149. Get next iteration start node id based on the graph.
  150. :param graph: graph
  151. :return: next node id
  152. """
  153. return self._get_next_iteration(variable_pool, state)
  154. @abstractmethod
  155. def _get_next_iteration(self, variable_pool: VariablePool, state: BaseIterationState) -> NodeRunResult | str:
  156. """
  157. Get next iteration start node id based on the graph.
  158. :param graph: graph
  159. :return: next node id
  160. """
  161. raise NotImplementedError