Skip to content

treex.Tree

__init_subclass__() classmethod special

This method is called when a class is subclassed.

The default implementation does nothing. It may be overridden to extend subclasses.

Source code in treeo/tree.py
def __init_subclass__(cls):
    jax.tree_util.register_pytree_node_class(cls)

    # Restore the signature
    sig = inspect.signature(cls.__init__)
    parameters = tuple(sig.parameters.values())
    cls.__signature__ = sig.replace(parameters=parameters[1:])

    annotations = utils._get_all_annotations(cls)
    class_vars = utils._get_all_vars(cls)

    # init class variables
    cls._field_metadata = {}
    cls._factory_fields = {}
    cls._default_field_values = {}
    cls._subtrees = None

    for field, value in class_vars.items():
        if isinstance(value, dataclasses.Field):

            # save defaults
            if value.default is not dataclasses.MISSING:
                cls._default_field_values[field] = value.default
            elif value.default_factory is not dataclasses.MISSING:
                cls._factory_fields[field] = value.default_factory

            # extract metadata
            if value.metadata is not None and "node" in value.metadata:
                cls._field_metadata[field] = types.FieldMetadata(
                    node=value.metadata["node"],
                    kind=value.metadata["kind"],
                    opaque=value.metadata["opaque"],
                )

    for field, value in annotations.items():

        if field not in cls._field_metadata:
            is_node = any(issubclass(t, Tree) for t in utils._all_types(value))
            cls._field_metadata[field] = types.FieldMetadata(
                node=is_node,
                kind=type(None),
                opaque=False,
            )

check_metadata_updates(self)

Checks for new fields, if found, adds them to the metadata.

Source code in treeo/tree.py
def check_metadata_updates(self):
    """
    Checks for new fields, if found, adds them to the metadata.
    """
    with _CONTEXT.update(flatten_mode=FlattenMode.all_fields):
        jax.tree_flatten(self)