Skip to content

treex.MLP

A Multi-Layer Perceptron (MLP) that applies a sequence of linear layers with a given activation (relu by default), the last layer is linear.

__call__(self, x) special

Applies the MLP to the input.

Parameters:

Name Type Description Default
x ndarray

input array.

required

Returns:

Type Description
ndarray

The output of the MLP.

Source code in treex/nn/mlp.py
@to.compact
def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
    """
    Applies the MLP to the input.

    Arguments:
        x: input array.

    Returns:
        The output of the MLP.
    """
    last_layer_idx = len(self.features) - 1

    for i, features_out in enumerate(self.features):
        x = Linear(
            features_out=features_out,
            use_bias=self.use_bias,
            dtype=self.dtype,
            precision=self.precision,
            kernel_init=self.kernel_init,
            bias_init=self.bias_init,
        )(x)

        if i < last_layer_idx:
            x = self.activation(x)

    return x

__init__(self, features, activation=<jax._src.custom_derivatives.custom_jvp object at 0x7f0f5aab0890>, use_bias=True, dtype=<class 'jax._src.numpy.lax_numpy.float32'>, precision=None, kernel_init=<function variance_scaling.<locals>.init at 0x7f0f519b4200>, bias_init=<function zeros at 0x7f0f5aad27a0>) special

Parameters:

Name Type Description Default
features Sequence[int]

a sequence of L+1 integers, where L is the number of layers, the first integer is the number of input features and all subsequent integers are the number of output features of the respective layer.

required
activation Callable[[jax._src.numpy.lax_numpy.ndarray], jax._src.numpy.lax_numpy.ndarray]

the activation function to use.

<jax._src.custom_derivatives.custom_jvp object at 0x7f0f5aab0890>
use_bias bool

whether to add a bias to the output (default: True).

True
dtype Any

the dtype of the computation (default: float32).

<class 'jax._src.numpy.lax_numpy.float32'>
precision Any

numerical precision of the computation see jax.lax.Precision for details.

None
kernel_init Callable[[Any, Iterable[int], Any], Any]

initializer function for the weight matrix.

<function variance_scaling.<locals>.init at 0x7f0f519b4200>
bias_init Callable[[Any, Iterable[int], Any], Any]

initializer function for the bias.

<function zeros at 0x7f0f5aad27a0>
Source code in treex/nn/mlp.py
def __init__(
    self,
    features: tp.Sequence[int],
    activation: tp.Callable[[jnp.ndarray], jnp.ndarray] = jax.nn.relu,
    use_bias: bool = True,
    dtype: tp.Any = jnp.float32,
    precision: tp.Any = None,
    kernel_init: tp.Callable[
        [flax_module.PRNGKey, flax_module.Shape, flax_module.Dtype],
        flax_module.Array,
    ] = flax_module.default_kernel_init,
    bias_init: tp.Callable[
        [flax_module.PRNGKey, flax_module.Shape, flax_module.Dtype],
        flax_module.Array,
    ] = flax_module.zeros,
):
    """
    Arguments:
        features: a sequence of L+1 integers, where L is the number of layers,
            the first integer is the number of input features and all subsequent
            integers are the number of output features of the respective layer.
        activation: the activation function to use.
        use_bias: whether to add a bias to the output (default: True).
        dtype: the dtype of the computation (default: float32).
        precision: numerical precision of the computation see `jax.lax.Precision`
            for details.
        kernel_init: initializer function for the weight matrix.
        bias_init: initializer function for the bias.
    """

    if len(features) == 0:
        raise ValueError("features must have at least 1 element")

    self.features = features
    self.activation = activation
    self.use_bias = use_bias
    self.dtype = dtype
    self.precision = precision
    self.kernel_init = kernel_init
    self.bias_init = bias_init