Skip to content

treeo.map

Applies a function to all leaves in a pytree using jax.tree_map, if filters are given then the function will be applied only to the subset of leaves that match the filters. For more information see map's user guide.

Parameters:

Name Type Description Default
f Callable

The function to apply to the leaves.

required
obj ~A

a pytree possibly containing to.Trees.

required
*filters Union[Type[Any], Callable[[FieldInfo], bool]]

The filters used to select the leaves to which the function will be applied.

()
flatten_mode Union[treeo.tree.FlattenMode, str]

Sets a new FlattenMode context for the operation, if None the current context is used.

None
add_field_info

Represent the leaves of the tree by a FieldInfo type. This enables values of the field such as

required

Returns:

Type Description
~A

A new pytree with the changes applied.

Source code in treeo/api.py
def map(
    f: tp.Callable,
    obj: A,
    *filters: Filter,
    flatten_mode: tp.Union[FlattenMode, str, None] = None,
    is_leaf: tp.Callable[[tp.Any], bool] = None,
    field_info: tp.Optional[bool] = False,
) -> A:
    """
    Applies a function to all leaves in a pytree using `jax.tree_map`, if `filters` are given then
    the function will be applied only to the subset of leaves that match the filters. For more information see
    [map's user guide](https://cgarciae.github.io/treeo/user-guide/api/map).


    Arguments:
        f: The function to apply to the leaves.
        obj: a pytree possibly containing `to.Tree`s.
        *filters: The filters used to select the leaves to which the function will be applied.
        flatten_mode: Sets a new `FlattenMode` context for the operation, if `None` the current context is used.
        add_field_info: Represent the leaves of the tree by a `FieldInfo` type. This enables values of the field such as
        kind and value to be used within the `map` function.

    Returns:
        A new pytree with the changes applied.
    """

    input_obj = obj

    has_filters = len(filters) > 0

    with _flatten_context(flatten_mode):
        if has_filters:
            new_obj = filter(obj, *filters)
        else:
            new_obj = obj

        # Conditionally build map function with, or without, the leaf nodes' field info.
        if field_info:
            with add_field_info():
                new_obj: A = jax.tree_map(f, new_obj, is_leaf=is_leaf)
        else:
            new_obj: A = jax.tree_map(f, new_obj, is_leaf=is_leaf)

        if has_filters:
            new_obj = merge(obj, new_obj)

    return new_obj