entities.py 1.5 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364
  1. from enum import StrEnum
  2. from typing import Any, Optional
  3. from pydantic import Field
  4. from core.workflow.nodes.base import BaseIterationNodeData, BaseIterationState, BaseNodeData
  5. class ErrorHandleMode(StrEnum):
  6. TERMINATED = "terminated"
  7. CONTINUE_ON_ERROR = "continue-on-error"
  8. REMOVE_ABNORMAL_OUTPUT = "remove-abnormal-output"
  9. class IterationNodeData(BaseIterationNodeData):
  10. """
  11. Iteration Node Data.
  12. """
  13. parent_loop_id: Optional[str] = None # redundant field, not used currently
  14. iterator_selector: list[str] # variable selector
  15. output_selector: list[str] # output selector
  16. is_parallel: bool = False # open the parallel mode or not
  17. parallel_nums: int = 10 # the numbers of parallel
  18. error_handle_mode: ErrorHandleMode = ErrorHandleMode.TERMINATED # how to handle the error
  19. class IterationStartNodeData(BaseNodeData):
  20. """
  21. Iteration Start Node Data.
  22. """
  23. pass
  24. class IterationState(BaseIterationState):
  25. """
  26. Iteration State.
  27. """
  28. outputs: list[Any] = Field(default_factory=list)
  29. current_output: Optional[Any] = None
  30. class MetaData(BaseIterationState.MetaData):
  31. """
  32. Data.
  33. """
  34. iterator_length: int
  35. def get_last_output(self) -> Optional[Any]:
  36. """
  37. Get last output.
  38. """
  39. if self.outputs:
  40. return self.outputs[-1]
  41. return None
  42. def get_current_output(self) -> Optional[Any]:
  43. """
  44. Get current output.
  45. """
  46. return self.current_output