Skip to content


Loss base class.

To be implemented by subclasses:

  • call(): Contains the logic for loss calculation.

Example subclass implementation:

class MeanSquaredError(Loss):
    def call(self, target, preds):
        return jnp.mean(jnp.square(preds - target), axis=-1)

Please see the [Modules, Losses, and Metrics Guide] ( for more details on this.

Source code in treex/losses/
class Loss:
    Loss base class.

    To be implemented by subclasses:

    * `call()`: Contains the logic for loss calculation.

    Example subclass implementation:

    class MeanSquaredError(Loss):
        def call(self, target, preds):
            return jnp.mean(jnp.square(preds - target), axis=-1)

    Please see the [Modules, Losses, and Metrics Guide]
    ( for more
    details on this.

    def __init__(
        reduction: tp.Optional[Reduction] = None,
        weight: tp.Optional[types.ScalarLike] = None,
        on: tp.Optional[types.IndexLike] = None,
        name: tp.Optional[str] = None,
        Initializes `Loss` class.

            reduction: (Optional) Type of `tx.losses.Reduction` to apply to
                loss. Default value is `SUM_OVER_BATCH_SIZE`. For almost all cases
                this defaults to `SUM_OVER_BATCH_SIZE`.
            weight: Optional weight contribution for the total loss. Defaults to `1`.
            on: A string or integer, or iterable of string or integers, that
                indicate how to index/filter the `target` and `preds`
                arguments before passing them to `call`. For example if `on = "a"` then
                `target = target["a"]`. If `on` is an iterable
                the structures will be indexed iteratively, for example if `on = ["a", 0, "b"]`
                then `target = target["a"][0]["b"]`, same for `preds`. For more information
                check out [Keras-like behavior](
            name: Optional name for the instance, if not provided lower snake_case version
                of the name of the class is used instead.
        """ = name if name is not None else utils._get_name(self)
        self.weight = (
            jnp.asarray(weight, dtype=jnp.float32)
            if weight is not None
            else jnp.array(1.0, dtype=jnp.float32)
        self._reduction = (
            reduction if reduction is not None else Reduction.SUM_OVER_BATCH_SIZE
        self._labels_filter = (on,) if isinstance(on, (str, int)) else on
        self._signature_f =

    def __call__(
    ) -> jnp.ndarray:

        if self._labels_filter is not None:
            if "target" in kwargs and kwargs["target"] is not None:
                for index in self._labels_filter:
                    kwargs["target"] = kwargs["target"][index]

            if "preds" in kwargs and kwargs["preds"] is not None:
                for index in self._labels_filter:
                    kwargs["preds"] = kwargs["preds"][index]

        sample_weight: tp.Optional[jnp.ndarray] = kwargs.pop("sample_weight", None)

        values =**kwargs)

        return reduce_loss(values, sample_weight, self.weight, self._reduction)

    def call(self, *args, **kwargs) -> jnp.ndarray:

__init__(self, reduction=None, weight=None, on=None, name=None) special

Initializes Loss class.


Name Type Description Default
reduction Optional[treex.losses.loss.Reduction]

(Optional) Type of tx.losses.Reduction to apply to loss. Default value is SUM_OVER_BATCH_SIZE. For almost all cases this defaults to SUM_OVER_BATCH_SIZE.

weight Union[float, numpy.ndarray, jax._src.numpy.lax_numpy.ndarray]

Optional weight contribution for the total loss. Defaults to 1.

on Union[str, int, Sequence[Union[str, int]]]

A string or integer, or iterable of string or integers, that indicate how to index/filter the target and preds arguments before passing them to call. For example if on = "a" then target = target["a"]. If on is an iterable the structures will be indexed iteratively, for example if on = ["a", 0, "b"] then target = target["a"][0]["b"], same for preds. For more information check out Keras-like behavior.

name Optional[str]

Optional name for the instance, if not provided lower snake_case version of the name of the class is used instead.

Source code in treex/losses/
def __init__(
    reduction: tp.Optional[Reduction] = None,
    weight: tp.Optional[types.ScalarLike] = None,
    on: tp.Optional[types.IndexLike] = None,
    name: tp.Optional[str] = None,
    Initializes `Loss` class.

        reduction: (Optional) Type of `tx.losses.Reduction` to apply to
            loss. Default value is `SUM_OVER_BATCH_SIZE`. For almost all cases
            this defaults to `SUM_OVER_BATCH_SIZE`.
        weight: Optional weight contribution for the total loss. Defaults to `1`.
        on: A string or integer, or iterable of string or integers, that
            indicate how to index/filter the `target` and `preds`
            arguments before passing them to `call`. For example if `on = "a"` then
            `target = target["a"]`. If `on` is an iterable
            the structures will be indexed iteratively, for example if `on = ["a", 0, "b"]`
            then `target = target["a"][0]["b"]`, same for `preds`. For more information
            check out [Keras-like behavior](
        name: Optional name for the instance, if not provided lower snake_case version
            of the name of the class is used instead.
    """ = name if name is not None else utils._get_name(self)
    self.weight = (
        jnp.asarray(weight, dtype=jnp.float32)
        if weight is not None
        else jnp.array(1.0, dtype=jnp.float32)
    self._reduction = (
        reduction if reduction is not None else Reduction.SUM_OVER_BATCH_SIZE
    self._labels_filter = (on,) if isinstance(on, (str, int)) else on
    self._signature_f =