entities.py 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165
  1. import json
  2. from abc import ABC
  3. from enum import StrEnum
  4. from typing import Any, Optional, Union
  5. from pydantic import BaseModel, model_validator
  6. from core.workflow.nodes.base.exc import DefaultValueTypeError
  7. from core.workflow.nodes.enums import ErrorStrategy
  8. class DefaultValueType(StrEnum):
  9. STRING = "string"
  10. NUMBER = "number"
  11. OBJECT = "object"
  12. ARRAY_NUMBER = "array[number]"
  13. ARRAY_STRING = "array[string]"
  14. ARRAY_OBJECT = "array[object]"
  15. ARRAY_FILES = "array[file]"
  16. NumberType = Union[int, float]
  17. class DefaultValue(BaseModel):
  18. value: Any
  19. type: DefaultValueType
  20. key: str
  21. @staticmethod
  22. def _parse_json(value: str) -> Any:
  23. """Unified JSON parsing handler"""
  24. try:
  25. return json.loads(value)
  26. except json.JSONDecodeError:
  27. raise DefaultValueTypeError(f"Invalid JSON format for value: {value}")
  28. @staticmethod
  29. def _validate_array(value: Any, element_type: DefaultValueType) -> bool:
  30. """Unified array type validation"""
  31. # FIXME, type ignore here for do not find the reason mypy complain, if find the root cause, please fix it
  32. return isinstance(value, list) and all(isinstance(x, element_type) for x in value) # type: ignore
  33. @staticmethod
  34. def _convert_number(value: str) -> float:
  35. """Unified number conversion handler"""
  36. try:
  37. return float(value)
  38. except ValueError:
  39. raise DefaultValueTypeError(f"Cannot convert to number: {value}")
  40. @model_validator(mode="after")
  41. def validate_value_type(self) -> "DefaultValue":
  42. if self.type is None:
  43. raise DefaultValueTypeError("type field is required")
  44. # Type validation configuration
  45. type_validators = {
  46. DefaultValueType.STRING: {
  47. "type": str,
  48. "converter": lambda x: x,
  49. },
  50. DefaultValueType.NUMBER: {
  51. "type": NumberType,
  52. "converter": self._convert_number,
  53. },
  54. DefaultValueType.OBJECT: {
  55. "type": dict,
  56. "converter": self._parse_json,
  57. },
  58. DefaultValueType.ARRAY_NUMBER: {
  59. "type": list,
  60. "element_type": NumberType,
  61. "converter": self._parse_json,
  62. },
  63. DefaultValueType.ARRAY_STRING: {
  64. "type": list,
  65. "element_type": str,
  66. "converter": self._parse_json,
  67. },
  68. DefaultValueType.ARRAY_OBJECT: {
  69. "type": list,
  70. "element_type": dict,
  71. "converter": self._parse_json,
  72. },
  73. }
  74. validator: dict[str, Any] = type_validators.get(self.type, {})
  75. if not validator:
  76. if self.type == DefaultValueType.ARRAY_FILES:
  77. # Handle files type
  78. return self
  79. raise DefaultValueTypeError(f"Unsupported type: {self.type}")
  80. # Handle string input cases
  81. if isinstance(self.value, str) and self.type != DefaultValueType.STRING:
  82. self.value = validator["converter"](self.value)
  83. # Validate base type
  84. if not isinstance(self.value, validator["type"]):
  85. raise DefaultValueTypeError(f"Value must be {validator['type'].__name__} type for {self.value}")
  86. # Validate array element types
  87. if validator["type"] == list and not self._validate_array(self.value, validator["element_type"]):
  88. raise DefaultValueTypeError(f"All elements must be {validator['element_type'].__name__} for {self.value}")
  89. return self
  90. class RetryConfig(BaseModel):
  91. """node retry config"""
  92. max_retries: int = 0 # max retry times
  93. retry_interval: int = 0 # retry interval in milliseconds
  94. retry_enabled: bool = False # whether retry is enabled
  95. @property
  96. def retry_interval_seconds(self) -> float:
  97. return self.retry_interval / 1000
  98. class BaseNodeData(ABC, BaseModel):
  99. title: str
  100. desc: Optional[str] = None
  101. error_strategy: Optional[ErrorStrategy] = None
  102. default_value: Optional[list[DefaultValue]] = None
  103. version: str = "1"
  104. retry_config: RetryConfig = RetryConfig()
  105. @property
  106. def default_value_dict(self):
  107. if self.default_value:
  108. return {item.key: item.value for item in self.default_value}
  109. return {}
  110. class BaseIterationNodeData(BaseNodeData):
  111. start_node_id: Optional[str] = None
  112. class BaseIterationState(BaseModel):
  113. iteration_node_id: str
  114. index: int
  115. inputs: dict
  116. class MetaData(BaseModel):
  117. pass
  118. metadata: MetaData
  119. class BaseLoopNodeData(BaseNodeData):
  120. start_node_id: Optional[str] = None
  121. class BaseLoopState(BaseModel):
  122. loop_node_id: str
  123. index: int
  124. inputs: dict
  125. class MetaData(BaseModel):
  126. pass
  127. metadata: MetaData