Skip to content

treex.Sequential

A Module that applies a sequence of Modules or functions in order.

Examples:

mlp = tx.Sequential(
    tx.Linear(2, 32),
    jax.nn.relu,
    tx.Linear(32, 8),
    jax.nn.relu,
    tx.Linear(8, 4),
).init(42)

x = np.random.uniform(size=(10, 2))
y = mlp(x)

assert y.shape == (10, 4)

__init__(self, *layers) special

Parameters:

Name Type Description Default
*layers Union[Callable[..., jax._src.numpy.lax_numpy.ndarray], Callable[[jax._src.numpy.lax_numpy.ndarray], jax._src.numpy.lax_numpy.ndarray]]

A list of layers or callables to apply to apply in sequence.

()
Source code in treex/nn/sequential.py
def __init__(
    self, *layers: tp.Union[CallableModule, tp.Callable[[jnp.ndarray], jnp.ndarray]]
):
    """
    Arguments:
        *layers: A list of layers or callables to apply to apply in sequence.
    """

    self.layers = [
        layer if isinstance(layer, Module) else Lambda(layer) for layer in layers
    ]