variable_pool.py 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135
  1. from collections import defaultdict
  2. from collections.abc import Mapping, Sequence
  3. from typing import Any, Union
  4. from typing_extensions import deprecated
  5. from core.app.segments import Segment, Variable, factory
  6. from core.file.file_obj import FileVar
  7. from core.workflow.entities.node_entities import SystemVariable
  8. VariableValue = Union[str, int, float, dict, list, FileVar]
  9. SYSTEM_VARIABLE_NODE_ID = 'sys'
  10. ENVIRONMENT_VARIABLE_NODE_ID = 'env'
  11. class VariablePool:
  12. def __init__(
  13. self,
  14. system_variables: Mapping[SystemVariable, Any],
  15. user_inputs: Mapping[str, Any],
  16. environment_variables: Sequence[Variable],
  17. ) -> None:
  18. # system variables
  19. # for example:
  20. # {
  21. # 'query': 'abc',
  22. # 'files': []
  23. # }
  24. # Varaible dictionary is a dictionary for looking up variables by their selector.
  25. # The first element of the selector is the node id, it's the first-level key in the dictionary.
  26. # Other elements of the selector are the keys in the second-level dictionary. To get the key, we hash the
  27. # elements of the selector except the first one.
  28. self._variable_dictionary: dict[str, dict[int, Segment]] = defaultdict(dict)
  29. # TODO: This user inputs is not used for pool.
  30. self.user_inputs = user_inputs
  31. # Add system variables to the variable pool
  32. self.system_variables = system_variables
  33. for key, value in system_variables.items():
  34. self.add((SYSTEM_VARIABLE_NODE_ID, key.value), value)
  35. # Add environment variables to the variable pool
  36. for var in environment_variables or []:
  37. self.add((ENVIRONMENT_VARIABLE_NODE_ID, var.name), var)
  38. def add(self, selector: Sequence[str], value: Any, /) -> None:
  39. """
  40. Adds a variable to the variable pool.
  41. Args:
  42. selector (Sequence[str]): The selector for the variable.
  43. value (VariableValue): The value of the variable.
  44. Raises:
  45. ValueError: If the selector is invalid.
  46. Returns:
  47. None
  48. """
  49. if len(selector) < 2:
  50. raise ValueError('Invalid selector')
  51. if value is None:
  52. return
  53. if isinstance(value, Segment):
  54. v = value
  55. else:
  56. v = factory.build_segment(value)
  57. hash_key = hash(tuple(selector[1:]))
  58. self._variable_dictionary[selector[0]][hash_key] = v
  59. def get(self, selector: Sequence[str], /) -> Segment | None:
  60. """
  61. Retrieves the value from the variable pool based on the given selector.
  62. Args:
  63. selector (Sequence[str]): The selector used to identify the variable.
  64. Returns:
  65. Any: The value associated with the given selector.
  66. Raises:
  67. ValueError: If the selector is invalid.
  68. """
  69. if len(selector) < 2:
  70. raise ValueError('Invalid selector')
  71. hash_key = hash(tuple(selector[1:]))
  72. value = self._variable_dictionary[selector[0]].get(hash_key)
  73. return value
  74. @deprecated('This method is deprecated, use `get` instead.')
  75. def get_any(self, selector: Sequence[str], /) -> Any | None:
  76. """
  77. Retrieves the value from the variable pool based on the given selector.
  78. Args:
  79. selector (Sequence[str]): The selector used to identify the variable.
  80. Returns:
  81. Any: The value associated with the given selector.
  82. Raises:
  83. ValueError: If the selector is invalid.
  84. """
  85. if len(selector) < 2:
  86. raise ValueError('Invalid selector')
  87. hash_key = hash(tuple(selector[1:]))
  88. value = self._variable_dictionary[selector[0]].get(hash_key)
  89. return value.to_object() if value else None
  90. def remove(self, selector: Sequence[str], /):
  91. """
  92. Remove variables from the variable pool based on the given selector.
  93. Args:
  94. selector (Sequence[str]): A sequence of strings representing the selector.
  95. Returns:
  96. None
  97. """
  98. if not selector:
  99. return
  100. if len(selector) == 1:
  101. self._variable_dictionary[selector[0]] = {}
  102. return
  103. hash_key = hash(tuple(selector[1:]))
  104. self._variable_dictionary[selector[0]].pop(hash_key, None)