treex.nn.Sequential
A Module that applies a sequence of Modules or functions in order.
Examples:
mlp = tx.Sequential(
tx.Linear(2, 32),
jax.nn.relu,
tx.Linear(32, 8),
jax.nn.relu,
tx.Linear(8, 4),
).init(42)
x = np.random.uniform(size=(10, 2))
y = mlp(x)
assert y.shape == (10, 4)
Source code in treex/nn/sequential.py
class Sequential(Module):
"""
A Module that applies a sequence of Modules or functions in order.
Example:
```python
mlp = tx.Sequential(
tx.Linear(2, 32),
jax.nn.relu,
tx.Linear(32, 8),
jax.nn.relu,
tx.Linear(8, 4),
).init(42)
x = np.random.uniform(size=(10, 2))
y = mlp(x)
assert y.shape == (10, 4)
```
"""
layers: tp.List[CallableModule] = to.node()
def __init__(
self, *layers: tp.Union[CallableModule, tp.Callable[[jnp.ndarray], jnp.ndarray]]
):
"""
Arguments:
*layers: A list of layers or callables to apply to apply in sequence.
"""
self.layers = [
layer if isinstance(layer, Module) else Lambda(layer) for layer in layers
]
def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
for layer in self.layers:
x = layer(x)
return x
__init__(self, *layers)
special
Parameters:
Name | Type | Description | Default |
---|---|---|---|
*layers |
Union[Callable[..., jax._src.numpy.lax_numpy.ndarray], Callable[[jax._src.numpy.lax_numpy.ndarray], jax._src.numpy.lax_numpy.ndarray]] |
A list of layers or callables to apply to apply in sequence. |
() |
Source code in treex/nn/sequential.py
def __init__(
self, *layers: tp.Union[CallableModule, tp.Callable[[jnp.ndarray], jnp.ndarray]]
):
"""
Arguments:
*layers: A list of layers or callables to apply to apply in sequence.
"""
self.layers = [
layer if isinstance(layer, Module) else Lambda(layer) for layer in layers
]