treex.regularizers.L1L2
A regularizer that applies both L1 and L2 regularization penalties.
The L1 regularization penalty is computed as:
\ell_1\,\,penalty =\ell_1\sum_{i=0}^n|x_i|
The L2 regularization penalty is computed as
\ell_2\,\,penalty =\ell_2\sum_{i=0}^nx_i^2
Source code in treex/regularizers/l1l2.py
class L1L2(Loss):
r"""
A regularizer that applies both L1 and L2 regularization penalties.
The L1 regularization penalty is computed as:
$$
\ell_1\,\,penalty =\ell_1\sum_{i=0}^n|x_i|
$$
The L2 regularization penalty is computed as
$$\ell_2\,\,penalty =\ell_2\sum_{i=0}^nx_i^2$$
"""
def __init__(
self,
l1=0.0,
l2=0.0,
reduction: tp.Optional[Reduction] = None,
weight: tp.Optional[float] = None,
on: tp.Optional[types.IndexLike] = None,
name: tp.Optional[str] = None,
): # pylint: disable=redefined-outer-name
super().__init__(reduction=reduction, weight=weight, on=on, name=name)
self.l1 = l1
self.l2 = l2
def call(self, parameters: tp.Any) -> jnp.ndarray:
"""
Computes the L1 and L2 regularization penalty simultaneously.
Arguments:
net_params: A structure with all the parameters of the model.
"""
regularization: jnp.ndarray = jnp.array(0.0)
if not self.l1 and not self.l2:
return regularization
leaves = jax.tree_leaves(parameters)
if self.l1:
regularization += self.l1 * sum(jnp.sum(jnp.abs(p)) for p in leaves)
if self.l2:
regularization += self.l2 * sum(jnp.sum(jnp.square(p)) for p in leaves)
return regularization
call(self, parameters)
Computes the L1 and L2 regularization penalty simultaneously.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
net_params |
A structure with all the parameters of the model. |
required |
Source code in treex/regularizers/l1l2.py
def call(self, parameters: tp.Any) -> jnp.ndarray:
"""
Computes the L1 and L2 regularization penalty simultaneously.
Arguments:
net_params: A structure with all the parameters of the model.
"""
regularization: jnp.ndarray = jnp.array(0.0)
if not self.l1 and not self.l2:
return regularization
leaves = jax.tree_leaves(parameters)
if self.l1:
regularization += self.l1 * sum(jnp.sum(jnp.abs(p)) for p in leaves)
if self.l2:
regularization += self.l2 * sum(jnp.sum(jnp.square(p)) for p in leaves)
return regularization