Skip to content

treex.apply

Applies a function to all to.Trees in a Pytree. Works very similar to jax.tree_map, but its values are to.Trees instead of leaves, also f should apply the changes inplace to Tree object.

If inplace is False, a copy of the first object is returned with the changes applied. The rest of the objects are always copied.

Parameters:

Name Type Description Default
f Callable[..., NoneType]

The function to apply.

required
obj ~A

a pytree possibly containing Trees.

required
*rest ~A

additional pytrees.

()
inplace bool

If True, the input obj is mutated.

False

Returns:

Type Description
~A

A new pytree with the updated Trees or the same input obj if inplace is True.

Source code in treeo/api.py
def apply(f: tp.Callable[..., None], obj: A, *rest: A, inplace: bool = False) -> A:
    """
    Applies a function to all `to.Tree`s in a Pytree. Works very similar to `jax.tree_map`,
    but its values are `to.Tree`s instead of leaves, also `f` should apply the changes inplace to Tree object.

    If `inplace` is `False`, a copy of the first object is returned with the changes applied.
    The `rest` of the objects are always copied.

    Arguments:
        f: The function to apply.
        obj: a pytree possibly containing Trees.
        *rest: additional pytrees.
        inplace: If `True`, the input `obj` is mutated.

    Returns:
        A new pytree with the updated Trees or the same input `obj` if `inplace` is `True`.
    """
    rest = tree_m.copy(rest)

    if not inplace:
        obj = tree_m.copy(obj)

    objs = (obj,) + rest

    def nested_fn(obj, *rest):
        if isinstance(obj, Tree):
            apply(f, obj, *rest, inplace=True)

    jax.tree_map(
        nested_fn,
        *objs,
        is_leaf=lambda x: isinstance(x, Tree) and not x in objs,
    )

    if isinstance(obj, Tree):
        f(obj, *rest)

    return obj