Apply
Applies a function to all to.Trees in a Pytree. Works very similar to jax.tree_map, but its values are to.Trees instead of leaves, also f should apply the changes inplace to Tree object. For example, if you have a setup like this:
import treeo as to
class Child(to.Tree):
training: bool = to.field(default=True, node=False)
value: float = to.field(node=True)
...
class Parent(to.Tree):
training: bool = to.field(default=True, node=False)
left: Child
righ: Child
...
tree = Parent(...)
.training fields to False you can do this:
def set_training_false(tree):
if isinstance(tree, (Parent, Child)):
tree.training = False
tree = to.apply(set_training_false, tree)
tree is iternally copied so you can mutate is safely during apply without affecting the original object, however if inplace is True the original object will be mutated and returned.