Skip to content

treeo.Map

Mixin that adds a .map() method to the class.

map(self, f, *filters, *, flatten_mode=None, is_leaf=None)

map is a wrapper over treeo.map that passes self as the second argument.

Parameters:

Name Type Description Default
f Callable

The function to apply to the leaves.

required
*filters Union[Type[Any], Callable[[FieldInfo], bool]]

The filters used to select the leaves to which the function will be applied.

()
flatten_mode Union[treeo.tree.FlattenMode, str]

Sets a new FlattenMode context for the operation, if None the current context is used.

None

Returns:

Type Description
~A

A new pytree with the changes applied.

Source code in treeo/mixins.py
def map(
    self: A,
    f: tp.Callable,
    *filters: api.Filter,
    flatten_mode: tp.Union[api.FlattenMode, str, None] = None,
    is_leaf: tp.Callable[[tp.Any], bool] = None,
) -> A:
    """
    `map` is a wrapper over `treeo.map` that passes `self` as the second argument.

    Arguments:
        f: The function to apply to the leaves.
        *filters: The filters used to select the leaves to which the function will be applied.
        flatten_mode: Sets a new `FlattenMode` context for the operation, if `None` the current context is used.

    Returns:
        A new pytree with the changes applied.
    """
    return api.map(
        f,
        self,
        *filters,
        flatten_mode=flatten_mode,
        is_leaf=is_leaf,
    )