Skip to content

treex.next_key

Returns the next key.

Returns:

Type Description
ndarray

The next key.

Source code in treex/module.py
def next_key(*, axis_name: tp.Optional[tp.Any] = None) -> jnp.ndarray:
    """
    Returns the next key.

    Returns:
        The next key.
    """
    key: jnp.ndarray

    if _INIT_CONTEXT.key is None:
        raise RuntimeError(
            "RNG key not set, you are either calling an uninitialized Module outside `.init` or forgot to call `rng_key` context manager."
        )

    key, _INIT_CONTEXT.key = utils.iter_split(_INIT_CONTEXT.key)

    if axis_name is not None:
        axis_index = jax.lax.axis_index(axis_name)
        key = jax.random.fold_in(key, axis_index)

    return key