treeo.apply
Applies a function to all to.Tree
s in a Pytree. Works very similar to jax.tree_map
,
but its values are to.Tree
s instead of leaves, also f
should apply the changes inplace to Tree object.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
f |
Callable[..., NoneType] |
The function to apply. |
required |
obj |
~A |
a pytree possibly containing Trees. |
required |
*rest |
~A |
additional pytrees. |
() |
inplace |
bool |
If |
False |
Returns:
Type | Description |
---|---|
~A |
A new pytree with the updated Trees or the same input |
Source code in treeo/tree.py
def apply(
f: tp.Callable[..., None],
obj: A,
*rest: A,
inplace: bool = False,
_top_inplace: tp.Optional[bool] = None,
_top_level: bool = True,
) -> A:
"""
Applies a function to all `to.Tree`s in a Pytree. Works very similar to `jax.tree_map`,
but its values are `to.Tree`s instead of leaves, also `f` should apply the changes inplace to Tree object.
Arguments:
f: The function to apply.
obj: a pytree possibly containing Trees.
*rest: additional pytrees.
inplace: If `True`, the input `obj` is mutated.
Returns:
A new pytree with the updated Trees or the same input `obj` if `inplace` is `True`.
"""
if _top_inplace is None:
_top_inplace = inplace
if _top_level:
rest = copy(rest)
if not inplace:
obj = copy(obj)
objs = (obj,) + rest
def nested_fn(obj, *rest):
if isinstance(obj, Tree):
apply(
f,
obj,
*rest,
inplace=True,
_top_inplace=_top_inplace,
_top_level=False,
)
jax.tree_map(
nested_fn,
*objs,
is_leaf=lambda x: isinstance(x, Tree) and not x in objs,
)
if isinstance(obj, Tree):
if _top_inplace or obj._mutable:
f(obj, *rest)
else:
with _make_mutable_toplevel(obj):
f(obj, *rest)
return obj