# treex.map

Applies a function to all leaves in a pytree using jax.tree_map, if filters are given then the function will be applied only to the subset of leaves that match the filters. For more information see map's user guide.

Parameters:

Name Type Description Default
f Callable

The function to apply to the leaves.

required
obj ~A

a pytree possibly containing to.Trees.

required
*filters Union[Type[Any], Callable[[FieldInfo], bool]]

The filters used to select the leaves to which the function will be applied.

()
inplace bool

If True, the input obj is mutated and returned.

False
flatten_mode Union[treeo.tree.FlattenMode, str]

Sets a new FlattenMode context for the operation, if None the current context is used.

None
add_field_info

Represent the leaves of the tree by a FieldInfo type. This enables values of the field such as

required

Returns:

Type Description
~A

A new pytree with the changes applied. If inplace is True, the input obj is returned.

Source code in treeo/api.py
def map(
f: tp.Callable,
obj: A,
*filters: Filter,
inplace: bool = False,
flatten_mode: tp.Union[FlattenMode, str, None] = None,
is_leaf: tp.Callable[[tp.Any], bool] = None,
field_info: tp.Optional[bool] = False,
) -> A:
"""
Applies a function to all leaves in a pytree using jax.tree_map, if filters are given then
the function will be applied only to the subset of leaves that match the filters. For more information see
[map's user guide](https://cgarciae.github.io/treeo/user-guide/api/map).

Arguments:
f: The function to apply to the leaves.
obj: a pytree possibly containing to.Trees.
*filters: The filters used to select the leaves to which the function will be applied.
inplace: If True, the input obj is mutated and returned.
flatten_mode: Sets a new FlattenMode context for the operation, if None the current context is used.
add_field_info: Represent the leaves of the tree by a FieldInfo type. This enables values of the field such as
kind and value to be used within the map function.

Returns:
A new pytree with the changes applied. If inplace is True, the input obj is returned.
"""
if inplace and not hasattr(obj, "__dict__"):
raise ValueError(
f"Cannot map inplace on objects with no __dict__ property, got {obj}"
)

input_obj = obj

has_filters = len(filters) > 0

with _flatten_context(flatten_mode):
if has_filters:
new_obj = filter(obj, *filters)
else:
new_obj = obj

# Conditionally build map function with, or without, the leaf nodes' field info.
if field_info: