Applies a function to all leaves in a pytree using
filters are given then the function will be applied only to the subset of leaves that match the filters. Further, property of the field's kind type can be used within
map by setting the
field_info argument as True.
For example, if we want to zero all batch stats we can do:
@dataclass class MyTree(to.Tree): a: int = to.field(node=True, kind=Parameter) b: int = to.field(node=True, kind=BatchStat) tree = MyTree(a=1, b=2) to.map(lambda _: 0, tree, BatchStat) # MyTree(a=1, b=0)
We could also apply a custom function to those fields with the Parameter type
def parameter_fn(field): return f + 1 to.map(parameter_fn, tree, Parameter, field_info = True) # MyTree(a=2, b=2)
map is equivalent to
filter -> tree_map -> merge in sequence.
True, the input
obj is mutated and returned. You can only update inplace if the input
obj has a
__dict__ attribute, else a
TypeError is raised.