Skip to content

treex.Sequential

A Module that applies a sequence of Modules or functions in order.

Examples:

mlp = tx.Sequential(
    tx.Linear(2, 32),
    jax.nn.relu,
    tx.Linear(32, 8),
    jax.nn.relu,
    tx.Linear(8, 4),
).init(42)

x = np.random.uniform(size=(10, 2))
y = mlp(x)

assert y.shape == (10, 4)
Source code in treex/nn/sequential.py
class Sequential(Module):
    """
    A Module that applies a sequence of Modules or functions in order.

    Example:

    ```python
    mlp = tx.Sequential(
        tx.Linear(2, 32),
        jax.nn.relu,
        tx.Linear(32, 8),
        jax.nn.relu,
        tx.Linear(8, 4),
    ).init(42)

    x = np.random.uniform(size=(10, 2))
    y = mlp(x)

    assert y.shape == (10, 4)
    ```
    """

    layers: tp.List[CallableModule] = to.node()

    def __init__(
        self, *layers: tp.Union[CallableModule, tp.Callable[[jnp.ndarray], jnp.ndarray]]
    ):
        """
        Arguments:
            *layers: A list of layers or callables to apply to apply in sequence.
        """

        self.layers = [
            layer if isinstance(layer, Module) else Lambda(layer) for layer in layers
        ]

    def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
        for layer in self.layers:
            x = layer(x)
        return x

__init__(self, *layers) special

Parameters:

Name Type Description Default
*layers Union[Callable[..., jax._src.numpy.lax_numpy.ndarray], Callable[[jax._src.numpy.lax_numpy.ndarray], jax._src.numpy.lax_numpy.ndarray]]

A list of layers or callables to apply to apply in sequence.

()
Source code in treex/nn/sequential.py
def __init__(
    self, *layers: tp.Union[CallableModule, tp.Callable[[jnp.ndarray], jnp.ndarray]]
):
    """
    Arguments:
        *layers: A list of layers or callables to apply to apply in sequence.
    """

    self.layers = [
        layer if isinstance(layer, Module) else Lambda(layer) for layer in layers
    ]