Skip to content

Apply

Applies a function to all to.Trees in a Pytree. Works very similar to jax.tree_map, but its values are to.Trees instead of leaves, also f should apply the changes inplace to Tree object. For example, if you have a setup like this:

import treeo as to

class Child(to.Tree):
    training: bool = to.field(default=True, node=False)
    value: float = to.field(node=True)
    ...

class Parent(to.Tree):
    training: bool = to.field(default=True, node=False)
    left: Child
    righ: Child
    ...

tree = Parent(...)
And then want to update all .training fields to False you can do this:

def set_training_false(tree):
    if isinstance(tree, (Parent, Child)):
        tree.training = False

tree = to.apply(set_training_false, tree)
Here tree is iternally copied so you can mutate is safely during apply without affecting the original object, however if inplace is True the original object will be mutated and returned.