Skip to content

treex.regularizers.L1L2

A regularizer that applies both L1 and L2 regularization penalties.

The L1 regularization penalty is computed as:

\ell_1\,\,penalty =\ell_1\sum_{i=0}^n|x_i|

The L2 regularization penalty is computed as

\ell_2\,\,penalty =\ell_2\sum_{i=0}^nx_i^2
Source code in treex/regularizers/l1l2.py
class L1L2(Loss):
    r"""
    A regularizer that applies both L1 and L2 regularization penalties.

    The L1 regularization penalty is computed as:

    $$
    \ell_1\,\,penalty =\ell_1\sum_{i=0}^n|x_i|
    $$

    The L2 regularization penalty is computed as

    $$\ell_2\,\,penalty =\ell_2\sum_{i=0}^nx_i^2$$
    """

    def __init__(
        self,
        l1=0.0,
        l2=0.0,
        reduction: tp.Optional[Reduction] = None,
        weight: tp.Optional[float] = None,
        on: tp.Optional[types.IndexLike] = None,
        name: tp.Optional[str] = None,
    ):  # pylint: disable=redefined-outer-name
        super().__init__(reduction=reduction, weight=weight, on=on, name=name)

        self.l1 = l1
        self.l2 = l2

    def call(self, parameters: tp.Any) -> jnp.ndarray:
        """
        Computes the L1 and L2 regularization penalty simultaneously.

        Arguments:
            net_params: A structure with all the parameters of the model.
        """
        regularization: jnp.ndarray = jnp.array(0.0)

        if not self.l1 and not self.l2:
            return regularization

        leaves = jax.tree_leaves(parameters)

        if self.l1:
            regularization += self.l1 * sum(jnp.sum(jnp.abs(p)) for p in leaves)

        if self.l2:
            regularization += self.l2 * sum(jnp.sum(jnp.square(p)) for p in leaves)

        return regularization

call(self, parameters)

Computes the L1 and L2 regularization penalty simultaneously.

Parameters:

Name Type Description Default
net_params

A structure with all the parameters of the model.

required
Source code in treex/regularizers/l1l2.py
def call(self, parameters: tp.Any) -> jnp.ndarray:
    """
    Computes the L1 and L2 regularization penalty simultaneously.

    Arguments:
        net_params: A structure with all the parameters of the model.
    """
    regularization: jnp.ndarray = jnp.array(0.0)

    if not self.l1 and not self.l2:
        return regularization

    leaves = jax.tree_leaves(parameters)

    if self.l1:
        regularization += self.l1 * sum(jnp.sum(jnp.abs(p)) for p in leaves)

    if self.l2:
        regularization += self.l2 * sum(jnp.sum(jnp.square(p)) for p in leaves)

    return regularization