variable_pool.py 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159
  1. from collections import defaultdict
  2. from collections.abc import Mapping, Sequence
  3. from typing import Any, Union
  4. from pydantic import BaseModel, Field, model_validator
  5. from typing_extensions import deprecated
  6. from core.app.segments import Segment, Variable, factory
  7. from core.file.file_obj import FileVar
  8. from core.workflow.enums import SystemVariableKey
  9. VariableValue = Union[str, int, float, dict, list, FileVar]
  10. SYSTEM_VARIABLE_NODE_ID = "sys"
  11. ENVIRONMENT_VARIABLE_NODE_ID = "env"
  12. CONVERSATION_VARIABLE_NODE_ID = "conversation"
  13. class VariablePool(BaseModel):
  14. # Variable dictionary is a dictionary for looking up variables by their selector.
  15. # The first element of the selector is the node id, it's the first-level key in the dictionary.
  16. # Other elements of the selector are the keys in the second-level dictionary. To get the key, we hash the
  17. # elements of the selector except the first one.
  18. variable_dictionary: dict[str, dict[int, Segment]] = Field(
  19. description="Variables mapping", default=defaultdict(dict)
  20. )
  21. # TODO: This user inputs is not used for pool.
  22. user_inputs: Mapping[str, Any] = Field(
  23. description="User inputs",
  24. )
  25. system_variables: Mapping[SystemVariableKey, Any] = Field(
  26. description="System variables",
  27. )
  28. environment_variables: Sequence[Variable] = Field(description="Environment variables.", default_factory=list)
  29. conversation_variables: Sequence[Variable] | None = None
  30. @model_validator(mode="after")
  31. def val_model_after(self):
  32. """
  33. Append system variables
  34. :return:
  35. """
  36. # Add system variables to the variable pool
  37. for key, value in self.system_variables.items():
  38. self.add((SYSTEM_VARIABLE_NODE_ID, key.value), value)
  39. # Add environment variables to the variable pool
  40. for var in self.environment_variables or []:
  41. self.add((ENVIRONMENT_VARIABLE_NODE_ID, var.name), var)
  42. # Add conversation variables to the variable pool
  43. for var in self.conversation_variables or []:
  44. self.add((CONVERSATION_VARIABLE_NODE_ID, var.name), var)
  45. return self
  46. def add(self, selector: Sequence[str], value: Any, /) -> None:
  47. """
  48. Adds a variable to the variable pool.
  49. Args:
  50. selector (Sequence[str]): The selector for the variable.
  51. value (VariableValue): The value of the variable.
  52. Raises:
  53. ValueError: If the selector is invalid.
  54. Returns:
  55. None
  56. """
  57. if len(selector) < 2:
  58. raise ValueError("Invalid selector")
  59. if value is None:
  60. return
  61. if isinstance(value, Segment):
  62. v = value
  63. else:
  64. v = factory.build_segment(value)
  65. hash_key = hash(tuple(selector[1:]))
  66. self.variable_dictionary[selector[0]][hash_key] = v
  67. def get(self, selector: Sequence[str], /) -> Segment | None:
  68. """
  69. Retrieves the value from the variable pool based on the given selector.
  70. Args:
  71. selector (Sequence[str]): The selector used to identify the variable.
  72. Returns:
  73. Any: The value associated with the given selector.
  74. Raises:
  75. ValueError: If the selector is invalid.
  76. """
  77. if len(selector) < 2:
  78. raise ValueError("Invalid selector")
  79. hash_key = hash(tuple(selector[1:]))
  80. value = self.variable_dictionary[selector[0]].get(hash_key)
  81. return value
  82. @deprecated("This method is deprecated, use `get` instead.")
  83. def get_any(self, selector: Sequence[str], /) -> Any | None:
  84. """
  85. Retrieves the value from the variable pool based on the given selector.
  86. Args:
  87. selector (Sequence[str]): The selector used to identify the variable.
  88. Returns:
  89. Any: The value associated with the given selector.
  90. Raises:
  91. ValueError: If the selector is invalid.
  92. """
  93. if len(selector) < 2:
  94. raise ValueError("Invalid selector")
  95. hash_key = hash(tuple(selector[1:]))
  96. value = self.variable_dictionary[selector[0]].get(hash_key)
  97. return value.to_object() if value else None
  98. def remove(self, selector: Sequence[str], /):
  99. """
  100. Remove variables from the variable pool based on the given selector.
  101. Args:
  102. selector (Sequence[str]): A sequence of strings representing the selector.
  103. Returns:
  104. None
  105. """
  106. if not selector:
  107. return
  108. if len(selector) == 1:
  109. self.variable_dictionary[selector[0]] = {}
  110. return
  111. hash_key = hash(tuple(selector[1:]))
  112. self.variable_dictionary[selector[0]].pop(hash_key, None)
  113. def remove_node(self, node_id: str, /):
  114. """
  115. Remove all variables associated with a given node id.
  116. Args:
  117. node_id (str): The node id to remove.
  118. Returns:
  119. None
  120. """
  121. self.variable_dictionary.pop(node_id, None)