Skip to content

Map

Applies a function to all leaves in a pytree using jax.tree_map. If filters are given then the function will be applied only to the subset of leaves that match the filters. Further, property of the field's kind type can be used within map by setting the field_info argument as True.

For example, if we want to zero all batch stats we can do:

Example:

@dataclass
class MyTree(to.Tree):
    a: int = to.field(node=True, kind=Parameter)
    b: int = to.field(node=True, kind=BatchStat)

tree = MyTree(a=1, b=2)

to.map(lambda _: 0, tree, BatchStat) # MyTree(a=1, b=0)

We could also apply a custom function to those fields with the Parameter type

def parameter_fn(field):
    return f + 1

to.map(parameter_fn, tree, Parameter, field_info = True)  # MyTree(a=2, b=2)

map is equivalent to filter -> tree_map -> merge in sequence.

If inplace is True, the input obj is mutated and returned. You can only update inplace if the input obj has a __dict__ attribute, else a TypeError is raised.