Skip to content

treex.Loss

Loss base class.

To be implemented by subclasses:

  • call(): Contains the logic for loss calculation.

Example subclass implementation:

class MeanSquaredError(Loss):
    def call(self, target, preds):
        return jnp.mean(jnp.square(preds - target), axis=-1)

Please see the [Modules, Losses, and Metrics Guide] (https://poets-ai.github.io/elegy/guides/modules-losses-metrics/#losses) for more details on this.

__init__(self, reduction=None, weight=None, on=None, name=None) special

Initializes Loss class.

Parameters:

Name Type Description Default
reduction Optional[treex.losses.loss.Reduction]

(Optional) Type of tx.losses.Reduction to apply to loss. Default value is SUM_OVER_BATCH_SIZE. For almost all cases this defaults to SUM_OVER_BATCH_SIZE.

None
weight Union[float, numpy.ndarray, jax._src.numpy.lax_numpy.ndarray]

Optional weight contribution for the total loss. Defaults to 1.

None
on Union[str, int, Sequence[Union[str, int]]]

A string or integer, or iterable of string or integers, that indicate how to index/filter the target and preds arguments before passing them to call. For example if on = "a" then target = target["a"]. If on is an iterable the structures will be indexed iteratively, for example if on = ["a", 0, "b"] then target = target["a"][0]["b"], same for preds. For more information check out Keras-like behavior.

None
name Optional[str]

Optional name for the instance, if not provided lower snake_case version of the name of the class is used instead.

None
Source code in treex/losses/loss.py
def __init__(
    self,
    reduction: tp.Optional[Reduction] = None,
    weight: tp.Optional[types.ScalarLike] = None,
    on: tp.Optional[types.IndexLike] = None,
    name: tp.Optional[str] = None,
):
    """
    Initializes `Loss` class.

    Arguments:
        reduction: (Optional) Type of `tx.losses.Reduction` to apply to
            loss. Default value is `SUM_OVER_BATCH_SIZE`. For almost all cases
            this defaults to `SUM_OVER_BATCH_SIZE`.
        weight: Optional weight contribution for the total loss. Defaults to `1`.
        on: A string or integer, or iterable of string or integers, that
            indicate how to index/filter the `target` and `preds`
            arguments before passing them to `call`. For example if `on = "a"` then
            `target = target["a"]`. If `on` is an iterable
            the structures will be indexed iteratively, for example if `on = ["a", 0, "b"]`
            then `target = target["a"][0]["b"]`, same for `preds`. For more information
            check out [Keras-like behavior](https://poets-ai.github.io/elegy/guides/modules-losses-metrics/#keras-like-behavior).
        name: Optional name for the instance, if not provided lower snake_case version
            of the name of the class is used instead.
    """
    self.name = name if name is not None else utils._get_name(self)
    self.weight = (
        jnp.asarray(weight, dtype=jnp.float32)
        if weight is not None
        else jnp.array(1.0, dtype=jnp.float32)
    )
    self._reduction = (
        reduction if reduction is not None else Reduction.SUM_OVER_BATCH_SIZE
    )
    self._labels_filter = (on,) if isinstance(on, (str, int)) else on
    self._signature_f = self.call