Skip to content

treeo.mutable

A decorator that transforms a stateful function f that receives an Tree instance as a its first argument into a function that returns a tuple of the result and a Tree with the new state.

This is useful for 2 reasons: * It transforms f into a pure function. * It allows Immutable Trees to perform inline field updates without getting RuntimeErrors.

Note that since the original object is not modified, Immutable instance remain in the end immutable.

Examples:

def accumulate_id(tree: MyTree, x: int) -> int:
    tree.n += x
    return x

tree0 = MyTree(n=4)
y, tree1 = mutable(accumulate_id)(tree0, 1)

assert tree0.n == 4
assert tree1.n == 5
assert y == 1

Note: Any Trees that are found in the output of f are set to being immutable.

Parameters:

Name Type Description Default
f Callable[..., ~A]

The function to be transformed.

required
toplevel_only bool

If True, only the top-level object is made mutable.

False

Returns:

Type Description
Callable[..., Tuple[~A, Any]]

A function that returns a tuple of the result and a Tree with the new state.

Source code in treeo/api.py
def mutable(
    f: tp.Callable[..., A],
    *,
    toplevel_only: bool = False,
) -> tp.Callable[..., tp.Tuple[A, tp.Any]]:
    """
    A decorator that transforms a stateful function `f` that receives an Tree
    instance as a its first argument into a function that returns a tuple of the result and a Tree
    with the new state.

    This is useful for 2 reasons:
    * It transforms `f` into a pure function.
    * It allows `Immutable` Trees to perform inline field updates without getting `RuntimeError`s.

    Note that since the original object is not modified, `Immutable` instance remain in the end immutable.

    Example:

    ```python
    def accumulate_id(tree: MyTree, x: int) -> int:
        tree.n += x
        return x

    tree0 = MyTree(n=4)
    y, tree1 = mutable(accumulate_id)(tree0, 1)

    assert tree0.n == 4
    assert tree1.n == 5
    assert y == 1
    ```

    **Note**: Any `Tree`s that are found in the output of `f` are set to being
    immutable.

    Arguments:
        f: The function to be transformed.
        toplevel_only: If `True`, only the top-level object is made mutable.

    Returns:
        A function that returns a tuple of the result and a Tree with the new state.
    """

    f0 = f

    if inspect.ismethod(f):
        tree0 = f.__self__
        f = f.__func__
    elif isinstance(f, tree_m.Tree) and callable(f):
        tree0 = f
        f = f.__class__.__call__
    else:
        tree0 = None

    if tree0 is not None and not isinstance(tree0, Tree):
        name = f0.__name__ is hasattr(f0, "__name__") and f0.__class__.__name__
        raise TypeError(
            f"Invalid bounded method or callable '{name}', tried to infer unbouded function and instance, "
            f"expected a 'Tree' instance but '{type(tree0).__name__}' instead. Try using an unbounded class method instead."
        )

    @functools.wraps(f)
    def wrapper(tree, *args, **kwargs) -> tp.Tuple[A, tp.Any]:

        tree = tree_m.copy(tree)

        with tree_m.make_mutable(tree, toplevel_only=toplevel_only):
            output = f(tree, *args, **kwargs)

        def _make_output_immutable(a: Tree):
            tree_m._set_mutable(a, None)

        output = tree_m.apply(_make_output_immutable, output)

        return output, tree

    wrapper._treeo_mutable = True

    if tree0 is not None:

        @functools.wraps(f)
        def obj_wrapper(*args, **kwargs):
            return wrapper(tree0, *args, **kwargs)

        return obj_wrapper

    return wrapper