Skip to content

treex.Module

Source code in treex/module.py
class Module(Treex, Filters, metaclass=ModuleMeta):
    # use to.field to copy class vars to instance
    _training: bool = to.static(True)
    _initialized: bool = to.static(False)
    _frozen: bool = to.static(False)

    def __init__(self, name: tp.Optional[str] = None):
        self.name = (
            name
            if name is not None
            else utils._lower_snake_case(self.__class__.__name__)
        )

    def initializing(self) -> bool:
        if not self.initialized:
            if not _INIT_CONTEXT.initializing:
                raise RuntimeError(
                    f"Trying run {self.__class__.__name__} for the first time outside of `init`"
                )

            return True

        return False

    @property
    def initialized(self) -> bool:
        return self._initialized

    @property
    def training(self) -> bool:
        return self._training

    @property
    def frozen(self) -> bool:
        return self._frozen

    def __init_subclass__(cls):
        if issubclass(cls, tp.Callable):
            orig_call = cls.__call__

            @functools.wraps(cls.__call__)
            def new_call(self: Module, *args, **kwargs):
                outputs = orig_call(self, *args, **kwargs)

                if (
                    contexts._CONTEXT.call_info is not None
                    and self not in contexts._CONTEXT.call_info
                ):
                    inputs = types.Inputs(*args, **kwargs)
                    contexts._CONTEXT.call_info[self] = (inputs, outputs)

                return outputs

            cls.__call__ = new_call

        return super().__init_subclass__()

    def init(
        self: M,
        key: tp.Union[int, jnp.ndarray],
        inputs: types.InputLike = to.MISSING,
        call_method: str = "__call__",
        *,
        inplace: bool = False,
        _set_initialize: bool = True,
    ) -> M:
        """
        Method version of `tx.init`, it applies `self` as first argument.

        `init` creates a new module with the same structure, but with its fields initialized given a seed `key`. The following
        procedure is used:

        1. The input `key` is split and iteratively updated before passing a derived value to any
            process that requires initialization.
        2. `Initializer`s are called and applied to the module first.
        3. `Module.rng_init` methods are called last.

        Arguments:
            key: The seed to use for initialization.
        Returns:
            The new module with the fields initialized.
        """
        module = self.copy() if not inplace else self
        key = utils.Key(key)

        with _INIT_CONTEXT.update(key=key, initializing=True):

            module: M = module.map(
                lambda initializer: (
                    initializer(next_key())
                    if isinstance(initializer, types.Initializer)
                    else initializer
                ),
                is_leaf=lambda x: isinstance(x, types.Initializer),
                inplace=True,
            )

            def call_rng_init(module: Module):
                if isinstance(module, Module) and not module._initialized:
                    module.rng_init()

            module = to.apply(call_rng_init, module, inplace=True)

            if inputs is not to.MISSING:
                inputs = types.Inputs.from_value(inputs)
                method = getattr(module, call_method)
                method(*inputs.args, **inputs.kwargs)

        if _set_initialize:

            def set_initialized(module: Module):
                if isinstance(module, Module) and not module._initialized:
                    module._initialized = True

            module = to.apply(set_initialized, module, inplace=True)

        return module

    def rng_init(self) -> None:
        pass

    def tabulate(
        self,
        inputs: tp.Union[types.InputLike, to.Missing] = to.MISSING,
        depth: int = -1,
        signature: bool = False,
        param_types: bool = True,
    ) -> str:
        """
        Returns a tabular representation of the module.

        Arguments:
            depth: The maximum depth of the representation in terms of nested Modules, -1 means no limit.
            signature: Whether to show the signature of the Module.
            param_types: Whether to show the types of the parameters.
        Returns:
            A string containing the tabular representation.
        """
        self = to.copy(self)

        if inputs is not to.MISSING:
            inputs = types.Inputs.from_value(inputs)

            if not isinstance(self, tp.Callable):
                raise TypeError(
                    "`inputs` can only be specified if the module is a callable."
                )

            with contexts._Context(call_info={}):

                # call using self to preserve references
                def eval_call(args, kwargs):
                    assert isinstance(self, tp.Callable)
                    return self(*args, **kwargs)

                jax.eval_shape(
                    eval_call,
                    inputs.args,
                    inputs.kwargs,
                )
                call_info = contexts._CONTEXT.call_info

        else:
            call_info = None

        with to.add_field_info():
            flat: tp.List[to.FieldInfo]
            flat, _ = jax.tree_flatten(self)
            tree_part_types: tp.Tuple[tp.Type[types.TreePart], ...] = tuple(
                {
                    field_info.kind
                    for field_info in flat
                    if utils._generic_issubclass(field_info.kind, types.TreePart)
                }
            )

        path = ()
        rows = list(
            utils._get_tabulate_rows(
                path, self, depth, tree_part_types, signature, param_types
            )
        )

        modules = [row[0] for row in rows]
        rows = [row[1:] for row in rows]

        if call_info is not None:
            for module, row in zip(modules, rows):
                if module in call_info:
                    inputs, outputs = call_info[module]
                    simplified_inputs = (
                        inputs.args[0]
                        if len(inputs.kwargs) == 0 and len(inputs.args) == 1
                        else inputs.kwargs
                        if len(inputs.kwargs) == 0
                        else inputs.kwargs
                        if len(inputs.args) == 0
                        else (inputs.args, inputs.kwargs)
                    )

                    inputs_repr = utils._format_param_tree(simplified_inputs)
                    outputs_repr = utils._format_param_tree(outputs)
                else:
                    inputs_repr = ""
                    outputs_repr = ""

                row.insert(3, outputs_repr)
                row.insert(3, inputs_repr)

        n_non_treepart_cols = 2 if call_info is None else 4

        rows[0][0] = "*"
        rows.append(
            [""] * n_non_treepart_cols
            + ["Total:"]
            + [
                utils._format_obj_size(self.filter(kind), add_padding=True)
                for kind in tree_part_types
            ]
        )
        utils._add_padding(rows)

        table = Table(
            show_header=True,
            show_lines=True,
            show_footer=True,
            # box=rich.box.HORIZONTALS,
        )

        table.add_column("path")
        table.add_column("module")
        table.add_column("params")

        if call_info is not None:
            table.add_column("inputs")
            table.add_column("outputs")

        for tree_part_type in tree_part_types:
            type_name = tree_part_type.__name__
            if type_name.startswith("_"):
                type_name = type_name[1:]

            table.add_column(type_name)

        for row in rows[:-1]:
            table.add_row(*row)

        table.columns[n_non_treepart_cols].footer = Text.from_markup(
            rows[-1][n_non_treepart_cols], justify="right"
        )

        for i in range(len(tree_part_types)):
            table.columns[n_non_treepart_cols + 1 + i].footer = rows[-1][
                n_non_treepart_cols + 1 + i
            ]

        table.caption_style = "bold"
        table.caption = "\nTotal Parameters: " + utils._format_obj_size(
            self, add_padding=False
        )

        return utils._get_rich_repr(table)

