Skip to content


The filter method allows you to select a subtree by filtering based on a kind, all leaves whose field kind is a subclass of such type are kept, the rest are set to a special Nothing value.

tree = MyModule(a=1, b=2)

module.filter(Parameter) # MyModule(a=1, b=Nothing)
module.filter(BatchStat) # MyModule(a=Nothing, b=2)

Since Nothing is an empty Pytree it gets ignored by tree operations, this effectively allows you to easily operate on a subset of the fields:

negative = lambda x: -x

jax.tree_map(negative, module.filter(Parameter)) # MyModule(a=-1, b=Nothing)
jax.tree_map(negative, module.filter(BatchStat)) # MyModule(a=Nothing, b=-2)


As simple filters using the standard TreePart types are used often, some shortcuts are provided:

Shortcut Equivalence
.parameters() .filter(tx.Parameter)
.batch_stats() .filter(tx.BatchStat)
.rngs() .filter(tx.RNG)
.model_states() .filter(tx.ModelState)
.states() .filter(tx.State)
.metrics() .filter(tx.Metric)
.losses() .filter(tx.Loss)
.logs() .filter(tx.Log)

Based on this the first example can be written as:

module.parameters()  # MyModule(a=1, b=Nothing)
module.batch_stats() # MyModule(a=Nothing, b=2)

filter predicates

If you need to do more complex filtering, you can pass callables with the signature

FieldInfo -> bool

instead of types:

# all Parameters whose field name is "kernel"
    lambda field: issubclass(field.kind, Parameter) 
    and == "kernel"
# MyModule(a=Nothing, b=Nothing)

multiple filters

You can some queries by using multiple filters as *args. For a field to be kept it will required that all filters pass. Since passing types by themselves are "kind filters", one of the previous examples could be written as:

# all Parameters whose field name is "kernel"
    lambda field: == "kernel"
# MyModule(a=Nothing, b=Nothing)


If inplace is True, the input obj is mutated and returned. You can only update inplace if the input obj has a __dict__ attribute, else a TypeError is raised.