Skip to content

treex.metrics.MeanSquareError

Source code in treex/metrics/mean_square_error.py
class MeanSquareError(Mean):
    def __init__(
        self,
        on: tp.Optional[types.IndexLike] = None,
        name: tp.Optional[str] = None,
        dtype: tp.Optional[jnp.dtype] = None,
    ):
        """
        `Computes Mean Square Error`_ (MSE):
        .. math:: \text{MSE} = \frac{1}{N}\sum_i^N(y_i - \hat{y_i})^2
        Where :math:`y` is a tensor of target values, and :math:`\hat{y}` is a tensor of predictions.

        Args:
            on:
                A string or integer, or iterable of string or integers, that
                indicate how to index/filter the `target` and `preds`
                arguments before passing them to `call`. For example if `on = "a"` then
                `target = target["a"]`. If `on` is an iterable
                the structures will be indexed iteratively, for example if `on = ["a", 0, "b"]`
                then `target = target["a"][0]["b"]`, same for `preds`. For more information
                check out [Keras-like behavior](https://poets-ai.github.io/elegy/guides/modules-losses-metrics/#keras-like-behavior).
            name:
                Module name
            dtype:
                Metrics states initialization dtype


        Example:
        >>> import jax.numpy as jnp
        >>> from treex.metrics.mean_square_error import MeanSquareError

        >>> target = jnp.array([3.0, -0.5, 2.0, 7.0])
        >>> preds = jnp.array([3.0, -0.5, 2.0, 7.0])

        >>> mse = MeanSquareError()
        >>> mse(preds, target)

        """
        super().__init__(on=on, name=name, dtype=dtype)

    def update(
        self,
        target: jnp.ndarray,
        preds: jnp.ndarray,
        sample_weight: jnp.ndarray = None,
    ) -> tp.Any:
        """
        Accumulates metric statistics. `target` and `preds` should have the same shape.

        Arguments:
            target:
                Ground truth values. shape = `[batch_size, d0, .. dN]`.
            preds:
                The predicted values. shape = `[batch_size, d0, .. dN]`
            sample_weight:
                Optional weighting of each example. Defaults to 1. shape = `[batch_size, d0, .. dN]`
        Returns:
            Array with the cumulative mean absolute error.
        """
        values = _mean_square_error(preds, target)
        return super().update(values, sample_weight)

__call__(self, target, preds, sample_weight=None) special

Accumulates metric statistics. target and preds should have the same shape.

Parameters:

Name Type Description Default
target ndarray

Ground truth values. shape = [batch_size, d0, .. dN].

required
preds ndarray

The predicted values. shape = [batch_size, d0, .. dN]

required
sample_weight ndarray

Optional weighting of each example. Defaults to 1. shape = [batch_size, d0, .. dN]

None

Returns:

Type Description
Any

Array with the cumulative mean absolute error.

Source code in treex/metrics/mean_square_error.py
def update(
    self,
    target: jnp.ndarray,
    preds: jnp.ndarray,
    sample_weight: jnp.ndarray = None,
) -> tp.Any:
    """
    Accumulates metric statistics. `target` and `preds` should have the same shape.

    Arguments:
        target:
            Ground truth values. shape = `[batch_size, d0, .. dN]`.
        preds:
            The predicted values. shape = `[batch_size, d0, .. dN]`
        sample_weight:
            Optional weighting of each example. Defaults to 1. shape = `[batch_size, d0, .. dN]`
    Returns:
        Array with the cumulative mean absolute error.
    """
    values = _mean_square_error(preds, target)
    return super().update(values, sample_weight)

__init__(self, on=None, name=None, dtype=None) special

Computes Mean Square Error_ (MSE): .. math:: ext{MSE} = rac{1}{N}\sum_i^N(y_i - \hat{y_i})^2 Where :math:y is a tensor of target values, and :math:\hat{y} is a tensor of predictions.

Parameters:

Name Type Description Default
on Union[str, int, Sequence[Union[str, int]]]

A string or integer, or iterable of string or integers, that indicate how to index/filter the target and preds arguments before passing them to call. For example if on = "a" then target = target["a"]. If on is an iterable the structures will be indexed iteratively, for example if on = ["a", 0, "b"] then target = target["a"][0]["b"], same for preds. For more information check out Keras-like behavior.

None
name Optional[str]

Module name

None
dtype Optional[numpy.dtype]

Metrics states initialization dtype

None

Examples:

import jax.numpy as jnp from treex.metrics.mean_square_error import MeanSquareError

target = jnp.array([3.0, -0.5, 2.0, 7.0]) preds = jnp.array([3.0, -0.5, 2.0, 7.0])

mse = MeanSquareError() mse(preds, target)

Source code in treex/metrics/mean_square_error.py
def __init__(
    self,
    on: tp.Optional[types.IndexLike] = None,
    name: tp.Optional[str] = None,
    dtype: tp.Optional[jnp.dtype] = None,
):
    """
    `Computes Mean Square Error`_ (MSE):
    .. math:: \text{MSE} = \frac{1}{N}\sum_i^N(y_i - \hat{y_i})^2
    Where :math:`y` is a tensor of target values, and :math:`\hat{y}` is a tensor of predictions.

    Args:
        on:
            A string or integer, or iterable of string or integers, that
            indicate how to index/filter the `target` and `preds`
            arguments before passing them to `call`. For example if `on = "a"` then
            `target = target["a"]`. If `on` is an iterable
            the structures will be indexed iteratively, for example if `on = ["a", 0, "b"]`
            then `target = target["a"][0]["b"]`, same for `preds`. For more information
            check out [Keras-like behavior](https://poets-ai.github.io/elegy/guides/modules-losses-metrics/#keras-like-behavior).
        name:
            Module name
        dtype:
            Metrics states initialization dtype


    Example:
    >>> import jax.numpy as jnp
    >>> from treex.metrics.mean_square_error import MeanSquareError

    >>> target = jnp.array([3.0, -0.5, 2.0, 7.0])
    >>> preds = jnp.array([3.0, -0.5, 2.0, 7.0])

    >>> mse = MeanSquareError()
    >>> mse(preds, target)

    """
    super().__init__(on=on, name=name, dtype=dtype)

update(self, target, preds, sample_weight=None)

Accumulates metric statistics. target and preds should have the same shape.

Parameters:

Name Type Description Default
target ndarray

Ground truth values. shape = [batch_size, d0, .. dN].

required
preds ndarray

The predicted values. shape = [batch_size, d0, .. dN]

required
sample_weight ndarray

Optional weighting of each example. Defaults to 1. shape = [batch_size, d0, .. dN]

None

Returns:

Type Description
Any

Array with the cumulative mean absolute error.

Source code in treex/metrics/mean_square_error.py
def update(
    self,
    target: jnp.ndarray,
    preds: jnp.ndarray,
    sample_weight: jnp.ndarray = None,
) -> tp.Any:
    """
    Accumulates metric statistics. `target` and `preds` should have the same shape.

    Arguments:
        target:
            Ground truth values. shape = `[batch_size, d0, .. dN]`.
        preds:
            The predicted values. shape = `[batch_size, d0, .. dN]`
        sample_weight:
            Optional weighting of each example. Defaults to 1. shape = `[batch_size, d0, .. dN]`
    Returns:
        Array with the cumulative mean absolute error.
    """
    values = _mean_square_error(preds, target)
    return super().update(values, sample_weight)