Defining Fields
Tree fields are divided into two categories:
nodefields: they are considered as part of the pytree, JAX functions such astree_mapwill operate over them.staticfields: they are part of thePyTreeDef, JAX functions will not operate over them, but JAX is still aware of them, e.g. JAX will recompile jitted functions is case these fields change.
import treeo as to
@dataclass
class Person(to.Tree):
height: float = to.field(node=True)
name: str = to.field(node=False)
person = Person(height=1.5, name='John')
tree_map(lambda x: x + 1, person) # Person(height=2.5, name='John')
field is just a wrapper over dataclasses.field that adds the node and kind arguments you can use all dataclass features. However, dataclasses are orthogonal to Treeo, this means that you can naturally use non-dataclass classes:
class Person(to.Tree):
height: float = to.field(node=True)
name: str = to.field(node=False)
def __init__(self, height: float, name: str):
self.height = height
self.name = name
person = Person(height=1.5, name='John')
tree_map(lambda x: x + 1, person) # Person(height=2.5, name='John')