Skip to content

treex.nn.sequence

Creates a function that applies a sequence of callables to an input.

Examples:

class Block(tx.Module):
    linear: tx.Linear
    batch_norm: tx.BatchNorm
    dropout: tx.Dropout
    ...

    def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
        return tx.sequence(
            self.linear,
            self.batch_norm,
            self.dropout,
            jax.nn.relu,
        )(x)

Parameters:

Name Type Description Default
*layers Callable[..., jax._src.numpy.lax_numpy.ndarray]

A sequence of callables to apply.

()
Source code in treex/nn/sequential.py
def sequence(*layers: CallableModule) -> CallableModule:
    """
    Creates a function that applies a sequence of callables to an input.

    Example:
    ```python
    class Block(tx.Module):
        linear: tx.Linear
        batch_norm: tx.BatchNorm
        dropout: tx.Dropout
        ...

        def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
            return tx.sequence(
                self.linear,
                self.batch_norm,
                self.dropout,
                jax.nn.relu,
            )(x)

    ```

    Arguments:
        *layers: A sequence of callables to apply.
    """

    def _sequence(x: jnp.ndarray) -> jnp.ndarray:
        for layer in layers:
            x = layer(x)
        return x

    return _sequence