treex.next_key
Returns the next key.
Returns:
Type | Description |
---|---|
ndarray |
The next key. |
Source code in treex/module.py
def next_key(*, axis_name: tp.Optional[tp.Any] = None) -> jnp.ndarray:
"""
Returns the next key.
Returns:
The next key.
"""
key: jnp.ndarray
if _INIT_CONTEXT.key is None:
raise RuntimeError(
"RNG key not set, you are either calling an uninitialized Module outside `.init` or forgot to call `rng_key` context manager."
)
key, _INIT_CONTEXT.key = utils.iter_split(_INIT_CONTEXT.key)
if axis_name is not None:
axis_index = jax.lax.axis_index(axis_name)
key = jax.random.fold_in(key, axis_index)
return key