entities.py 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136
  1. import json
  2. from abc import ABC
  3. from enum import StrEnum
  4. from typing import Any, Optional, Union
  5. from pydantic import BaseModel, model_validator
  6. from core.workflow.nodes.base.exc import DefaultValueTypeError
  7. from core.workflow.nodes.enums import ErrorStrategy
  8. class DefaultValueType(StrEnum):
  9. STRING = "string"
  10. NUMBER = "number"
  11. OBJECT = "object"
  12. ARRAY_NUMBER = "array[number]"
  13. ARRAY_STRING = "array[string]"
  14. ARRAY_OBJECT = "array[object]"
  15. ARRAY_FILES = "array[file]"
  16. NumberType = Union[int, float]
  17. class DefaultValue(BaseModel):
  18. value: Any
  19. type: DefaultValueType
  20. key: str
  21. @staticmethod
  22. def _parse_json(value: str) -> Any:
  23. """Unified JSON parsing handler"""
  24. try:
  25. return json.loads(value)
  26. except json.JSONDecodeError:
  27. raise DefaultValueTypeError(f"Invalid JSON format for value: {value}")
  28. @staticmethod
  29. def _validate_array(value: Any, element_type: DefaultValueType) -> bool:
  30. """Unified array type validation"""
  31. return isinstance(value, list) and all(isinstance(x, element_type) for x in value)
  32. @staticmethod
  33. def _convert_number(value: str) -> float:
  34. """Unified number conversion handler"""
  35. try:
  36. return float(value)
  37. except ValueError:
  38. raise DefaultValueTypeError(f"Cannot convert to number: {value}")
  39. @model_validator(mode="after")
  40. def validate_value_type(self) -> "DefaultValue":
  41. if self.type is None:
  42. raise DefaultValueTypeError("type field is required")
  43. # Type validation configuration
  44. type_validators = {
  45. DefaultValueType.STRING: {
  46. "type": str,
  47. "converter": lambda x: x,
  48. },
  49. DefaultValueType.NUMBER: {
  50. "type": NumberType,
  51. "converter": self._convert_number,
  52. },
  53. DefaultValueType.OBJECT: {
  54. "type": dict,
  55. "converter": self._parse_json,
  56. },
  57. DefaultValueType.ARRAY_NUMBER: {
  58. "type": list,
  59. "element_type": NumberType,
  60. "converter": self._parse_json,
  61. },
  62. DefaultValueType.ARRAY_STRING: {
  63. "type": list,
  64. "element_type": str,
  65. "converter": self._parse_json,
  66. },
  67. DefaultValueType.ARRAY_OBJECT: {
  68. "type": list,
  69. "element_type": dict,
  70. "converter": self._parse_json,
  71. },
  72. }
  73. validator = type_validators.get(self.type)
  74. if not validator:
  75. if self.type == DefaultValueType.ARRAY_FILES:
  76. # Handle files type
  77. return self
  78. raise DefaultValueTypeError(f"Unsupported type: {self.type}")
  79. # Handle string input cases
  80. if isinstance(self.value, str) and self.type != DefaultValueType.STRING:
  81. self.value = validator["converter"](self.value)
  82. # Validate base type
  83. if not isinstance(self.value, validator["type"]):
  84. raise DefaultValueTypeError(f"Value must be {validator['type'].__name__} type for {self.value}")
  85. # Validate array element types
  86. if validator["type"] == list and not self._validate_array(self.value, validator["element_type"]):
  87. raise DefaultValueTypeError(f"All elements must be {validator['element_type'].__name__} for {self.value}")
  88. return self
  89. class BaseNodeData(ABC, BaseModel):
  90. title: str
  91. desc: Optional[str] = None
  92. error_strategy: Optional[ErrorStrategy] = None
  93. default_value: Optional[list[DefaultValue]] = None
  94. version: str = "1"
  95. @property
  96. def default_value_dict(self):
  97. if self.default_value:
  98. return {item.key: item.value for item in self.default_value}
  99. return {}
  100. class BaseIterationNodeData(BaseNodeData):
  101. start_node_id: Optional[str] = None
  102. class BaseIterationState(BaseModel):
  103. iteration_node_id: str
  104. index: int
  105. inputs: dict
  106. class MetaData(BaseModel):
  107. pass
  108. metadata: MetaData