from collections import defaultdict
from collections.abc import Mapping, Sequence
from typing import Any, Union

from pydantic import BaseModel, Field, model_validator
from typing_extensions import deprecated

from core.app.segments import Segment, Variable, factory
from core.file.file_obj import FileVar
from core.workflow.enums import SystemVariableKey

VariableValue = Union[str, int, float, dict, list, FileVar]


SYSTEM_VARIABLE_NODE_ID = "sys"
ENVIRONMENT_VARIABLE_NODE_ID = "env"
CONVERSATION_VARIABLE_NODE_ID = "conversation"


class VariablePool(BaseModel):
    # Variable dictionary is a dictionary for looking up variables by their selector.
    # The first element of the selector is the node id, it's the first-level key in the dictionary.
    # Other elements of the selector are the keys in the second-level dictionary. To get the key, we hash the
    # elements of the selector except the first one.
    variable_dictionary: dict[str, dict[int, Segment]] = Field(
        description="Variables mapping", default=defaultdict(dict)
    )

    # TODO: This user inputs is not used for pool.
    user_inputs: Mapping[str, Any] = Field(
        description="User inputs",
    )

    system_variables: Mapping[SystemVariableKey, Any] = Field(
        description="System variables",
    )

    environment_variables: Sequence[Variable] = Field(description="Environment variables.", default_factory=list)

    conversation_variables: Sequence[Variable] | None = None

    @model_validator(mode="after")
    def val_model_after(self):
        """
        Append system variables
        :return:
        """
        # Add system variables to the variable pool
        for key, value in self.system_variables.items():
            self.add((SYSTEM_VARIABLE_NODE_ID, key.value), value)

        # Add environment variables to the variable pool
        for var in self.environment_variables or []:
            self.add((ENVIRONMENT_VARIABLE_NODE_ID, var.name), var)

        # Add conversation variables to the variable pool
        for var in self.conversation_variables or []:
            self.add((CONVERSATION_VARIABLE_NODE_ID, var.name), var)

        return self

    def add(self, selector: Sequence[str], value: Any, /) -> None:
        """
        Adds a variable to the variable pool.

        Args:
            selector (Sequence[str]): The selector for the variable.
            value (VariableValue): The value of the variable.

        Raises:
            ValueError: If the selector is invalid.

        Returns:
            None
        """
        if len(selector) < 2:
            raise ValueError("Invalid selector")

        if value is None:
            return

        if isinstance(value, Segment):
            v = value
        else:
            v = factory.build_segment(value)

        hash_key = hash(tuple(selector[1:]))
        self.variable_dictionary[selector[0]][hash_key] = v

    def get(self, selector: Sequence[str], /) -> Segment | None:
        """
        Retrieves the value from the variable pool based on the given selector.

        Args:
            selector (Sequence[str]): The selector used to identify the variable.

        Returns:
            Any: The value associated with the given selector.

        Raises:
            ValueError: If the selector is invalid.
        """
        if len(selector) < 2:
            raise ValueError("Invalid selector")
        hash_key = hash(tuple(selector[1:]))
        value = self.variable_dictionary[selector[0]].get(hash_key)

        return value

    @deprecated("This method is deprecated, use `get` instead.")
    def get_any(self, selector: Sequence[str], /) -> Any | None:
        """
        Retrieves the value from the variable pool based on the given selector.

        Args:
            selector (Sequence[str]): The selector used to identify the variable.

        Returns:
            Any: The value associated with the given selector.

        Raises:
            ValueError: If the selector is invalid.
        """
        if len(selector) < 2:
            raise ValueError("Invalid selector")
        hash_key = hash(tuple(selector[1:]))
        value = self.variable_dictionary[selector[0]].get(hash_key)
        return value.to_object() if value else None

    def remove(self, selector: Sequence[str], /):
        """
        Remove variables from the variable pool based on the given selector.

        Args:
            selector (Sequence[str]): A sequence of strings representing the selector.

        Returns:
            None
        """
        if not selector:
            return
        if len(selector) == 1:
            self.variable_dictionary[selector[0]] = {}
            return
        hash_key = hash(tuple(selector[1:]))
        self.variable_dictionary[selector[0]].pop(hash_key, None)

    def remove_node(self, node_id: str, /):
        """
        Remove all variables associated with a given node id.

        Args:
            node_id (str): The node id to remove.

        Returns:
            None
        """
        self.variable_dictionary.pop(node_id, None)