treex.map
Applies a function to all leaves in a pytree using jax.tree_map
, if filters
are given then
the function will be applied only to the subset of leaves that match the filters. For more information see
map's user guide.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
f |
Callable |
The function to apply to the leaves. |
required |
obj |
~A |
a pytree possibly containing |
required |
*filters |
Union[Type[Any], Callable[[FieldInfo], bool]] |
The filters used to select the leaves to which the function will be applied. |
() |
inplace |
bool |
If |
False |
flatten_mode |
Union[treeo.tree.FlattenMode, str] |
Sets a new |
None |
add_field_info |
Represent the leaves of the tree by a |
required |
Returns:
Type | Description |
---|---|
~A |
A new pytree with the changes applied. If |
Source code in treeo/api.py
def map(
f: tp.Callable,
obj: A,
*filters: Filter,
inplace: bool = False,
flatten_mode: tp.Union[FlattenMode, str, None] = None,
is_leaf: tp.Callable[[tp.Any], bool] = None,
field_info: tp.Optional[bool] = False,
) -> A:
"""
Applies a function to all leaves in a pytree using `jax.tree_map`, if `filters` are given then
the function will be applied only to the subset of leaves that match the filters. For more information see
[map's user guide](https://cgarciae.github.io/treeo/user-guide/api/map).
Arguments:
f: The function to apply to the leaves.
obj: a pytree possibly containing `to.Tree`s.
*filters: The filters used to select the leaves to which the function will be applied.
inplace: If `True`, the input `obj` is mutated and returned.
flatten_mode: Sets a new `FlattenMode` context for the operation, if `None` the current context is used.
add_field_info: Represent the leaves of the tree by a `FieldInfo` type. This enables values of the field such as
kind and value to be used within the `map` function.
Returns:
A new pytree with the changes applied. If `inplace` is `True`, the input `obj` is returned.
"""
if inplace and not hasattr(obj, "__dict__"):
raise ValueError(
f"Cannot map inplace on objects with no __dict__ property, got {obj}"
)
input_obj = obj
has_filters = len(filters) > 0
with _flatten_context(flatten_mode):
if has_filters:
new_obj = filter(obj, *filters)
else:
new_obj = obj
# Conditionally build map function with, or without, the leaf nodes' field info.
if field_info:
with add_field_info():
new_obj: A = jax.tree_map(f, new_obj, is_leaf=is_leaf)
else:
new_obj: A = jax.tree_map(f, new_obj, is_leaf=is_leaf)
if has_filters:
new_obj = merge(obj, new_obj)
if inplace:
input_obj.__dict__.update(new_obj.__dict__)
return input_obj
else:
return new_obj