__init_subclass__() classmethod special

This method is called when a class is subclassed.

The default implementation does nothing. It may be overridden to extend subclasses.

Source code in treex/module.py
def __init_subclass__(cls):
    if issubclass(cls, tp.Callable):
        orig_call = cls.__call__

        @functools.wraps(cls.__call__)
        def new_call(self: Module, *args, **kwargs):
            outputs = orig_call(self, *args, **kwargs)

            if (
                contexts._CONTEXT.call_info is not None
                and self not in contexts._CONTEXT.call_info
            ):
                inputs = types.Inputs(*args, **kwargs)
                contexts._CONTEXT.call_info[self] = (inputs, outputs)

            return outputs

        cls.__call__ = new_call

    return super().__init_subclass__()

init(self, key, inputs=<treeo.types.Missing object at 0x7f8857a28210>, call_method='__call__', *, inplace=False, _set_initialize=True)

Method version of tx.init, it applies self as first argument.

init creates a new module with the same structure, but with its fields initialized given a seed key. The following procedure is used:

  1. The input key is split and iteratively updated before passing a derived value to any process that requires initialization.
  2. Initializers are called and applied to the module first.
  3. Module.rng_init methods are called last.

Parameters:

Name Type Description Default
key Union[int, jax._src.numpy.lax_numpy.ndarray]

The seed to use for initialization.

required

Returns:

Type Description
~M

The new module with the fields initialized.

Source code in treex/module.py
def init(
    self: M,
    key: tp.Union[int, jnp.ndarray],
    inputs: types.InputLike = to.MISSING,
    call_method: str = "__call__",
    *,
    inplace: bool = False,
    _set_initialize: bool = True,
) -> M:
    """
    Method version of `tx.init`, it applies `self` as first argument.

    `init` creates a new module with the same structure, but with its fields initialized given a seed `key`. The following
    procedure is used:

    1. The input `key` is split and iteratively updated before passing a derived value to any
        process that requires initialization.
    2. `Initializer`s are called and applied to the module first.
    3. `Module.rng_init` methods are called last.

    Arguments:
        key: The seed to use for initialization.
    Returns:
        The new module with the fields initialized.
    """
    module = self.copy() if not inplace else self
    key = utils.Key(key)

    with _INIT_CONTEXT.update(key=key, initializing=True):

        module: M = module.map(
            lambda initializer: (
                initializer(next_key())
                if isinstance(initializer, types.Initializer)
                else initializer
            ),
            is_leaf=lambda x: isinstance(x, types.Initializer),
            inplace=True,
        )

        def call_rng_init(module: Module):
            if isinstance(module, Module) and not module._initialized:
                module.rng_init()

        module = to.apply(call_rng_init, module, inplace=True)

        if inputs is not to.MISSING:
            inputs = types.Inputs.from_value(inputs)
            method = getattr(module, call_method)
            method(*inputs.args, **inputs.kwargs)

    if _set_initialize:

        def set_initialized(module: Module):
            if isinstance(module, Module) and not module._initialized:
                module._initialized = True

        module = to.apply(set_initialized, module, inplace=True)

    return module

