Skip to content

treex.Tree

Source code in treeo/tree.py
class Tree(metaclass=TreeMeta):
    _field_metadata: tp.Dict[str, types.FieldMetadata]
    _factory_fields: tp.Dict[str, tp.Callable[[], tp.Any]]
    _default_field_values: tp.Dict[str, tp.Any]
    _subtrees: tp.Optional[tp.Tuple[str, ...]]

    @property
    def field_metadata(self) -> tp.Mapping[str, types.FieldMetadata]:
        return MappingProxyType(self._field_metadata)

    def update_field_metadata(
        self: T,
        field: str,
        node: tp.Optional[bool] = None,
        kind: tp.Optional[type] = None,
        opaque: tp.Union[bool, utils.OpaquePredicate, None] = None,
    ) -> T:
        module = copy(self)

        field_metadata = module._field_metadata[field]
        updates = {}

        if node is not None:
            updates.update(node=node)

        if kind is not None:
            updates.update(kind=kind)

        if opaque is not None:
            updates.update(opaque=opaque)

        if updates:
            field_metadata = field_metadata.update(**updates)

        module._field_metadata[field] = field_metadata

        return module

    def check_metadata_updates(self):
        """
        Checks for new fields, if found, adds them to the metadata.
        """
        with _CONTEXT.update(flatten_mode=FlattenMode.all_fields):
            jax.tree_flatten(self)

    def _update_local_metadata(self):
        for field, value in vars(self).items():

            if field not in self._field_metadata:
                self._field_metadata[field] = types.FieldMetadata(
                    node=isinstance(value, Tree),
                    kind=type(value),
                    opaque=False,
                )

    def __init_subclass__(cls):
        jax.tree_util.register_pytree_node_class(cls)

        # Restore the signature
        sig = inspect.signature(cls.__init__)
        parameters = tuple(sig.parameters.values())
        cls.__signature__ = sig.replace(parameters=parameters[1:])

        annotations = utils._get_all_annotations(cls)
        class_vars = utils._get_all_vars(cls)

        # init class variables
        cls._field_metadata = {}
        cls._factory_fields = {}
        cls._default_field_values = {}
        cls._subtrees = None

        for field, value in class_vars.items():
            if isinstance(value, dataclasses.Field):

                # save defaults
                if value.default is not dataclasses.MISSING:
                    cls._default_field_values[field] = value.default
                elif value.default_factory is not dataclasses.MISSING:
                    cls._factory_fields[field] = value.default_factory

                # extract metadata
                if value.metadata is not None and "node" in value.metadata:
                    cls._field_metadata[field] = types.FieldMetadata(
                        node=value.metadata["node"],
                        kind=value.metadata["kind"],
                        opaque=value.metadata["opaque"],
                    )

        for field, value in annotations.items():

            if field not in cls._field_metadata:
                is_node = any(issubclass(t, Tree) for t in utils._all_types(value))
                cls._field_metadata[field] = types.FieldMetadata(
                    node=is_node,
                    kind=type(None),
                    opaque=False,
                )

    def tree_flatten(self):

        fields = vars(self).copy()

        node_fields = {}
        static_fields = {}

        # auto-annotations
        self._update_local_metadata()

        if _CONTEXT.flatten_mode == FlattenMode.all_fields:
            node_fields = fields
        elif _CONTEXT.flatten_mode == FlattenMode.no_fields:
            static_fields = fields
        else:  # normal or None
            for field, value in fields.items():
                field_annotation = self._field_metadata[field]

                if field_annotation.node:
                    node_fields[field] = value
                elif not field_annotation.node and field_annotation.opaque != False:
                    static_fields[field] = utils.Opaque(
                        value,
                        predicate=field_annotation.opaque
                        if not isinstance(field_annotation.opaque, bool)
                        else None,
                    )
                else:
                    static_fields[field] = value

        # maybe convert to FieldInfo
        if _CONTEXT.add_field_info:
            for field in node_fields.keys():
                if field in TREE_PRIVATE_FIELDS:
                    continue

                kind = self._field_metadata[field].kind
                # leaves, treedef
                node_fields[field] = jax.tree_map(
                    lambda x: FieldInfo(
                        name=field,
                        value=x,
                        kind=kind,
                        module=self,
                    )
                    if not isinstance(x, Tree)
                    else x,
                    node_fields[field],
                    is_leaf=lambda x: isinstance(x, Tree),
                )

        children = (node_fields,)

        return children, static_fields

    @classmethod
    def tree_unflatten(cls, static_fields, children):

        module = cls.__new__(cls)
        (node_fields,) = children

        if _CONTEXT.add_field_info:
            for field in node_fields.keys():
                node_fields[field] = jax.tree_map(
                    lambda x: x.value if isinstance(x, FieldInfo) else x,
                    node_fields[field],
                    is_leaf=lambda x: isinstance(x, Tree),
                )

        module.__dict__.update(node_fields, **static_fields)

        # extract value from Opaque
        for field, value in static_fields.items():
            if (
                isinstance(value, utils.Opaque)
                and field in module._field_metadata
                and module._field_metadata[field].opaque
            ):
                setattr(module, field, value.value)

        return module

__init_subclass__() classmethod special

This method is called when a class is subclassed.

The default implementation does nothing. It may be overridden to extend subclasses.

Source code in treeo/tree.py
def __init_subclass__(cls):
    jax.tree_util.register_pytree_node_class(cls)

    # Restore the signature
    sig = inspect.signature(cls.__init__)
    parameters = tuple(sig.parameters.values())
    cls.__signature__ = sig.replace(parameters=parameters[1:])

    annotations = utils._get_all_annotations(cls)
    class_vars = utils._get_all_vars(cls)

    # init class variables
    cls._field_metadata = {}
    cls._factory_fields = {}
    cls._default_field_values = {}
    cls._subtrees = None

    for field, value in class_vars.items():
        if isinstance(value, dataclasses.Field):

            # save defaults
            if value.default is not dataclasses.MISSING:
                cls._default_field_values[field] = value.default
            elif value.default_factory is not dataclasses.MISSING:
                cls._factory_fields[field] = value.default_factory

            # extract metadata
            if value.metadata is not None and "node" in value.metadata:
                cls._field_metadata[field] = types.FieldMetadata(
                    node=value.metadata["node"],
                    kind=value.metadata["kind"],
                    opaque=value.metadata["opaque"],
                )

    for field, value in annotations.items():

        if field not in cls._field_metadata:
            is_node = any(issubclass(t, Tree) for t in utils._all_types(value))
            cls._field_metadata[field] = types.FieldMetadata(
                node=is_node,
                kind=type(None),
                opaque=False,
            )

check_metadata_updates(self)

Checks for new fields, if found, adds them to the metadata.

Source code in treeo/tree.py
def check_metadata_updates(self):
    """
    Checks for new fields, if found, adds them to the metadata.
    """
    with _CONTEXT.update(flatten_mode=FlattenMode.all_fields):
        jax.tree_flatten(self)