treex.Tree
Source code in treeo/tree.py
class Tree(metaclass=TreeMeta):
_field_metadata: tp.Dict[str, types.FieldMetadata]
_factory_fields: tp.Dict[str, tp.Callable[[], tp.Any]]
_default_field_values: tp.Dict[str, tp.Any]
_subtrees: tp.Optional[tp.Tuple[str, ...]]
@property
def field_metadata(self) -> tp.Mapping[str, types.FieldMetadata]:
return MappingProxyType(self._field_metadata)
def update_field_metadata(
self: T,
field: str,
node: tp.Optional[bool] = None,
kind: tp.Optional[type] = None,
opaque: tp.Union[bool, utils.OpaquePredicate, None] = None,
) -> T:
module = copy(self)
field_metadata = module._field_metadata[field]
updates = {}
if node is not None:
updates.update(node=node)
if kind is not None:
updates.update(kind=kind)
if opaque is not None:
updates.update(opaque=opaque)
if updates:
field_metadata = field_metadata.update(**updates)
module._field_metadata[field] = field_metadata
return module
def check_metadata_updates(self):
"""
Checks for new fields, if found, adds them to the metadata.
"""
with _CONTEXT.update(flatten_mode=FlattenMode.all_fields):
jax.tree_flatten(self)
def _update_local_metadata(self):
for field, value in vars(self).items():
if field not in self._field_metadata:
self._field_metadata[field] = types.FieldMetadata(
node=isinstance(value, Tree),
kind=type(value),
opaque=False,
)
def __init_subclass__(cls):
jax.tree_util.register_pytree_node_class(cls)
# Restore the signature
sig = inspect.signature(cls.__init__)
parameters = tuple(sig.parameters.values())
cls.__signature__ = sig.replace(parameters=parameters[1:])
annotations = utils._get_all_annotations(cls)
class_vars = utils._get_all_vars(cls)
# init class variables
cls._field_metadata = {}
cls._factory_fields = {}
cls._default_field_values = {}
cls._subtrees = None
for field, value in class_vars.items():
if isinstance(value, dataclasses.Field):
# save defaults
if value.default is not dataclasses.MISSING:
cls._default_field_values[field] = value.default
elif value.default_factory is not dataclasses.MISSING:
cls._factory_fields[field] = value.default_factory
# extract metadata
if value.metadata is not None and "node" in value.metadata:
cls._field_metadata[field] = types.FieldMetadata(
node=value.metadata["node"],
kind=value.metadata["kind"],
opaque=value.metadata["opaque"],
)
for field, value in annotations.items():
if field not in cls._field_metadata:
is_node = any(issubclass(t, Tree) for t in utils._all_types(value))
cls._field_metadata[field] = types.FieldMetadata(
node=is_node,
kind=type(None),
opaque=False,
)
def tree_flatten(self):
fields = vars(self).copy()
node_fields = {}
static_fields = {}
# auto-annotations
self._update_local_metadata()
if _CONTEXT.flatten_mode == FlattenMode.all_fields:
node_fields = fields
elif _CONTEXT.flatten_mode == FlattenMode.no_fields:
static_fields = fields
else: # normal or None
for field, value in fields.items():
field_annotation = self._field_metadata[field]
if field_annotation.node:
node_fields[field] = value
elif not field_annotation.node and field_annotation.opaque != False:
static_fields[field] = utils.Opaque(
value,
predicate=field_annotation.opaque
if not isinstance(field_annotation.opaque, bool)
else None,
)
else:
static_fields[field] = value
# maybe convert to FieldInfo
if _CONTEXT.add_field_info:
for field in node_fields.keys():
if field in TREE_PRIVATE_FIELDS:
continue
kind = self._field_metadata[field].kind
# leaves, treedef
node_fields[field] = jax.tree_map(
lambda x: FieldInfo(
name=field,
value=x,
kind=kind,
module=self,
)
if not isinstance(x, Tree)
else x,
node_fields[field],
is_leaf=lambda x: isinstance(x, Tree),
)
children = (node_fields,)
return children, static_fields
@classmethod
def tree_unflatten(cls, static_fields, children):
module = cls.__new__(cls)
(node_fields,) = children
if _CONTEXT.add_field_info:
for field in node_fields.keys():
node_fields[field] = jax.tree_map(
lambda x: x.value if isinstance(x, FieldInfo) else x,
node_fields[field],
is_leaf=lambda x: isinstance(x, Tree),
)
module.__dict__.update(node_fields, **static_fields)
# extract value from Opaque
for field, value in static_fields.items():
if (
isinstance(value, utils.Opaque)
and field in module._field_metadata
and module._field_metadata[field].opaque
):
setattr(module, field, value.value)
return module
__init_subclass__()
classmethod
special
This method is called when a class is subclassed.
The default implementation does nothing. It may be overridden to extend subclasses.
Source code in treeo/tree.py
def __init_subclass__(cls):
jax.tree_util.register_pytree_node_class(cls)
# Restore the signature
sig = inspect.signature(cls.__init__)
parameters = tuple(sig.parameters.values())
cls.__signature__ = sig.replace(parameters=parameters[1:])
annotations = utils._get_all_annotations(cls)
class_vars = utils._get_all_vars(cls)
# init class variables
cls._field_metadata = {}
cls._factory_fields = {}
cls._default_field_values = {}
cls._subtrees = None
for field, value in class_vars.items():
if isinstance(value, dataclasses.Field):
# save defaults
if value.default is not dataclasses.MISSING:
cls._default_field_values[field] = value.default
elif value.default_factory is not dataclasses.MISSING:
cls._factory_fields[field] = value.default_factory
# extract metadata
if value.metadata is not None and "node" in value.metadata:
cls._field_metadata[field] = types.FieldMetadata(
node=value.metadata["node"],
kind=value.metadata["kind"],
opaque=value.metadata["opaque"],
)
for field, value in annotations.items():
if field not in cls._field_metadata:
is_node = any(issubclass(t, Tree) for t in utils._all_types(value))
cls._field_metadata[field] = types.FieldMetadata(
node=is_node,
kind=type(None),
opaque=False,
)
check_metadata_updates(self)
Checks for new fields, if found, adds them to the metadata.
Source code in treeo/tree.py
def check_metadata_updates(self):
"""
Checks for new fields, if found, adds them to the metadata.
"""
with _CONTEXT.update(flatten_mode=FlattenMode.all_fields):
jax.tree_flatten(self)