class ModuleMeta(to.TreeMeta):
def construct(cls, obj: M, *args, **kwargs) -> M:
# reset context during construction
with _InitContext():
obj = super().construct(obj, *args, **kwargs)
if not hasattr(obj, "name"):
obj.name = utils._lower_snake_case(obj.__class__.__name__)
if to.in_compact():
if _INIT_CONTEXT.key is None:
raise RuntimeError(
f"Trying to construct new module {obj} with a compact context outside of `init` or an `rng_key` context."
)
if not obj.initialized:
obj.init(key=next_key(), inplace=True, _set_initialize=False)
return obj