Defining Fields
Tree fields are divided into two categories:
node
fields: they are considered as part of the pytree, JAX functions such astree_map
will operate over them.static
fields: they are part of thePyTreeDef
, JAX functions will not operate over them, but JAX is still aware of them, e.g. JAX will recompile jitted functions is case these fields change.
import treeo as to
@dataclass
class Person(to.Tree):
height: float = to.field(node=True)
name: str = to.field(node=False)
person = Person(height=1.5, name='John')
tree_map(lambda x: x + 1, person) # Person(height=2.5, name='John')
field
is just a wrapper over dataclasses.field
that adds the node
and kind
arguments you can use all dataclass
features. However, dataclasses are orthogonal to Treeo, this means that you can naturally use non-dataclass classes:
class Person(to.Tree):
height: float = to.field(node=True)
name: str = to.field(node=False)
def __init__(self, height: float, name: str):
self.height = height
self.name = name
person = Person(height=1.5, name='John')
tree_map(lambda x: x + 1, person) # Person(height=2.5, name='John')