Skip to content

treeo.toplevel_mutable

A decorator that transforms a stateful function f that receives an Tree instance as a its first argument into a mutable function. It differs from mutable in the following ways:

  • It always applies mutability to the top-level object only.
  • f is expected to return the new state either as the only output or as the last element of a tuple.

Examples:

@dataclass
class Child(to.Tree, to.Immutable):
    n: int = to.node()

@dataclass
def Parent(to.Tree, to.Immutable):
    child: Child

    @to.toplevel_mutable
    def update(self) -> "Parent":
        # self is currently mutable
        self.child = self.child.replace(n=self.child.n + 1) # but child is immutable (so we use replace)

        return self

tree = Parent(child=Child(n=4))
tree = tree.update()

This behaviour is useful when the top-level tree mostly manipulates sub-trees that have well-defined immutable APIs, avoids explicitly run replace to propagate updates to the sub-trees and makes management of the top-level tree easier.

Note: Any Trees that are found in the output of f are set to being immutable, however the element is the to the same immutablity status as the input tree if they have the same type.

Parameters:

Name Type Description Default
f ~C

The function to be transformed.

required

Returns:

Type Description
~C

A function with top-level mutability.

Source code in treeo/api.py
def toplevel_mutable(f: C) -> C:
    """
    A decorator that transforms a stateful function `f` that receives an Tree
    instance as a its first argument into a mutable function. It differs from `mutable`
    in the following ways:

    * It always applies mutability to the top-level object only.
    * `f` is expected to return the new state either as the only output or
        as the last element of a tuple.

    Example:

    ```python
    @dataclass
    class Child(to.Tree, to.Immutable):
        n: int = to.node()

    @dataclass
    def Parent(to.Tree, to.Immutable):
        child: Child

        @to.toplevel_mutable
        def update(self) -> "Parent":
            # self is currently mutable
            self.child = self.child.replace(n=self.child.n + 1) # but child is immutable (so we use replace)

            return self

    tree = Parent(child=Child(n=4))
    tree = tree.update()
    ```

    This behaviour is useful when the top-level tree mostly manipulates sub-trees that have well-defined
    immutable APIs, avoids explicitly run `replace` to propagate updates to the sub-trees and makes
    management of the top-level tree easier.

    **Note**: Any `Tree`s that are found in the output of `f` are set to being
    immutable, however the element is the to the same immutablity status as the
    input tree if they have the same type.

    Arguments:
        f: The function to be transformed.

    Returns:
        A function with top-level mutability.
    """

    f0 = f

    if inspect.ismethod(f):
        tree0 = f.__self__
        f = f.__func__
    elif isinstance(f, tree_m.Tree) and callable(f):
        tree0 = f
        f = f.__class__.__call__
    else:
        tree0 = None

    if tree0 is not None and not isinstance(tree0, Tree):
        name = f0.__name__ is hasattr(f0, "__name__") and f0.__class__.__name__
        raise TypeError(
            f"Invalid bounded method or callable '{name}', tried to infer unbouded function and instance, "
            f"expected a 'Tree' instance but '{type(tree0).__name__}' instead. Try using an unbounded class method instead."
        )

    @functools.wraps(f)
    def wrapper(tree: tree_m.Tree, *args, **kwargs):
        if not isinstance(tree, tree_m.Tree):
            raise TypeError(f"Expected 'Tree' type, got '{type(tree).__name__}'")

        output, _ = mutable(f, toplevel_only=True)(tree, *args, **kwargs)

        if isinstance(output, tuple):
            *ys, last = output
        else:
            ys = ()
            last = output

        if type(last) is type(tree):
            tree_m._set_mutable(last, tree._mutable)

        if isinstance(output, tuple):
            return (*ys, last)
        else:
            return last

    wrapper._treeo_mutable = True

    if tree0 is not None:

        @functools.wraps(f)
        def obj_wrapper(*args, **kwargs):
            return wrapper(tree0, *args, **kwargs)

        return obj_wrapper

    return wrapper