variable_pool.py 6.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177
  1. import re
  2. from collections import defaultdict
  3. from collections.abc import Mapping, Sequence
  4. from typing import Any, Union
  5. from pydantic import BaseModel, Field
  6. from core.file import File, FileAttribute, file_manager
  7. from core.variables import Segment, SegmentGroup, Variable
  8. from core.variables.segments import FileSegment, NoneSegment
  9. from factories import variable_factory
  10. from ..constants import CONVERSATION_VARIABLE_NODE_ID, ENVIRONMENT_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID
  11. from ..enums import SystemVariableKey
  12. VariableValue = Union[str, int, float, dict, list, File]
  13. VARIABLE_PATTERN = re.compile(r"\{\{#([a-zA-Z0-9_]{1,50}(?:\.[a-zA-Z_][a-zA-Z0-9_]{0,29}){1,10})#\}\}")
  14. class VariablePool(BaseModel):
  15. # Variable dictionary is a dictionary for looking up variables by their selector.
  16. # The first element of the selector is the node id, it's the first-level key in the dictionary.
  17. # Other elements of the selector are the keys in the second-level dictionary. To get the key, we hash the
  18. # elements of the selector except the first one.
  19. variable_dictionary: dict[str, dict[int, Segment]] = Field(
  20. description="Variables mapping",
  21. default=defaultdict(dict),
  22. )
  23. # TODO: This user inputs is not used for pool.
  24. user_inputs: Mapping[str, Any] = Field(
  25. description="User inputs",
  26. )
  27. system_variables: Mapping[SystemVariableKey, Any] = Field(
  28. description="System variables",
  29. )
  30. environment_variables: Sequence[Variable] = Field(
  31. description="Environment variables.",
  32. default_factory=list,
  33. )
  34. conversation_variables: Sequence[Variable] = Field(
  35. description="Conversation variables.",
  36. default_factory=list,
  37. )
  38. def __init__(
  39. self,
  40. *,
  41. system_variables: Mapping[SystemVariableKey, Any] | None = None,
  42. user_inputs: Mapping[str, Any] | None = None,
  43. environment_variables: Sequence[Variable] | None = None,
  44. conversation_variables: Sequence[Variable] | None = None,
  45. **kwargs,
  46. ):
  47. environment_variables = environment_variables or []
  48. conversation_variables = conversation_variables or []
  49. user_inputs = user_inputs or {}
  50. system_variables = system_variables or {}
  51. super().__init__(
  52. system_variables=system_variables,
  53. user_inputs=user_inputs,
  54. environment_variables=environment_variables,
  55. conversation_variables=conversation_variables,
  56. **kwargs,
  57. )
  58. for key, value in self.system_variables.items():
  59. self.add((SYSTEM_VARIABLE_NODE_ID, key.value), value)
  60. # Add environment variables to the variable pool
  61. for var in self.environment_variables:
  62. self.add((ENVIRONMENT_VARIABLE_NODE_ID, var.name), var)
  63. # Add conversation variables to the variable pool
  64. for var in self.conversation_variables:
  65. self.add((CONVERSATION_VARIABLE_NODE_ID, var.name), var)
  66. def add(self, selector: Sequence[str], value: Any, /) -> None:
  67. """
  68. Adds a variable to the variable pool.
  69. NOTE: You should not add a non-Segment value to the variable pool
  70. even if it is allowed now.
  71. Args:
  72. selector (Sequence[str]): The selector for the variable.
  73. value (VariableValue): The value of the variable.
  74. Raises:
  75. ValueError: If the selector is invalid.
  76. Returns:
  77. None
  78. """
  79. if len(selector) < 2:
  80. raise ValueError("Invalid selector")
  81. if isinstance(value, Variable):
  82. variable = value
  83. if isinstance(value, Segment):
  84. variable = variable_factory.segment_to_variable(segment=value, selector=selector)
  85. else:
  86. segment = variable_factory.build_segment(value)
  87. variable = variable_factory.segment_to_variable(segment=segment, selector=selector)
  88. hash_key = hash(tuple(selector[1:]))
  89. self.variable_dictionary[selector[0]][hash_key] = variable
  90. def get(self, selector: Sequence[str], /) -> Segment | None:
  91. """
  92. Retrieves the value from the variable pool based on the given selector.
  93. Args:
  94. selector (Sequence[str]): The selector used to identify the variable.
  95. Returns:
  96. Any: The value associated with the given selector.
  97. Raises:
  98. ValueError: If the selector is invalid.
  99. """
  100. if len(selector) < 2:
  101. return None
  102. hash_key = hash(tuple(selector[1:]))
  103. value = self.variable_dictionary[selector[0]].get(hash_key)
  104. if value is None:
  105. selector, attr = selector[:-1], selector[-1]
  106. # Python support `attr in FileAttribute` after 3.12
  107. if attr not in {item.value for item in FileAttribute}:
  108. return None
  109. value = self.get(selector)
  110. if not isinstance(value, (FileSegment, NoneSegment)):
  111. return None
  112. if isinstance(value, FileSegment):
  113. attr = FileAttribute(attr)
  114. attr_value = file_manager.get_attr(file=value.value, attr=attr)
  115. return variable_factory.build_segment(attr_value)
  116. return value
  117. return value
  118. def remove(self, selector: Sequence[str], /):
  119. """
  120. Remove variables from the variable pool based on the given selector.
  121. Args:
  122. selector (Sequence[str]): A sequence of strings representing the selector.
  123. Returns:
  124. None
  125. """
  126. if not selector:
  127. return
  128. if len(selector) == 1:
  129. self.variable_dictionary[selector[0]] = {}
  130. return
  131. hash_key = hash(tuple(selector[1:]))
  132. self.variable_dictionary[selector[0]].pop(hash_key, None)
  133. def convert_template(self, template: str, /):
  134. parts = VARIABLE_PATTERN.split(template)
  135. segments = []
  136. for part in filter(lambda x: x, parts):
  137. if "." in part and (variable := self.get(part.split("."))):
  138. segments.append(variable)
  139. else:
  140. segments.append(variable_factory.build_segment(part))
  141. return SegmentGroup(value=segments)
  142. def get_file(self, selector: Sequence[str], /) -> FileSegment | None:
  143. segment = self.get(selector)
  144. if isinstance(segment, FileSegment):
  145. return segment
  146. return None