tabulate(self, inputs=<treeo.types.Missing object at 0x7f8857a28210>, depth=-1, signature=False, param_types=True)

Returns a tabular representation of the module.

Parameters:

Name Type Description Default
depth int

The maximum depth of the representation in terms of nested Modules, -1 means no limit.

-1
signature bool

Whether to show the signature of the Module.

False
param_types bool

Whether to show the types of the parameters.

True

Returns:

Type Description
str

A string containing the tabular representation.

Source code in treex/module.py
def tabulate(
    self,
    inputs: tp.Union[types.InputLike, to.Missing] = to.MISSING,
    depth: int = -1,
    signature: bool = False,
    param_types: bool = True,
) -> str:
    """
    Returns a tabular representation of the module.

    Arguments:
        depth: The maximum depth of the representation in terms of nested Modules, -1 means no limit.
        signature: Whether to show the signature of the Module.
        param_types: Whether to show the types of the parameters.
    Returns:
        A string containing the tabular representation.
    """
    self = to.copy(self)

    if inputs is not to.MISSING:
        inputs = types.Inputs.from_value(inputs)

        if not isinstance(self, tp.Callable):
            raise TypeError(
                "`inputs` can only be specified if the module is a callable."
            )

        with contexts._Context(call_info={}):

            # call using self to preserve references
            def eval_call(args, kwargs):
                assert isinstance(self, tp.Callable)
                return self(*args, **kwargs)

            jax.eval_shape(
                eval_call,
                inputs.args,
                inputs.kwargs,
            )
            call_info = contexts._CONTEXT.call_info

    else:
        call_info = None

    with to.add_field_info():
        flat: tp.List[to.FieldInfo]
        flat, _ = jax.tree_flatten(self)
        tree_part_types: tp.Tuple[tp.Type[types.TreePart], ...] = tuple(
            {
                field_info.kind
                for field_info in flat
                if utils._generic_issubclass(field_info.kind, types.TreePart)
            }
        )

    path = ()
    rows = list(
        utils._get_tabulate_rows(
            path, self, depth, tree_part_types, signature, param_types
        )
    )

    modules = [row[0] for row in rows]
    rows = [row[1:] for row in rows]

    if call_info is not None:
        for module, row in zip(modules, rows):
            if module in call_info:
                inputs, outputs = call_info[module]
                simplified_inputs = (
                    inputs.args[0]
                    if len(inputs.kwargs) == 0 and len(inputs.args) == 1
                    else inputs.kwargs
                    if len(inputs.kwargs) == 0
                    else inputs.kwargs
                    if len(inputs.args) == 0
                    else (inputs.args, inputs.kwargs)
                )

                inputs_repr = utils._format_param_tree(simplified_inputs)
                outputs_repr = utils._format_param_tree(outputs)
            else:
                inputs_repr = ""
                outputs_repr = ""

            row.insert(3, outputs_repr)
            row.insert(3, inputs_repr)

    n_non_treepart_cols = 2 if call_info is None else 4

    rows[0][0] = "*"
    rows.append(
        [""] * n_non_treepart_cols
        + ["Total:"]
        + [
            utils._format_obj_size(self.filter(kind), add_padding=True)
            for kind in tree_part_types
        ]
    )
    utils._add_padding(rows)

    table = Table(
        show_header=True,
        show_lines=True,
        show_footer=True,
        # box=rich.box.HORIZONTALS,
    )

    table.add_column("path")
    table.add_column("module")
    table.add_column("params")

    if call_info is not None:
        table.add_column("inputs")
        table.add_column("outputs")

    for tree_part_type in tree_part_types:
        type_name = tree_part_type.__name__
        if type_name.startswith("_"):
            type_name = type_name[1:]

        table.add_column(type_name)

    for row in rows[:-1]:
        table.add_row(*row)

    table.columns[n_non_treepart_cols].footer = Text.from_markup(
        rows[-1][n_non_treepart_cols], justify="right"
    )

    for i in range(len(tree_part_types)):
        table.columns[n_non_treepart_cols + 1 + i].footer = rows[-1][
            n_non_treepart_cols + 1 + i
        ]

    table.caption_style = "bold"
    table.caption = "\nTotal Parameters: " + utils._format_obj_size(
        self, add_padding=False
    )

    return utils._get_rich_repr(table)