Skip to content

treex.losses.Crossentropy

Computes the crossentropy loss between the target and predictions.

Use this crossentropy loss function when there are two or more label classes. We expect target to be provided as integers. If you want to provide target using one-hot representation, please use CategoricalCrossentropy loss. There should be # classes floating point values per feature for preds and a single floating point value per feature for target. In the snippet below, there is a single floating point value per example for target and # classes floating pointing values per example for preds. The shape of target is [batch_size] and the shape of preds is [batch_size, num_classes].

Usage:

target = jnp.array([1, 2])
preds = jnp.array([[0.05, 0.95, 0], [0.1, 0.8, 0.1]])

# Using 'auto'/'sum_over_batch_size' reduction type.
scce = tx.losses.Crossentropy()
result = scce(target, preds)  # 1.177
assert np.isclose(result, 1.177, rtol=0.01)

# Calling with 'sample_weight'.
result = scce(target, preds, sample_weight=jnp.array([0.3, 0.7]))  # 0.814
assert np.isclose(result, 0.814, rtol=0.01)

# Using 'sum' reduction type.
scce = tx.losses.Crossentropy(
    reduction=tx.losses.Reduction.SUM
)
result = scce(target, preds)  # 2.354
assert np.isclose(result, 2.354, rtol=0.01)

# Using 'none' reduction type.
scce = tx.losses.Crossentropy(
    reduction=tx.losses.Reduction.NONE
)
result = scce(target, preds)  # [0.0513, 2.303]
assert jnp.all(np.isclose(result, [0.0513, 2.303], rtol=0.01))

Usage with the Elegy API:

