Skip to content

treex.ModuleMeta

Source code in treex/module.py
class ModuleMeta(to.TreeMeta):
    def construct(cls, obj: M, *args, **kwargs) -> M:
        # reset context during construction
        with _InitContext():
            obj = super().construct(obj, *args, **kwargs)

        if not hasattr(obj, "name"):
            obj.name = utils._lower_snake_case(obj.__class__.__name__)

        if to.in_compact():
            if _INIT_CONTEXT.key is None:
                raise RuntimeError(
                    f"Trying to construct new module {obj} with a compact context outside of `init` or an `rng_key` context."
                )

            if not obj.initialized:
                obj.init(key=next_key(), inplace=True, _set_initialize=False)

        return obj