Skip to content

treex.KeySeq

KeySeq is simple module that can produce a sequence of PRNGKeys.

Examples:

class Dropout(Module):
    rng: KeySeq()

    def __init__(self, rate: float):
        self.next_key = KeySeq()
        ...

    def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
        key = self.next_key()
        mask = jax.random.bernoulli(key, 1.0 - self.rate)
        ...
Source code in treex/key_seq.py
class KeySeq(Module):
    """KeySeq is simple module that can produce a sequence of PRNGKeys.

    Example:
    ```python
    class Dropout(Module):
        rng: KeySeq()

        def __init__(self, rate: float):
            self.next_key = KeySeq()
            ...

        def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
            key = self.next_key()
            mask = jax.random.bernoulli(key, 1.0 - self.rate)
            ...
    ```
    """

    key: tp.Union[types.Initializer, jnp.ndarray] = types.Rng.node()

    def __init__(
        self,
        key: tp.Optional[tp.Union[jnp.ndarray, int]] = None,
        *,
        axis_name: tp.Optional[tp.Any] = None
    ):
        """
        Arguments:
            key: An optional PRNGKey to initialize the KeySeq with.
        """

        self.key = (
            utils.Key(key)
            if isinstance(key, int)
            else key
            if isinstance(key, (jnp.ndarray, np.ndarray))
            else types.Initializer(lambda key: key)
        )
        self.axis_name = axis_name

    def __call__(self, *, axis_name: tp.Optional[tp.Any] = None) -> jnp.ndarray:
        """
        Return a new PRNGKey and updates the internal rng state.

        Returns:
            A PRNGKey.
        """
        key: jnp.ndarray

        assert isinstance(self.key, jnp.ndarray)
        key, self.key = utils.iter_split(self.key)

        if axis_name is None:
            axis_name = self.axis_name

        if axis_name is not None:
            axis_index = jax.lax.axis_index(axis_name)
            key = jax.random.fold_in(key, axis_index)

        return key

__call__(self, *, axis_name=None) special

Return a new PRNGKey and updates the internal rng state.

Returns:

Type Description
ndarray

A PRNGKey.

Source code in treex/key_seq.py
def __call__(self, *, axis_name: tp.Optional[tp.Any] = None) -> jnp.ndarray:
    """
    Return a new PRNGKey and updates the internal rng state.

    Returns:
        A PRNGKey.
    """
    key: jnp.ndarray

    assert isinstance(self.key, jnp.ndarray)
    key, self.key = utils.iter_split(self.key)

    if axis_name is None:
        axis_name = self.axis_name

    if axis_name is not None:
        axis_index = jax.lax.axis_index(axis_name)
        key = jax.random.fold_in(key, axis_index)

    return key

__init__(self, key=None, *, axis_name=None) special

Parameters:

Name Type Description Default
key Union[jax._src.numpy.lax_numpy.ndarray, int]

An optional PRNGKey to initialize the KeySeq with.

None
Source code in treex/key_seq.py
def __init__(
    self,
    key: tp.Optional[tp.Union[jnp.ndarray, int]] = None,
    *,
    axis_name: tp.Optional[tp.Any] = None
):
    """
    Arguments:
        key: An optional PRNGKey to initialize the KeySeq with.
    """

    self.key = (
        utils.Key(key)
        if isinstance(key, int)
        else key
        if isinstance(key, (jnp.ndarray, np.ndarray))
        else types.Initializer(lambda key: key)
    )
    self.axis_name = axis_name