treex.apply
Applies a function to all to.Tree
s in a Pytree. Works very similar to jax.tree_map
,
but its values are to.Tree
s instead of leaves, also f
should apply the changes inplace to Tree object.
If inplace
is False
, a copy of the first object is returned with the changes applied.
The rest
of the objects are always copied.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
f |
Callable[..., NoneType] |
The function to apply. |
required |
obj |
~A |
a pytree possibly containing Trees. |
required |
*rest |
~A |
additional pytrees. |
() |
inplace |
bool |
If |
False |
Returns:
Type | Description |
---|---|
~A |
A new pytree with the updated Trees or the same input |
Source code in treeo/api.py
def apply(f: tp.Callable[..., None], obj: A, *rest: A, inplace: bool = False) -> A:
"""
Applies a function to all `to.Tree`s in a Pytree. Works very similar to `jax.tree_map`,
but its values are `to.Tree`s instead of leaves, also `f` should apply the changes inplace to Tree object.
If `inplace` is `False`, a copy of the first object is returned with the changes applied.
The `rest` of the objects are always copied.
Arguments:
f: The function to apply.
obj: a pytree possibly containing Trees.
*rest: additional pytrees.
inplace: If `True`, the input `obj` is mutated.
Returns:
A new pytree with the updated Trees or the same input `obj` if `inplace` is `True`.
"""
rest = tree_m.copy(rest)
if not inplace:
obj = tree_m.copy(obj)
objs = (obj,) + rest
def nested_fn(obj, *rest):
if isinstance(obj, Tree):
apply(f, obj, *rest, inplace=True)
jax.tree_map(
nested_fn,
*objs,
is_leaf=lambda x: isinstance(x, Tree) and not x in objs,
)
if isinstance(obj, Tree):
f(obj, *rest)
return obj