treex.Module
Source code in treex/module.py
class Module(Treex, Filters, metaclass=ModuleMeta):
# use to.field to copy class vars to instance
_training: bool = to.static(True)
_initialized: bool = to.static(False)
_frozen: bool = to.static(False)
def __init__(self, name: tp.Optional[str] = None):
self.name = (
name
if name is not None
else utils._lower_snake_case(self.__class__.__name__)
)
def initializing(self) -> bool:
if not self.initialized:
if not _INIT_CONTEXT.initializing:
raise RuntimeError(
f"Trying run {self.__class__.__name__} for the first time outside of `init`"
)
return True
return False
@property
def initialized(self) -> bool:
return self._initialized
@property
def training(self) -> bool:
return self._training
@property
def frozen(self) -> bool:
return self._frozen
def __init_subclass__(cls):
if issubclass(cls, tp.Callable):
orig_call = cls.__call__
@functools.wraps(cls.__call__)
def new_call(self: Module, *args, **kwargs):
outputs = orig_call(self, *args, **kwargs)
if (
contexts._CONTEXT.call_info is not None
and self not in contexts._CONTEXT.call_info
):
inputs = types.Inputs(*args, **kwargs)
contexts._CONTEXT.call_info[self] = (inputs, outputs)
return outputs
cls.__call__ = new_call
return super().__init_subclass__()
def init(
self: M,
key: tp.Union[int, jnp.ndarray],
inputs: types.InputLike = to.MISSING,
call_method: str = "__call__",
*,
inplace: bool = False,
_set_initialize: bool = True,
) -> M:
"""
Method version of `tx.init`, it applies `self` as first argument.
`init` creates a new module with the same structure, but with its fields initialized given a seed `key`. The following
procedure is used:
1. The input `key` is split and iteratively updated before passing a derived value to any
process that requires initialization.
2. `Initializer`s are called and applied to the module first.
3. `Module.rng_init` methods are called last.
Arguments:
key: The seed to use for initialization.
Returns:
The new module with the fields initialized.
"""
module = self.copy() if not inplace else self
key = utils.Key(key)
with _INIT_CONTEXT.update(key=key, initializing=True):
module: M = module.map(
lambda initializer: (
initializer(next_key())
if isinstance(initializer, types.Initializer)
else initializer
),
is_leaf=lambda x: isinstance(x, types.Initializer),
inplace=True,
)
def call_rng_init(module: Module):
if isinstance(module, Module) and not module._initialized:
module.rng_init()
module = to.apply(call_rng_init, module, inplace=True)
if inputs is not to.MISSING:
inputs = types.Inputs.from_value(inputs)
method = getattr(module, call_method)
method(*inputs.args, **inputs.kwargs)
if _set_initialize:
def set_initialized(module: Module):
if isinstance(module, Module) and not module._initialized:
module._initialized = True
module = to.apply(set_initialized, module, inplace=True)
return module
def rng_init(self) -> None:
pass
def tabulate(
self,
inputs: tp.Union[types.InputLike, to.Missing] = to.MISSING,
depth: int = -1,
signature: bool = False,
param_types: bool = True,
) -> str:
"""
Returns a tabular representation of the module.
Arguments:
depth: The maximum depth of the representation in terms of nested Modules, -1 means no limit.
signature: Whether to show the signature of the Module.
param_types: Whether to show the types of the parameters.
Returns:
A string containing the tabular representation.
"""
self = to.copy(self)
if inputs is not to.MISSING:
inputs = types.Inputs.from_value(inputs)
if not isinstance(self, tp.Callable):
raise TypeError(
"`inputs` can only be specified if the module is a callable."
)
with contexts._Context(call_info={}):
# call using self to preserve references
def eval_call(args, kwargs):
assert isinstance(self, tp.Callable)
return self(*args, **kwargs)
jax.eval_shape(
eval_call,
inputs.args,
inputs.kwargs,
)
call_info = contexts._CONTEXT.call_info
else:
call_info = None
with to.add_field_info():
flat: tp.List[to.FieldInfo]
flat, _ = jax.tree_flatten(self)
tree_part_types: tp.Tuple[tp.Type[types.TreePart], ...] = tuple(
{
field_info.kind
for field_info in flat
if utils._generic_issubclass(field_info.kind, types.TreePart)
}
)
path = ()
rows = list(
utils._get_tabulate_rows(
path, self, depth, tree_part_types, signature, param_types
)
)
modules = [row[0] for row in rows]
rows = [row[1:] for row in rows]
if call_info is not None:
for module, row in zip(modules, rows):
if module in call_info:
inputs, outputs = call_info[module]
simplified_inputs = (
inputs.args[0]
if len(inputs.kwargs) == 0 and len(inputs.args) == 1
else inputs.kwargs
if len(inputs.kwargs) == 0
else inputs.kwargs
if len(inputs.args) == 0
else (inputs.args, inputs.kwargs)
)
inputs_repr = utils._format_param_tree(simplified_inputs)
outputs_repr = utils._format_param_tree(outputs)
else:
inputs_repr = ""
outputs_repr = ""
row.insert(3, outputs_repr)
row.insert(3, inputs_repr)
n_non_treepart_cols = 2 if call_info is None else 4
rows[0][0] = "*"
rows.append(
[""] * n_non_treepart_cols
+ ["Total:"]
+ [
utils._format_obj_size(self.filter(kind), add_padding=True)
for kind in tree_part_types
]
)
utils._add_padding(rows)
table = Table(
show_header=True,
show_lines=True,
show_footer=True,
# box=rich.box.HORIZONTALS,
)
table.add_column("path")
table.add_column("module")
table.add_column("params")
if call_info is not None:
table.add_column("inputs")
table.add_column("outputs")
for tree_part_type in tree_part_types:
type_name = tree_part_type.__name__
if type_name.startswith("_"):
type_name = type_name[1:]
table.add_column(type_name)
for row in rows[:-1]:
table.add_row(*row)
table.columns[n_non_treepart_cols].footer = Text.from_markup(
rows[-1][n_non_treepart_cols], justify="right"
)
for i in range(len(tree_part_types)):
table.columns[n_non_treepart_cols + 1 + i].footer = rows[-1][
n_non_treepart_cols + 1 + i
]
table.caption_style = "bold"
table.caption = "\nTotal Parameters: " + utils._format_obj_size(
self, add_padding=False
)
return utils._get_rich_repr(table)
__init_subclass__()
classmethod
special
This method is called when a class is subclassed.
The default implementation does nothing. It may be overridden to extend subclasses.
Source code in treex/module.py
def __init_subclass__(cls):
if issubclass(cls, tp.Callable):
orig_call = cls.__call__
@functools.wraps(cls.__call__)
def new_call(self: Module, *args, **kwargs):
outputs = orig_call(self, *args, **kwargs)
if (
contexts._CONTEXT.call_info is not None
and self not in contexts._CONTEXT.call_info
):
inputs = types.Inputs(*args, **kwargs)
contexts._CONTEXT.call_info[self] = (inputs, outputs)
return outputs
cls.__call__ = new_call
return super().__init_subclass__()
init(self, key, inputs=<treeo.types.Missing object at 0x7f8857a28210>, call_method='__call__', *, inplace=False, _set_initialize=True)
Method version of tx.init
, it applies self
as first argument.
init
creates a new module with the same structure, but with its fields initialized given a seed key
. The following
procedure is used:
- The input
key
is split and iteratively updated before passing a derived value to any process that requires initialization. Initializer
s are called and applied to the module first.Module.rng_init
methods are called last.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
key |
Union[int, jax._src.numpy.lax_numpy.ndarray] |
The seed to use for initialization. |
required |
Returns:
Type | Description |
---|---|
~M |
The new module with the fields initialized. |
Source code in treex/module.py
def init(
self: M,
key: tp.Union[int, jnp.ndarray],
inputs: types.InputLike = to.MISSING,
call_method: str = "__call__",
*,
inplace: bool = False,
_set_initialize: bool = True,
) -> M:
"""
Method version of `tx.init`, it applies `self` as first argument.
`init` creates a new module with the same structure, but with its fields initialized given a seed `key`. The following
procedure is used:
1. The input `key` is split and iteratively updated before passing a derived value to any
process that requires initialization.
2. `Initializer`s are called and applied to the module first.
3. `Module.rng_init` methods are called last.
Arguments:
key: The seed to use for initialization.
Returns:
The new module with the fields initialized.
"""
module = self.copy() if not inplace else self
key = utils.Key(key)
with _INIT_CONTEXT.update(key=key, initializing=True):
module: M = module.map(
lambda initializer: (
initializer(next_key())
if isinstance(initializer, types.Initializer)
else initializer
),
is_leaf=lambda x: isinstance(x, types.Initializer),
inplace=True,
)
def call_rng_init(module: Module):
if isinstance(module, Module) and not module._initialized:
module.rng_init()
module = to.apply(call_rng_init, module, inplace=True)
if inputs is not to.MISSING:
inputs = types.Inputs.from_value(inputs)
method = getattr(module, call_method)
method(*inputs.args, **inputs.kwargs)
if _set_initialize:
def set_initialized(module: Module):
if isinstance(module, Module) and not module._initialized:
module._initialized = True
module = to.apply(set_initialized, module, inplace=True)
return module
tabulate(self, inputs=<treeo.types.Missing object at 0x7f8857a28210>, depth=-1, signature=False, param_types=True)
Returns a tabular representation of the module.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
depth |
int |
The maximum depth of the representation in terms of nested Modules, -1 means no limit. |
-1 |
signature |
bool |
Whether to show the signature of the Module. |
False |
param_types |
bool |
Whether to show the types of the parameters. |
True |
Returns:
Type | Description |
---|---|
str |
A string containing the tabular representation. |
Source code in treex/module.py
def tabulate(
self,
inputs: tp.Union[types.InputLike, to.Missing] = to.MISSING,
depth: int = -1,
signature: bool = False,
param_types: bool = True,
) -> str:
"""
Returns a tabular representation of the module.
Arguments:
depth: The maximum depth of the representation in terms of nested Modules, -1 means no limit.
signature: Whether to show the signature of the Module.
param_types: Whether to show the types of the parameters.
Returns:
A string containing the tabular representation.
"""
self = to.copy(self)
if inputs is not to.MISSING:
inputs = types.Inputs.from_value(inputs)
if not isinstance(self, tp.Callable):
raise TypeError(
"`inputs` can only be specified if the module is a callable."
)
with contexts._Context(call_info={}):
# call using self to preserve references
def eval_call(args, kwargs):
assert isinstance(self, tp.Callable)
return self(*args, **kwargs)
jax.eval_shape(
eval_call,
inputs.args,
inputs.kwargs,
)
call_info = contexts._CONTEXT.call_info
else:
call_info = None
with to.add_field_info():
flat: tp.List[to.FieldInfo]
flat, _ = jax.tree_flatten(self)
tree_part_types: tp.Tuple[tp.Type[types.TreePart], ...] = tuple(
{
field_info.kind
for field_info in flat
if utils._generic_issubclass(field_info.kind, types.TreePart)
}
)
path = ()
rows = list(
utils._get_tabulate_rows(
path, self, depth, tree_part_types, signature, param_types
)
)
modules = [row[0] for row in rows]
rows = [row[1:] for row in rows]
if call_info is not None:
for module, row in zip(modules, rows):
if module in call_info:
inputs, outputs = call_info[module]
simplified_inputs = (
inputs.args[0]
if len(inputs.kwargs) == 0 and len(inputs.args) == 1
else inputs.kwargs
if len(inputs.kwargs) == 0
else inputs.kwargs
if len(inputs.args) == 0
else (inputs.args, inputs.kwargs)
)
inputs_repr = utils._format_param_tree(simplified_inputs)
outputs_repr = utils._format_param_tree(outputs)
else:
inputs_repr = ""
outputs_repr = ""
row.insert(3, outputs_repr)
row.insert(3, inputs_repr)
n_non_treepart_cols = 2 if call_info is None else 4
rows[0][0] = "*"
rows.append(
[""] * n_non_treepart_cols
+ ["Total:"]
+ [
utils._format_obj_size(self.filter(kind), add_padding=True)
for kind in tree_part_types
]
)
utils._add_padding(rows)
table = Table(
show_header=True,
show_lines=True,
show_footer=True,
# box=rich.box.HORIZONTALS,
)
table.add_column("path")
table.add_column("module")
table.add_column("params")
if call_info is not None:
table.add_column("inputs")
table.add_column("outputs")
for tree_part_type in tree_part_types:
type_name = tree_part_type.__name__
if type_name.startswith("_"):
type_name = type_name[1:]
table.add_column(type_name)
for row in rows[:-1]:
table.add_row(*row)
table.columns[n_non_treepart_cols].footer = Text.from_markup(
rows[-1][n_non_treepart_cols], justify="right"
)
for i in range(len(tree_part_types)):
table.columns[n_non_treepart_cols + 1 + i].footer = rows[-1][
n_non_treepart_cols + 1 + i
]
table.caption_style = "bold"
table.caption = "\nTotal Parameters: " + utils._format_obj_size(
self, add_padding=False
)
return utils._get_rich_repr(table)