treeo.Map
Mixin that adds a .map()
method to the class.
map(self, f, *filters, *, flatten_mode=None, is_leaf=None)
map
is a wrapper over treeo.map
that passes self
as the second argument.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
f |
Callable |
The function to apply to the leaves. |
required |
*filters |
Union[Type[Any], Callable[[FieldInfo], bool]] |
The filters used to select the leaves to which the function will be applied. |
() |
flatten_mode |
Union[treeo.tree.FlattenMode, str] |
Sets a new |
None |
Returns:
Type | Description |
---|---|
~A |
A new pytree with the changes applied. |
Source code in treeo/mixins.py
def map(
self: A,
f: tp.Callable,
*filters: api.Filter,
flatten_mode: tp.Union[api.FlattenMode, str, None] = None,
is_leaf: tp.Callable[[tp.Any], bool] = None,
) -> A:
"""
`map` is a wrapper over `treeo.map` that passes `self` as the second argument.
Arguments:
f: The function to apply to the leaves.
*filters: The filters used to select the leaves to which the function will be applied.
flatten_mode: Sets a new `FlattenMode` context for the operation, if `None` the current context is used.
Returns:
A new pytree with the changes applied.
"""
return api.map(
f,
self,
*filters,
flatten_mode=flatten_mode,
is_leaf=is_leaf,
)