123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159 |
- from collections import defaultdict
- from collections.abc import Mapping, Sequence
- from typing import Any, Union
- from pydantic import BaseModel, Field, model_validator
- from typing_extensions import deprecated
- from core.app.segments import Segment, Variable, factory
- from core.file.file_obj import FileVar
- from core.workflow.enums import SystemVariableKey
- VariableValue = Union[str, int, float, dict, list, FileVar]
- SYSTEM_VARIABLE_NODE_ID = "sys"
- ENVIRONMENT_VARIABLE_NODE_ID = "env"
- CONVERSATION_VARIABLE_NODE_ID = "conversation"
- class VariablePool(BaseModel):
- # Variable dictionary is a dictionary for looking up variables by their selector.
- # The first element of the selector is the node id, it's the first-level key in the dictionary.
- # Other elements of the selector are the keys in the second-level dictionary. To get the key, we hash the
- # elements of the selector except the first one.
- variable_dictionary: dict[str, dict[int, Segment]] = Field(
- description="Variables mapping", default=defaultdict(dict)
- )
- # TODO: This user inputs is not used for pool.
- user_inputs: Mapping[str, Any] = Field(
- description="User inputs",
- )
- system_variables: Mapping[SystemVariableKey, Any] = Field(
- description="System variables",
- )
- environment_variables: Sequence[Variable] = Field(description="Environment variables.", default_factory=list)
- conversation_variables: Sequence[Variable] | None = None
- @model_validator(mode="after")
- def val_model_after(self):
- """
- Append system variables
- :return:
- """
- # Add system variables to the variable pool
- for key, value in self.system_variables.items():
- self.add((SYSTEM_VARIABLE_NODE_ID, key.value), value)
- # Add environment variables to the variable pool
- for var in self.environment_variables or []:
- self.add((ENVIRONMENT_VARIABLE_NODE_ID, var.name), var)
- # Add conversation variables to the variable pool
- for var in self.conversation_variables or []:
- self.add((CONVERSATION_VARIABLE_NODE_ID, var.name), var)
- return self
- def add(self, selector: Sequence[str], value: Any, /) -> None:
- """
- Adds a variable to the variable pool.
- Args:
- selector (Sequence[str]): The selector for the variable.
- value (VariableValue): The value of the variable.
- Raises:
- ValueError: If the selector is invalid.
- Returns:
- None
- """
- if len(selector) < 2:
- raise ValueError("Invalid selector")
- if value is None:
- return
- if isinstance(value, Segment):
- v = value
- else:
- v = factory.build_segment(value)
- hash_key = hash(tuple(selector[1:]))
- self.variable_dictionary[selector[0]][hash_key] = v
- def get(self, selector: Sequence[str], /) -> Segment | None:
- """
- Retrieves the value from the variable pool based on the given selector.
- Args:
- selector (Sequence[str]): The selector used to identify the variable.
- Returns:
- Any: The value associated with the given selector.
- Raises:
- ValueError: If the selector is invalid.
- """
- if len(selector) < 2:
- raise ValueError("Invalid selector")
- hash_key = hash(tuple(selector[1:]))
- value = self.variable_dictionary[selector[0]].get(hash_key)
- return value
- @deprecated("This method is deprecated, use `get` instead.")
- def get_any(self, selector: Sequence[str], /) -> Any | None:
- """
- Retrieves the value from the variable pool based on the given selector.
- Args:
- selector (Sequence[str]): The selector used to identify the variable.
- Returns:
- Any: The value associated with the given selector.
- Raises:
- ValueError: If the selector is invalid.
- """
- if len(selector) < 2:
- raise ValueError("Invalid selector")
- hash_key = hash(tuple(selector[1:]))
- value = self.variable_dictionary[selector[0]].get(hash_key)
- return value.to_object() if value else None
- def remove(self, selector: Sequence[str], /):
- """
- Remove variables from the variable pool based on the given selector.
- Args:
- selector (Sequence[str]): A sequence of strings representing the selector.
- Returns:
- None
- """
- if not selector:
- return
- if len(selector) == 1:
- self.variable_dictionary[selector[0]] = {}
- return
- hash_key = hash(tuple(selector[1:]))
- self.variable_dictionary[selector[0]].pop(hash_key, None)
- def remove_node(self, node_id: str, /):
- """
- Remove all variables associated with a given node id.
- Args:
- node_id (str): The node id to remove.
- Returns:
- None
- """
- self.variable_dictionary.pop(node_id, None)
|