Skip to content

treex.Dropout

Create a dropout layer.

Dropout is implemented as a wrapper over flax.linen.Dropout, its constructor arguments accept almost the same arguments including any Flax artifacts such as initializers. Main differences:

  • deterministic is not a constructor argument, but remains a __call__ argument.
  • self.training state is used to indicate how Dropout should behave, interally deterministic = not self.training or self.frozen is used unless deterministic is explicitly passed via __call__.
  • Dropout maintains an rng: Rng state which is used to generate random masks unless rng is passed via __call__.

__call__(self, x, deterministic=None, rng=None) special

Applies a random dropout mask to the input.

Parameters:

Name Type Description Default
x ndarray

the inputs that should be randomly masked.

required
deterministic Optional[bool]

if false the inputs are scaled by 1 / (1 - rate) and masked, whereas if true, no mask is applied and the inputs are returned as is.

None
rng

an optional jax.random.PRNGKey. By default self.rng will be used.

None

Returns:

Type Description
ndarray

The masked inputs reweighted to preserve mean.

Source code in treex/nn/dropout.py
def __call__(
    self, x: jnp.ndarray, deterministic: tp.Optional[bool] = None, rng=None
) -> jnp.ndarray:
    """Applies a random dropout mask to the input.

    Arguments:
        x: the inputs that should be randomly masked.
        deterministic: if false the inputs are scaled by `1 / (1 - rate)` and
            masked, whereas if true, no mask is applied and the inputs are returned
            as is.
        rng: an optional `jax.random.PRNGKey`. By default `self.rng` will
            be used.

    Returns:
        The masked inputs reweighted to preserve mean.
    """
    variables = dict()

    training = (
        not deterministic
        if deterministic is not None
        else self.training and not self.frozen
    )

    if rng is None:
        rng = self.next_key() if training else self.next_key.key

    # call apply
    output = self.module.apply(
        variables,
        x,
        deterministic=not training,
        rng=rng,
    )

    return tp.cast(jnp.ndarray, output)

__init__(self, rate, broadcast_dims=()) special

Create a dropout layer.

Parameters:

Name Type Description Default
rate float

the dropout probability. (not the keep rate!)

required
broadcast_dims Iterable[int]

dimensions that will share the same dropout mask

()
Source code in treex/nn/dropout.py
def __init__(
    self,
    rate: float,
    broadcast_dims: tp.Iterable[int] = (),
):
    """
    Create a dropout layer.

    Arguments:
        rate: the dropout probability.  (_not_ the keep rate!)
        broadcast_dims: dimensions that will share the same dropout mask
    """

    self.rate = rate
    self.broadcast_dims = broadcast_dims
    self.next_key = KeySeq()