Skip to content

treeo.apply

Applies a function to all to.Trees in a Pytree. Works very similar to jax.tree_map, but its values are to.Trees instead of leaves, also f should apply the changes inplace to Tree object.

Parameters:

Name Type Description Default
f Callable[..., NoneType]

The function to apply.

required
obj ~A

a pytree possibly containing Trees.

required
*rest ~A

additional pytrees.

()
inplace bool

If True, the input obj is mutated.

False

Returns:

Type Description
~A

A new pytree with the updated Trees or the same input obj if inplace is True.

Source code in treeo/tree.py
def apply(
    f: tp.Callable[..., None],
    obj: A,
    *rest: A,
    inplace: bool = False,
    _top_inplace: tp.Optional[bool] = None,
    _top_level: bool = True,
) -> A:
    """
    Applies a function to all `to.Tree`s in a Pytree. Works very similar to `jax.tree_map`,
    but its values are `to.Tree`s instead of leaves, also `f` should apply the changes inplace to Tree object.

    Arguments:
        f: The function to apply.
        obj: a pytree possibly containing Trees.
        *rest: additional pytrees.
        inplace: If `True`, the input `obj` is mutated.

    Returns:
        A new pytree with the updated Trees or the same input `obj` if `inplace` is `True`.
    """
    if _top_inplace is None:
        _top_inplace = inplace

    if _top_level:
        rest = copy(rest)
        if not inplace:
            obj = copy(obj)

    objs = (obj,) + rest

    def nested_fn(obj, *rest):
        if isinstance(obj, Tree):
            apply(
                f,
                obj,
                *rest,
                inplace=True,
                _top_inplace=_top_inplace,
                _top_level=False,
            )

    jax.tree_map(
        nested_fn,
        *objs,
        is_leaf=lambda x: isinstance(x, Tree) and not x in objs,
    )

    if isinstance(obj, Tree):
        if _top_inplace or obj._mutable:
            f(obj, *rest)
        else:
            with _make_mutable_toplevel(obj):
                f(obj, *rest)

    return obj