Skip to content

treex.nn.MLP

A Multi-Layer Perceptron (MLP) that applies a sequence of linear layers with a given activation (relu by default), the last layer is linear.

Source code in treex/nn/mlp.py
class MLP(Module):
    """A Multi-Layer Perceptron (MLP) that applies a sequence of linear layers
    with a given activation (relu by default), the last layer is linear.
    """

    # pytree
    layers: tp.List[Linear]

    # props
    features: tp.Sequence[int]
    module: flax_module.Dense

    def __init__(
        self,
        features: tp.Sequence[int],
        activation: tp.Callable[[jnp.ndarray], jnp.ndarray] = jax.nn.relu,
        use_bias: bool = True,
        dtype: tp.Any = jnp.float32,
        precision: tp.Any = None,
        kernel_init: tp.Callable[
            [flax_module.PRNGKey, flax_module.Shape, flax_module.Dtype],
            flax_module.Array,
        ] = flax_module.default_kernel_init,
        bias_init: tp.Callable[
            [flax_module.PRNGKey, flax_module.Shape, flax_module.Dtype],
            flax_module.Array,
        ] = flax_module.zeros,
    ):
        """
        Arguments:
            features: a sequence of L+1 integers, where L is the number of layers,
                the first integer is the number of input features and all subsequent
                integers are the number of output features of the respective layer.
            activation: the activation function to use.
            use_bias: whether to add a bias to the output (default: True).
            dtype: the dtype of the computation (default: float32).
            precision: numerical precision of the computation see `jax.lax.Precision`
                for details.
            kernel_init: initializer function for the weight matrix.
            bias_init: initializer function for the bias.
        """

        if len(features) == 0:
            raise ValueError("features must have at least 1 element")

        self.features = features
        self.activation = activation
        self.use_bias = use_bias
        self.dtype = dtype
        self.precision = precision
        self.kernel_init = kernel_init
        self.bias_init = bias_init

    @to.compact
    def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
        """
        Applies the MLP to the input.

        Arguments:
            x: input array.

        Returns:
            The output of the MLP.
        """
        last_layer_idx = len(self.features) - 1

        for i, features_out in enumerate(self.features):
            x = Linear(
                features_out=features_out,
                use_bias=self.use_bias,
                dtype=self.dtype,
                precision=self.precision,
                kernel_init=self.kernel_init,
                bias_init=self.bias_init,
            )(x)

            if i < last_layer_idx:
                x = self.activation(x)

        return x

__call__(self, x) special

Applies the MLP to the input.

Parameters:

Name Type Description Default
x ndarray

input array.

required

Returns:

Type Description
ndarray

The output of the MLP.

Source code in treex/nn/mlp.py
@to.compact
def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
    """
    Applies the MLP to the input.

    Arguments:
        x: input array.

    Returns:
        The output of the MLP.
    """
    last_layer_idx = len(self.features) - 1

    for i, features_out in enumerate(self.features):
        x = Linear(
            features_out=features_out,
            use_bias=self.use_bias,
            dtype=self.dtype,
            precision=self.precision,
            kernel_init=self.kernel_init,
            bias_init=self.bias_init,
        )(x)

        if i < last_layer_idx:
            x = self.activation(x)

    return x

__init__(self, features, activation=<jax._src.custom_derivatives.custom_jvp object at 0x7f8860570310>, use_bias=True, dtype=<class 'jax._src.numpy.lax_numpy.float32'>, precision=None, kernel_init=<function variance_scaling.<locals>.init at 0x7f8862b6a680>, bias_init=<function zeros at 0x7f8860592290>) special

Parameters:

Name Type Description Default
features Sequence[int]

a sequence of L+1 integers, where L is the number of layers, the first integer is the number of input features and all subsequent integers are the number of output features of the respective layer.

required
activation Callable[[jax._src.numpy.lax_numpy.ndarray], jax._src.numpy.lax_numpy.ndarray]

the activation function to use.

<jax._src.custom_derivatives.custom_jvp object at 0x7f8860570310>
use_bias bool

whether to add a bias to the output (default: True).

True
dtype Any

the dtype of the computation (default: float32).

<class 'jax._src.numpy.lax_numpy.float32'>
precision Any

numerical precision of the computation see jax.lax.Precision for details.

None
kernel_init Callable[[Any, Iterable[int], Any], Any]

initializer function for the weight matrix.

<function variance_scaling.<locals>.init at 0x7f8862b6a680>
bias_init Callable[[Any, Iterable[int], Any], Any]

initializer function for the bias.

<function zeros at 0x7f8860592290>
Source code in treex/nn/mlp.py
def __init__(
    self,
    features: tp.Sequence[int],
    activation: tp.Callable[[jnp.ndarray], jnp.ndarray] = jax.nn.relu,
    use_bias: bool = True,
    dtype: tp.Any = jnp.float32,
    precision: tp.Any = None,
    kernel_init: tp.Callable[
        [flax_module.PRNGKey, flax_module.Shape, flax_module.Dtype],
        flax_module.Array,
    ] = flax_module.default_kernel_init,
    bias_init: tp.Callable[
        [flax_module.PRNGKey, flax_module.Shape, flax_module.Dtype],
        flax_module.Array,
    ] = flax_module.zeros,
):
    """
    Arguments:
        features: a sequence of L+1 integers, where L is the number of layers,
            the first integer is the number of input features and all subsequent
            integers are the number of output features of the respective layer.
        activation: the activation function to use.
        use_bias: whether to add a bias to the output (default: True).
        dtype: the dtype of the computation (default: float32).
        precision: numerical precision of the computation see `jax.lax.Precision`
            for details.
        kernel_init: initializer function for the weight matrix.
        bias_init: initializer function for the bias.
    """

    if len(features) == 0:
        raise ValueError("features must have at least 1 element")

    self.features = features
    self.activation = activation
    self.use_bias = use_bias
    self.dtype = dtype
    self.precision = precision
    self.kernel_init = kernel_init
    self.bias_init = bias_init