treeo.Tree
__init_subclass__()
classmethod
special
This method is called when a class is subclassed.
The default implementation does nothing. It may be overridden to extend subclasses.
Source code in treeo/tree.py
def __init_subclass__(cls):
jax.tree_util.register_pytree_node_class(cls)
# Restore the signature
sig = inspect.signature(cls.__init__)
parameters = tuple(sig.parameters.values())
cls.__signature__ = sig.replace(parameters=parameters[1:])
annotations = utils._get_all_annotations(cls)
class_vars = utils._get_all_vars(cls)
# init class variables
cls._field_metadata = {}
cls._factory_fields = {}
cls._default_field_values = {}
cls._subtrees = None
cls._mutable = None
for field, value in class_vars.items():
if isinstance(value, dataclasses.Field):
# save defaults
if value.default is not dataclasses.MISSING:
cls._default_field_values[field] = value.default
elif value.default_factory is not dataclasses.MISSING:
cls._factory_fields[field] = value.default_factory
# extract metadata
if value.metadata is not None and "node" in value.metadata:
cls._field_metadata[field] = types.FieldMetadata(
node=value.metadata["node"],
kind=value.metadata["kind"],
)
for field, value in annotations.items():
if field not in cls._field_metadata:
is_node = any(issubclass(t, Tree) for t in utils._all_types(value))
cls._field_metadata[field] = types.FieldMetadata(
node=is_node,
kind=type(None),
)
check_metadata_updates(self)
Checks for new fields, if found, adds them to the metadata.
Source code in treeo/tree.py
def check_metadata_updates(self):
"""
Checks for new fields, if found, adds them to the metadata.
"""
# flattening / unflattening process updates metadata
updates = copy(self)
utils._safe_update_fields_from(self, updates)