Skip to content

Map

Applies a function to all leaves in the Module using jax.tree_map. If filters are given then the function will be applied only to the subset of leaves that match the filters.

For example, if we want to zero all batch stats we can do:

Example:

@dataclass
class MyModule(x.Module):
    a: int = tx.Parameter.node()
    b: int = tx.BatchStat.node()

module = MyModule(a=1, b=2)

module.map(lambda _: 0, tx.BatchStat) # MyTree(a=1, b=0)

map is equivalent to filter -> tree_map -> merge in sequence.

If inplace is True, the input obj is mutated and returned. You can only update inplace if the input obj has a __dict__ attribute, else a TypeError is raised.