Skip to content

treex.KeySeq

KeySeq is simple module that can produce a sequence of PRNGKeys.

Examples:

class Dropout(Module):
    rng: KeySeq()

    def __init__(self, rate: float):
        self.next_key = KeySeq()
        ...

    def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
        key = self.next_key()
        mask = jax.random.bernoulli(key, 1.0 - self.rate)
        ...

__call__(self, *, axis_name=None) special

Return a new PRNGKey and updates the internal rng state.

Returns:

Type Description
ndarray

A PRNGKey.

Source code in treex/key_seq.py
def __call__(self, *, axis_name: tp.Optional[tp.Any] = None) -> jnp.ndarray:
    """
    Return a new PRNGKey and updates the internal rng state.

    Returns:
        A PRNGKey.
    """
    key: jnp.ndarray

    assert isinstance(self.key, jnp.ndarray)
    key, self.key = utils.iter_split(self.key)

    if axis_name is None:
        axis_name = self.axis_name

    if axis_name is not None:
        axis_index = jax.lax.axis_index(axis_name)
        key = jax.random.fold_in(key, axis_index)

    return key

__init__(self, key=None, *, axis_name=None) special

Parameters:

Name Type Description Default
key Union[jax._src.numpy.lax_numpy.ndarray, int]

An optional PRNGKey to initialize the KeySeq with.

None
Source code in treex/key_seq.py
def __init__(
    self,
    key: tp.Optional[tp.Union[jnp.ndarray, int]] = None,
    *,
    axis_name: tp.Optional[tp.Any] = None
):
    """
    Arguments:
        key: An optional PRNGKey to initialize the KeySeq with.
    """

    self.key = (
        utils.Key(key)
        if isinstance(key, int)
        else key
        if isinstance(key, (jnp.ndarray, np.ndarray))
        else types.Initializer(lambda key: key)
    )
    self.axis_name = axis_name