treex.nn.sequence
Creates a function that applies a sequence of callables to an input.
Examples:
class Block(tx.Module):
linear: tx.Linear
batch_norm: tx.BatchNorm
dropout: tx.Dropout
...
def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
return tx.sequence(
self.linear,
self.batch_norm,
self.dropout,
jax.nn.relu,
)(x)
Parameters:
Name | Type | Description | Default |
---|---|---|---|
*layers |
Callable[..., jax._src.numpy.lax_numpy.ndarray] |
A sequence of callables to apply. |
() |
Source code in treex/nn/sequential.py
def sequence(*layers: CallableModule) -> CallableModule:
"""
Creates a function that applies a sequence of callables to an input.
Example:
```python
class Block(tx.Module):
linear: tx.Linear
batch_norm: tx.BatchNorm
dropout: tx.Dropout
...
def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
return tx.sequence(
self.linear,
self.batch_norm,
self.dropout,
jax.nn.relu,
)(x)
```
Arguments:
*layers: A sequence of callables to apply.
"""
def _sequence(x: jnp.ndarray) -> jnp.ndarray:
for layer in layers:
x = layer(x)
return x
return _sequence