variable_factory.py 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157
  1. from collections.abc import Mapping, Sequence
  2. from typing import Any
  3. from uuid import uuid4
  4. from configs import dify_config
  5. from core.file import File
  6. from core.variables.exc import VariableError
  7. from core.variables.segments import (
  8. ArrayAnySegment,
  9. ArrayFileSegment,
  10. ArrayNumberSegment,
  11. ArrayObjectSegment,
  12. ArraySegment,
  13. ArrayStringSegment,
  14. FileSegment,
  15. FloatSegment,
  16. IntegerSegment,
  17. NoneSegment,
  18. ObjectSegment,
  19. Segment,
  20. StringSegment,
  21. )
  22. from core.variables.types import SegmentType
  23. from core.variables.variables import (
  24. ArrayAnyVariable,
  25. ArrayFileVariable,
  26. ArrayNumberVariable,
  27. ArrayObjectVariable,
  28. ArrayStringVariable,
  29. FileVariable,
  30. FloatVariable,
  31. IntegerVariable,
  32. NoneVariable,
  33. ObjectVariable,
  34. SecretVariable,
  35. StringVariable,
  36. Variable,
  37. )
  38. class InvalidSelectorError(ValueError):
  39. pass
  40. class UnsupportedSegmentTypeError(Exception):
  41. pass
  42. # Define the constant
  43. SEGMENT_TO_VARIABLE_MAP = {
  44. StringSegment: StringVariable,
  45. IntegerSegment: IntegerVariable,
  46. FloatSegment: FloatVariable,
  47. ObjectSegment: ObjectVariable,
  48. FileSegment: FileVariable,
  49. ArrayStringSegment: ArrayStringVariable,
  50. ArrayNumberSegment: ArrayNumberVariable,
  51. ArrayObjectSegment: ArrayObjectVariable,
  52. ArrayFileSegment: ArrayFileVariable,
  53. ArrayAnySegment: ArrayAnyVariable,
  54. NoneSegment: NoneVariable,
  55. }
  56. def build_variable_from_mapping(mapping: Mapping[str, Any], /) -> Variable:
  57. if (value_type := mapping.get("value_type")) is None:
  58. raise VariableError("missing value type")
  59. if not mapping.get("name"):
  60. raise VariableError("missing name")
  61. if (value := mapping.get("value")) is None:
  62. raise VariableError("missing value")
  63. match value_type:
  64. case SegmentType.STRING:
  65. result = StringVariable.model_validate(mapping)
  66. case SegmentType.SECRET:
  67. result = SecretVariable.model_validate(mapping)
  68. case SegmentType.NUMBER if isinstance(value, int):
  69. result = IntegerVariable.model_validate(mapping)
  70. case SegmentType.NUMBER if isinstance(value, float):
  71. result = FloatVariable.model_validate(mapping)
  72. case SegmentType.NUMBER if not isinstance(value, float | int):
  73. raise VariableError(f"invalid number value {value}")
  74. case SegmentType.OBJECT if isinstance(value, dict):
  75. result = ObjectVariable.model_validate(mapping)
  76. case SegmentType.ARRAY_STRING if isinstance(value, list):
  77. result = ArrayStringVariable.model_validate(mapping)
  78. case SegmentType.ARRAY_NUMBER if isinstance(value, list):
  79. result = ArrayNumberVariable.model_validate(mapping)
  80. case SegmentType.ARRAY_OBJECT if isinstance(value, list):
  81. result = ArrayObjectVariable.model_validate(mapping)
  82. case _:
  83. raise VariableError(f"not supported value type {value_type}")
  84. if result.size > dify_config.MAX_VARIABLE_SIZE:
  85. raise VariableError(f"variable size {result.size} exceeds limit {dify_config.MAX_VARIABLE_SIZE}")
  86. return result
  87. def build_segment(value: Any, /) -> Segment:
  88. if value is None:
  89. return NoneSegment()
  90. if isinstance(value, str):
  91. return StringSegment(value=value)
  92. if isinstance(value, int):
  93. return IntegerSegment(value=value)
  94. if isinstance(value, float):
  95. return FloatSegment(value=value)
  96. if isinstance(value, dict):
  97. return ObjectSegment(value=value)
  98. if isinstance(value, File):
  99. return FileSegment(value=value)
  100. if isinstance(value, list):
  101. items = [build_segment(item) for item in value]
  102. types = {item.value_type for item in items}
  103. if len(types) != 1 or all(isinstance(item, ArraySegment) for item in items):
  104. return ArrayAnySegment(value=value)
  105. match types.pop():
  106. case SegmentType.STRING:
  107. return ArrayStringSegment(value=value)
  108. case SegmentType.NUMBER:
  109. return ArrayNumberSegment(value=value)
  110. case SegmentType.OBJECT:
  111. return ArrayObjectSegment(value=value)
  112. case SegmentType.FILE:
  113. return ArrayFileSegment(value=value)
  114. case SegmentType.NONE:
  115. return ArrayAnySegment(value=value)
  116. case _:
  117. raise ValueError(f"not supported value {value}")
  118. raise ValueError(f"not supported value {value}")
  119. def segment_to_variable(
  120. *,
  121. segment: Segment,
  122. selector: Sequence[str],
  123. id: str | None = None,
  124. name: str | None = None,
  125. description: str = "",
  126. ) -> Variable:
  127. if isinstance(segment, Variable):
  128. return segment
  129. name = name or selector[-1]
  130. id = id or str(uuid4())
  131. segment_type = type(segment)
  132. if segment_type not in SEGMENT_TO_VARIABLE_MAP:
  133. raise UnsupportedSegmentTypeError(f"not supported segment type {segment_type}")
  134. variable_class = SEGMENT_TO_VARIABLE_MAP[segment_type]
  135. return variable_class(
  136. id=id,
  137. name=name,
  138. description=description,
  139. value=segment.value,
  140. selector=selector,
  141. )