variable_pool.py 6.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189
  1. import re
  2. from collections import defaultdict
  3. from collections.abc import Mapping, Sequence
  4. from typing import Any, Union
  5. from pydantic import BaseModel, Field
  6. from typing_extensions import deprecated
  7. from core.file import File, FileAttribute, file_manager
  8. from core.variables import Segment, SegmentGroup, Variable
  9. from core.variables.segments import FileSegment
  10. from factories import variable_factory
  11. from ..constants import CONVERSATION_VARIABLE_NODE_ID, ENVIRONMENT_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID
  12. from ..enums import SystemVariableKey
  13. VariableValue = Union[str, int, float, dict, list, File]
  14. VARIABLE_PATTERN = re.compile(r"\{\{#([a-zA-Z0-9_]{1,50}(?:\.[a-zA-Z_][a-zA-Z0-9_]{0,29}){1,10})#\}\}")
  15. class VariablePool(BaseModel):
  16. # Variable dictionary is a dictionary for looking up variables by their selector.
  17. # The first element of the selector is the node id, it's the first-level key in the dictionary.
  18. # Other elements of the selector are the keys in the second-level dictionary. To get the key, we hash the
  19. # elements of the selector except the first one.
  20. variable_dictionary: dict[str, dict[int, Segment]] = Field(
  21. description="Variables mapping",
  22. default=defaultdict(dict),
  23. )
  24. # TODO: This user inputs is not used for pool.
  25. user_inputs: Mapping[str, Any] = Field(
  26. description="User inputs",
  27. )
  28. system_variables: Mapping[SystemVariableKey, Any] = Field(
  29. description="System variables",
  30. )
  31. environment_variables: Sequence[Variable] = Field(
  32. description="Environment variables.",
  33. default_factory=list,
  34. )
  35. conversation_variables: Sequence[Variable] = Field(
  36. description="Conversation variables.",
  37. default_factory=list,
  38. )
  39. def __init__(
  40. self,
  41. *,
  42. system_variables: Mapping[SystemVariableKey, Any] | None = None,
  43. user_inputs: Mapping[str, Any] | None = None,
  44. environment_variables: Sequence[Variable] | None = None,
  45. conversation_variables: Sequence[Variable] | None = None,
  46. **kwargs,
  47. ):
  48. environment_variables = environment_variables or []
  49. conversation_variables = conversation_variables or []
  50. user_inputs = user_inputs or {}
  51. system_variables = system_variables or {}
  52. super().__init__(
  53. system_variables=system_variables,
  54. user_inputs=user_inputs,
  55. environment_variables=environment_variables,
  56. conversation_variables=conversation_variables,
  57. **kwargs,
  58. )
  59. for key, value in self.system_variables.items():
  60. self.add((SYSTEM_VARIABLE_NODE_ID, key.value), value)
  61. # Add environment variables to the variable pool
  62. for var in self.environment_variables:
  63. self.add((ENVIRONMENT_VARIABLE_NODE_ID, var.name), var)
  64. # Add conversation variables to the variable pool
  65. for var in self.conversation_variables:
  66. self.add((CONVERSATION_VARIABLE_NODE_ID, var.name), var)
  67. def add(self, selector: Sequence[str], value: Any, /) -> None:
  68. """
  69. Adds a variable to the variable pool.
  70. NOTE: You should not add a non-Segment value to the variable pool
  71. even if it is allowed now.
  72. Args:
  73. selector (Sequence[str]): The selector for the variable.
  74. value (VariableValue): The value of the variable.
  75. Raises:
  76. ValueError: If the selector is invalid.
  77. Returns:
  78. None
  79. """
  80. if len(selector) < 2:
  81. raise ValueError("Invalid selector")
  82. if isinstance(value, Segment):
  83. v = value
  84. else:
  85. v = variable_factory.build_segment(value)
  86. hash_key = hash(tuple(selector[1:]))
  87. self.variable_dictionary[selector[0]][hash_key] = v
  88. def get(self, selector: Sequence[str], /) -> Segment | None:
  89. """
  90. Retrieves the value from the variable pool based on the given selector.
  91. Args:
  92. selector (Sequence[str]): The selector used to identify the variable.
  93. Returns:
  94. Any: The value associated with the given selector.
  95. Raises:
  96. ValueError: If the selector is invalid.
  97. """
  98. if len(selector) < 2:
  99. return None
  100. hash_key = hash(tuple(selector[1:]))
  101. value = self.variable_dictionary[selector[0]].get(hash_key)
  102. if value is None:
  103. selector, attr = selector[:-1], selector[-1]
  104. value = self.get(selector)
  105. if isinstance(value, FileSegment):
  106. attr = FileAttribute(attr)
  107. attr_value = file_manager.get_attr(file=value.value, attr=attr)
  108. return variable_factory.build_segment(attr_value)
  109. return value
  110. @deprecated("This method is deprecated, use `get` instead.")
  111. def get_any(self, selector: Sequence[str], /) -> Any | None:
  112. """
  113. Retrieves the value from the variable pool based on the given selector.
  114. Args:
  115. selector (Sequence[str]): The selector used to identify the variable.
  116. Returns:
  117. Any: The value associated with the given selector.
  118. Raises:
  119. ValueError: If the selector is invalid.
  120. """
  121. if len(selector) < 2:
  122. raise ValueError("Invalid selector")
  123. hash_key = hash(tuple(selector[1:]))
  124. value = self.variable_dictionary[selector[0]].get(hash_key)
  125. return value.to_object() if value else None
  126. def remove(self, selector: Sequence[str], /):
  127. """
  128. Remove variables from the variable pool based on the given selector.
  129. Args:
  130. selector (Sequence[str]): A sequence of strings representing the selector.
  131. Returns:
  132. None
  133. """
  134. if not selector:
  135. return
  136. if len(selector) == 1:
  137. self.variable_dictionary[selector[0]] = {}
  138. return
  139. hash_key = hash(tuple(selector[1:]))
  140. self.variable_dictionary[selector[0]].pop(hash_key, None)
  141. def convert_template(self, template: str, /):
  142. parts = VARIABLE_PATTERN.split(template)
  143. segments = []
  144. for part in filter(lambda x: x, parts):
  145. if "." in part and (variable := self.get(part.split("."))):
  146. segments.append(variable)
  147. else:
  148. segments.append(variable_factory.build_segment(part))
  149. return SegmentGroup(value=segments)
  150. def get_file(self, selector: Sequence[str], /) -> FileSegment | None:
  151. segment = self.get(selector)
  152. if isinstance(segment, FileSegment):
  153. return segment
  154. return None