model = elegy.Model(
    module_fn,
    loss=tx.losses.Crossentropy(),
    metrics=elegy.metrics.Accuracy(),
    optimizer=optax.adam(1e-3),
)
Source code in treex/losses/crossentropy.py
class Crossentropy(Loss):
    """
    Computes the crossentropy loss between the target and predictions.

    Use this crossentropy loss function when there are two or more label classes.
    We expect target to be provided as integers. If you want to provide target
    using `one-hot` representation, please use `CategoricalCrossentropy` loss.
    There should be `# classes` floating point values per feature for `preds`
    and a single floating point value per feature for `target`.
    In the snippet below, there is a single floating point value per example for
    `target` and `# classes` floating pointing values per example for `preds`.
    The shape of `target` is `[batch_size]` and the shape of `preds` is
    `[batch_size, num_classes]`.

    Usage:
    ```python
    target = jnp.array([1, 2])
    preds = jnp.array([[0.05, 0.95, 0], [0.1, 0.8, 0.1]])

    # Using 'auto'/'sum_over_batch_size' reduction type.
    scce = tx.losses.Crossentropy()
    result = scce(target, preds)  # 1.177
    assert np.isclose(result, 1.177, rtol=0.01)

    # Calling with 'sample_weight'.
    result = scce(target, preds, sample_weight=jnp.array([0.3, 0.7]))  # 0.814
    assert np.isclose(result, 0.814, rtol=0.01)

    # Using 'sum' reduction type.
    scce = tx.losses.Crossentropy(
        reduction=tx.losses.Reduction.SUM
    )
    result = scce(target, preds)  # 2.354
    assert np.isclose(result, 2.354, rtol=0.01)

    # Using 'none' reduction type.
    scce = tx.losses.Crossentropy(
        reduction=tx.losses.Reduction.NONE
    )
    result = scce(target, preds)  # [0.0513, 2.303]
    assert jnp.all(np.isclose(result, [0.0513, 2.303], rtol=0.01))
    ```

    Usage with the `Elegy` API:

    ```python
    model = elegy.Model(
        module_fn,
        loss=tx.losses.Crossentropy(),
        metrics=elegy.metrics.Accuracy(),
        optimizer=optax.adam(1e-3),
    )

    ```
    """

    def __init__(
        self,
        *,
        from_logits: bool = True,
        binary: bool = False,
        label_smoothing: tp.Optional[float] = None,
        reduction: tp.Optional[Reduction] = None,
        check_bounds: bool = True,
        weight: tp.Optional[float] = None,
        on: tp.Optional[types.IndexLike] = None,
        name: tp.Optional[str] = None,
    ):
        """
        Initializes `SparseCategoricalCrossentropy` instance.

        Arguments:
            from_logits: Whether `preds` is expected to be a logits tensor. By
                default, we assume that `preds` encodes a probability distribution.
                **Note - Using from_logits=True is more numerically stable.**
            reduction: (Optional) Type of `tx.losses.Reduction` to apply to
                loss. Default value is `SUM_OVER_BATCH_SIZE`. For almost all cases
                this defaults to `SUM_OVER_BATCH_SIZE`.
            weight: Optional weight contribution for the total loss. Defaults to `1`.
            on: A string or integer, or iterable of string or integers, that
                indicate how to index/filter the `target` and `preds`
                arguments before passing them to `call`. For example if `on = "a"` then
                `target = target["a"]`. If `on` is an iterable
                the structures will be indexed iteratively, for example if `on = ["a", 0, "b"]`
                then `target = target["a"][0]["b"]`, same for `preds`. For more information
                check out [Keras-like behavior](https://poets-ai.github.io/elegy/guides/modules-losses-metrics/#keras-like-behavior).
            check_bounds: If `True` (default), checks `target` for negative values and values
                larger or equal than the number of channels in `preds`. Sets loss to NaN
                if this is the case. If `False`, the check is disabled and the loss may contain
                incorrect values.
        """
        super().__init__(reduction=reduction, weight=weight, on=on, name=name)

        self._from_logits = from_logits
        self._check_bounds = check_bounds
        self._binary = binary
        self._label_smoothing = label_smoothing

    def call(
        self, target, preds, sample_weight: tp.Optional[jnp.ndarray] = None
    ) -> jnp.ndarray:
        """
        Invokes the `SparseCategoricalCrossentropy` instance.

        Arguments:
            target: Ground truth values.
            preds: The predicted values.
            sample_weight: Acts as a
                coefficient for the loss. If a scalar is provided, then the loss is
                simply scaled by the given value. If `sample_weight` is a tensor of size
                `[batch_size]`, then the total loss for each sample of the batch is
                rescaled by the corresponding element in the `sample_weight` vector. If
                the shape of `sample_weight` is `[batch_size, d0, .. dN-1]` (or can be
                broadcasted to this shape), then each loss element of `preds` is scaled
                by the corresponding value of `sample_weight`. (Note on`dN-1`: all loss
                functions reduce by 1 dimension, usually axis=-1.)

        Returns:
            Loss values per sample.
        """

        return crossentropy(
            target,
            preds,
            binary=self._binary,
            from_logits=self._from_logits,
            label_smoothing=self._label_smoothing,
            check_bounds=self._check_bounds,
        )

__init__(self, *, from_logits=True, binary=False, label_smoothing=None, reduction=None, check_bounds=True, weight=None, on=None, name=None) special

Initializes SparseCategoricalCrossentropy instance.

Parameters:

Name Type Description Default
from_logits bool

Whether preds is expected to be a logits tensor. By default, we assume that preds encodes a probability distribution. Note - Using from_logits=True is more numerically stable.

True
reduction Optional[treex.losses.loss.Reduction]

(Optional) Type of tx.losses.Reduction to apply to loss. Default value is SUM_OVER_BATCH_SIZE. For almost all cases this defaults to SUM_OVER_BATCH_SIZE.

None
weight Optional[float]

Optional weight contribution for the total loss. Defaults to 1.

None
on Union[str, int, Sequence[Union[str, int]]]

A string or integer, or iterable of string or integers, that indicate how to index/filter the target and preds arguments before passing them to call. For example if on = "a" then target = target["a"]. If on is an iterable the structures will be indexed iteratively, for example if on = ["a", 0, "b"] then target = target["a"][0]["b"], same for preds. For more information check out Keras-like behavior.

None
check_bounds bool

