Apply
Applies a function to all to.Tree
s in a Pytree. Works very similar to jax.tree_map
, but its values are to.Tree
s instead of leaves, also f
should apply the changes inplace to Tree object. For example, if you have a setup like this:
import treeo as to
class Child(to.Tree):
training: bool = to.field(default=True, node=False)
value: float = to.field(node=True)
...
class Parent(to.Tree):
training: bool = to.field(default=True, node=False)
left: Child
righ: Child
...
tree = Parent(...)
.training
fields to False
you can do this:
def set_training_false(tree):
if isinstance(tree, (Parent, Child)):
tree.training = False
tree = to.apply(set_training_false, tree)
tree
is iternally copied so you can mutate is safely during apply
without affecting the original object, however if inplace
is True
the original object will be mutated and returned.