123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135 |
- from collections import defaultdict
- from collections.abc import Mapping, Sequence
- from typing import Any, Union
- from typing_extensions import deprecated
- from core.app.segments import Segment, Variable, factory
- from core.file.file_obj import FileVar
- from core.workflow.entities.node_entities import SystemVariable
- VariableValue = Union[str, int, float, dict, list, FileVar]
- SYSTEM_VARIABLE_NODE_ID = 'sys'
- ENVIRONMENT_VARIABLE_NODE_ID = 'env'
- class VariablePool:
- def __init__(
- self,
- system_variables: Mapping[SystemVariable, Any],
- user_inputs: Mapping[str, Any],
- environment_variables: Sequence[Variable],
- ) -> None:
- # system variables
- # for example:
- # {
- # 'query': 'abc',
- # 'files': []
- # }
- # Varaible dictionary is a dictionary for looking up variables by their selector.
- # The first element of the selector is the node id, it's the first-level key in the dictionary.
- # Other elements of the selector are the keys in the second-level dictionary. To get the key, we hash the
- # elements of the selector except the first one.
- self._variable_dictionary: dict[str, dict[int, Segment]] = defaultdict(dict)
- # TODO: This user inputs is not used for pool.
- self.user_inputs = user_inputs
- # Add system variables to the variable pool
- self.system_variables = system_variables
- for key, value in system_variables.items():
- self.add((SYSTEM_VARIABLE_NODE_ID, key.value), value)
- # Add environment variables to the variable pool
- for var in environment_variables or []:
- self.add((ENVIRONMENT_VARIABLE_NODE_ID, var.name), var)
- def add(self, selector: Sequence[str], value: Any, /) -> None:
- """
- Adds a variable to the variable pool.
- Args:
- selector (Sequence[str]): The selector for the variable.
- value (VariableValue): The value of the variable.
- Raises:
- ValueError: If the selector is invalid.
- Returns:
- None
- """
- if len(selector) < 2:
- raise ValueError('Invalid selector')
- if value is None:
- return
- if isinstance(value, Segment):
- v = value
- else:
- v = factory.build_segment(value)
- hash_key = hash(tuple(selector[1:]))
- self._variable_dictionary[selector[0]][hash_key] = v
- def get(self, selector: Sequence[str], /) -> Segment | None:
- """
- Retrieves the value from the variable pool based on the given selector.
- Args:
- selector (Sequence[str]): The selector used to identify the variable.
- Returns:
- Any: The value associated with the given selector.
- Raises:
- ValueError: If the selector is invalid.
- """
- if len(selector) < 2:
- raise ValueError('Invalid selector')
- hash_key = hash(tuple(selector[1:]))
- value = self._variable_dictionary[selector[0]].get(hash_key)
- return value
- @deprecated('This method is deprecated, use `get` instead.')
- def get_any(self, selector: Sequence[str], /) -> Any | None:
- """
- Retrieves the value from the variable pool based on the given selector.
- Args:
- selector (Sequence[str]): The selector used to identify the variable.
- Returns:
- Any: The value associated with the given selector.
- Raises:
- ValueError: If the selector is invalid.
- """
- if len(selector) < 2:
- raise ValueError('Invalid selector')
- hash_key = hash(tuple(selector[1:]))
- value = self._variable_dictionary[selector[0]].get(hash_key)
- return value.to_object() if value else None
- def remove(self, selector: Sequence[str], /):
- """
- Remove variables from the variable pool based on the given selector.
- Args:
- selector (Sequence[str]): A sequence of strings representing the selector.
- Returns:
- None
- """
- if not selector:
- return
- if len(selector) == 1:
- self._variable_dictionary[selector[0]] = {}
- return
- hash_key = hash(tuple(selector[1:]))
- self._variable_dictionary[selector[0]].pop(hash_key, None)
|