If True (default), checks target for negative values and values larger or equal than the number of channels in preds. Sets loss to NaN if this is the case. If False, the check is disabled and the loss may contain incorrect values.

True
Source code in treex/losses/crossentropy.py
def __init__(
    self,
    *,
    from_logits: bool = True,
    binary: bool = False,
    label_smoothing: tp.Optional[float] = None,
    reduction: tp.Optional[Reduction] = None,
    check_bounds: bool = True,
    weight: tp.Optional[float] = None,
    on: tp.Optional[types.IndexLike] = None,
    name: tp.Optional[str] = None,
):
    """
    Initializes `SparseCategoricalCrossentropy` instance.

    Arguments:
        from_logits: Whether `preds` is expected to be a logits tensor. By
            default, we assume that `preds` encodes a probability distribution.
            **Note - Using from_logits=True is more numerically stable.**
        reduction: (Optional) Type of `tx.losses.Reduction` to apply to
            loss. Default value is `SUM_OVER_BATCH_SIZE`. For almost all cases
            this defaults to `SUM_OVER_BATCH_SIZE`.
        weight: Optional weight contribution for the total loss. Defaults to `1`.
        on: A string or integer, or iterable of string or integers, that
            indicate how to index/filter the `target` and `preds`
            arguments before passing them to `call`. For example if `on = "a"` then
            `target = target["a"]`. If `on` is an iterable
            the structures will be indexed iteratively, for example if `on = ["a", 0, "b"]`
            then `target = target["a"][0]["b"]`, same for `preds`. For more information
            check out [Keras-like behavior](https://poets-ai.github.io/elegy/guides/modules-losses-metrics/#keras-like-behavior).
        check_bounds: If `True` (default), checks `target` for negative values and values
            larger or equal than the number of channels in `preds`. Sets loss to NaN
            if this is the case. If `False`, the check is disabled and the loss may contain
            incorrect values.
    """
    super().__init__(reduction=reduction, weight=weight, on=on, name=name)

    self._from_logits = from_logits
    self._check_bounds = check_bounds
    self._binary = binary
    self._label_smoothing = label_smoothing

call(self, target, preds, sample_weight=None)

Invokes the SparseCategoricalCrossentropy instance.

Parameters:

Name Type Description Default
target

Ground truth values.

required
preds

The predicted values.

required
sample_weight Optional[jax._src.numpy.lax_numpy.ndarray]

Acts as a coefficient for the loss. If a scalar is provided, then the loss is simply scaled by the given value. If sample_weight is a tensor of size [batch_size], then the total loss for each sample of the batch is rescaled by the corresponding element in the sample_weight vector. If the shape of sample_weight is [batch_size, d0, .. dN-1] (or can be broadcasted to this shape), then each loss element of preds is scaled by the corresponding value of sample_weight. (Note ondN-1: all loss functions reduce by 1 dimension, usually axis=-1.)

None

Returns:

Type Description
ndarray

Loss values per sample.

Source code in treex/losses/crossentropy.py
def call(
    self, target, preds, sample_weight: tp.Optional[jnp.ndarray] = None
) -> jnp.ndarray:
    """
    Invokes the `SparseCategoricalCrossentropy` instance.

    Arguments:
        target: Ground truth values.
        preds: The predicted values.
        sample_weight: Acts as a
            coefficient for the loss. If a scalar is provided, then the loss is
            simply scaled by the given value. If `sample_weight` is a tensor of size
            `[batch_size]`, then the total loss for each sample of the batch is
            rescaled by the corresponding element in the `sample_weight` vector. If
            the shape of `sample_weight` is `[batch_size, d0, .. dN-1]` (or can be
            broadcasted to this shape), then each loss element of `preds` is scaled
            by the corresponding value of `sample_weight`. (Note on`dN-1`: all loss
            functions reduce by 1 dimension, usually axis=-1.)

    Returns:
        Loss values per sample.
    """

    return crossentropy(
        target,
        preds,
        binary=self._binary,
        from_logits=self._from_logits,
        label_smoothing=self._label_smoothing,
        check_bounds=self._check_bounds,
    )