Field Metadata
All metadata you set either with to.field
or added by Treeo by default will be available in the field_metadata: Mapping[str, FieldMetadata]
property.
@dataclass
class Person(to.Tree):
height: float = to.field(node=True)
name: str = to.field(node=False, opaque=True)
mike = Person(height=1.5, name='Mike')
# no quite true, but you get the idea
assert mike.field_metadata == {
'height': FieldMetadata(
node=True,
kind=NoneType,
opaque=False,
),
'name': FieldMetadata(
node=False,
kind=NoneType,
opaque=True,
)
}
Changing field metadata
If at anypoint you want to change the metadata of any field you can do so by using the update_field_metadata
method. For example, imagine we have this definition of BatchNorm:
class BatchNorm(to.Tree):
# nodes
mean: jnp.ndarray = to.field(node=True, kind=BatchStat)
...
# static
features_in: int
momentum: float
...
model = BatchNorm(features_in=32, momentum=0.9)
momentum
hyperparameter field here is a float that you wish to make diffentiable in e.g. a meta-gradient setting. However, the original author of the class didn't consider this field a node, but you can get around this by updating its metadata:
class DifferentiableHyperParam:
pass
model = model.update_field_metadata(
'momentum', node=True, kind=DifferentiableHyperParam
)