variable_pool.py 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141
  1. from collections import defaultdict
  2. from collections.abc import Mapping, Sequence
  3. from typing import Any, Union
  4. from typing_extensions import deprecated
  5. from core.app.segments import Segment, Variable, factory
  6. from core.file.file_obj import FileVar
  7. from core.workflow.enums import SystemVariableKey
  8. VariableValue = Union[str, int, float, dict, list, FileVar]
  9. SYSTEM_VARIABLE_NODE_ID = "sys"
  10. ENVIRONMENT_VARIABLE_NODE_ID = "env"
  11. CONVERSATION_VARIABLE_NODE_ID = "conversation"
  12. class VariablePool:
  13. def __init__(
  14. self,
  15. system_variables: Mapping[SystemVariableKey, Any],
  16. user_inputs: Mapping[str, Any],
  17. environment_variables: Sequence[Variable],
  18. conversation_variables: Sequence[Variable] | None = None,
  19. ) -> None:
  20. # system variables
  21. # for example:
  22. # {
  23. # 'query': 'abc',
  24. # 'files': []
  25. # }
  26. # Varaible dictionary is a dictionary for looking up variables by their selector.
  27. # The first element of the selector is the node id, it's the first-level key in the dictionary.
  28. # Other elements of the selector are the keys in the second-level dictionary. To get the key, we hash the
  29. # elements of the selector except the first one.
  30. self._variable_dictionary: dict[str, dict[int, Segment]] = defaultdict(dict)
  31. # TODO: This user inputs is not used for pool.
  32. self.user_inputs = user_inputs
  33. # Add system variables to the variable pool
  34. self.system_variables = system_variables
  35. for key, value in system_variables.items():
  36. self.add((SYSTEM_VARIABLE_NODE_ID, key.value), value)
  37. # Add environment variables to the variable pool
  38. for var in environment_variables:
  39. self.add((ENVIRONMENT_VARIABLE_NODE_ID, var.name), var)
  40. # Add conversation variables to the variable pool
  41. for var in conversation_variables or []:
  42. self.add((CONVERSATION_VARIABLE_NODE_ID, var.name), var)
  43. def add(self, selector: Sequence[str], value: Any, /) -> None:
  44. """
  45. Adds a variable to the variable pool.
  46. Args:
  47. selector (Sequence[str]): The selector for the variable.
  48. value (VariableValue): The value of the variable.
  49. Raises:
  50. ValueError: If the selector is invalid.
  51. Returns:
  52. None
  53. """
  54. if len(selector) < 2:
  55. raise ValueError("Invalid selector")
  56. if value is None:
  57. return
  58. if isinstance(value, Segment):
  59. v = value
  60. else:
  61. v = factory.build_segment(value)
  62. hash_key = hash(tuple(selector[1:]))
  63. self._variable_dictionary[selector[0]][hash_key] = v
  64. def get(self, selector: Sequence[str], /) -> Segment | None:
  65. """
  66. Retrieves the value from the variable pool based on the given selector.
  67. Args:
  68. selector (Sequence[str]): The selector used to identify the variable.
  69. Returns:
  70. Any: The value associated with the given selector.
  71. Raises:
  72. ValueError: If the selector is invalid.
  73. """
  74. if len(selector) < 2:
  75. raise ValueError("Invalid selector")
  76. hash_key = hash(tuple(selector[1:]))
  77. value = self._variable_dictionary[selector[0]].get(hash_key)
  78. return value
  79. @deprecated("This method is deprecated, use `get` instead.")
  80. def get_any(self, selector: Sequence[str], /) -> Any | None:
  81. """
  82. Retrieves the value from the variable pool based on the given selector.
  83. Args:
  84. selector (Sequence[str]): The selector used to identify the variable.
  85. Returns:
  86. Any: The value associated with the given selector.
  87. Raises:
  88. ValueError: If the selector is invalid.
  89. """
  90. if len(selector) < 2:
  91. raise ValueError("Invalid selector")
  92. hash_key = hash(tuple(selector[1:]))
  93. value = self._variable_dictionary[selector[0]].get(hash_key)
  94. return value.to_object() if value else None
  95. def remove(self, selector: Sequence[str], /):
  96. """
  97. Remove variables from the variable pool based on the given selector.
  98. Args:
  99. selector (Sequence[str]): A sequence of strings representing the selector.
  100. Returns:
  101. None
  102. """
  103. if not selector:
  104. return
  105. if len(selector) == 1:
  106. self._variable_dictionary[selector[0]] = {}
  107. return
  108. hash_key = hash(tuple(selector[1:]))
  109. self._variable_dictionary[selector[0]].pop(hash_key, None)