# treex.metrics.MSE

Source code in treex/metrics/mean_square_error.py
class MeanSquareError(Mean):
def __init__(
self,
on: tp.Optional[types.IndexLike] = None,
name: tp.Optional[str] = None,
dtype: tp.Optional[jnp.dtype] = None,
):
"""
Computes Mean Square Error_ (MSE):
.. math:: \text{MSE} = \frac{1}{N}\sum_i^N(y_i - \hat{y_i})^2
Where :math:y is a tensor of target values, and :math:\hat{y} is a tensor of predictions.

Args:
on:
A string or integer, or iterable of string or integers, that
indicate how to index/filter the target and preds
arguments before passing them to call. For example if on = "a" then
target = target["a"]. If on is an iterable
the structures will be indexed iteratively, for example if on = ["a", 0, "b"]
then target = target["a"][0]["b"], same for preds. For more information
check out [Keras-like behavior](https://poets-ai.github.io/elegy/guides/modules-losses-metrics/#keras-like-behavior).
name:
Module name
dtype:
Metrics states initialization dtype

Example:
>>> import jax.numpy as jnp
>>> from treex.metrics.mean_square_error import MeanSquareError

>>> target = jnp.array([3.0, -0.5, 2.0, 7.0])
>>> preds = jnp.array([3.0, -0.5, 2.0, 7.0])

>>> mse = MeanSquareError()
>>> mse(preds, target)

"""
super().__init__(on=on, name=name, dtype=dtype)

def update(
self,
target: jnp.ndarray,
preds: jnp.ndarray,
sample_weight: jnp.ndarray = None,
) -> tp.Any:
"""
Accumulates metric statistics. target and preds should have the same shape.

Arguments:
target:
Ground truth values. shape = [batch_size, d0, .. dN].
preds:
The predicted values. shape = [batch_size, d0, .. dN]
sample_weight:
Optional weighting of each example. Defaults to 1. shape = [batch_size, d0, .. dN]
Returns:
Array with the cumulative mean absolute error.
"""
values = _mean_square_error(preds, target)
return super().update(values, sample_weight)


## __call__(self, target, preds, sample_weight=None) special

Accumulates metric statistics. target and preds should have the same shape.

Parameters:

Name Type Description Default
target ndarray

Ground truth values. shape = [batch_size, d0, .. dN].

required
preds ndarray

The predicted values. shape = [batch_size, d0, .. dN]

required
sample_weight ndarray

Optional weighting of each example. Defaults to 1. shape = [batch_size, d0, .. dN]

None

Returns:

Type Description
Any

Array with the cumulative mean absolute error.

Source code in treex/metrics/mean_square_error.py
def update(
self,
target: jnp.ndarray,
preds: jnp.ndarray,
sample_weight: jnp.ndarray = None,
) -> tp.Any:
"""
Accumulates metric statistics. target and preds should have the same shape.

Arguments:
target:
Ground truth values. shape = [batch_size, d0, .. dN].
preds:
The predicted values. shape = [batch_size, d0, .. dN]
sample_weight:
Optional weighting of each example. Defaults to 1. shape = [batch_size, d0, .. dN]
Returns:
Array with the cumulative mean absolute error.
"""
values = _mean_square_error(preds, target)
return super().update(values, sample_weight)


## __init__(self, on=None, name=None, dtype=None) special

Computes Mean Square Error_ (MSE): .. math:: ext{MSE} = rac{1}{N}\sum_i^N(y_i - \hat{y_i})^2 Where :math:y is a tensor of target values, and :math:\hat{y} is a tensor of predictions.

Parameters:

Name Type Description Default
on Union[str, int, Sequence[Union[str, int]]]

A string or integer, or iterable of string or integers, that indicate how to index/filter the target and preds arguments before passing them to call. For example if on = "a" then target = target["a"]. If on is an iterable the structures will be indexed iteratively, for example if on = ["a", 0, "b"] then target = target["a"][0]["b"], same for preds. For more information check out Keras-like behavior.

None
name Optional[str]

Module name

None
dtype Optional[numpy.dtype]

Metrics states initialization dtype

None

Examples:

import jax.numpy as jnp from treex.metrics.mean_square_error import MeanSquareError

target = jnp.array([3.0, -0.5, 2.0, 7.0]) preds = jnp.array([3.0, -0.5, 2.0, 7.0])

mse = MeanSquareError() mse(preds, target)

Source code in treex/metrics/mean_square_error.py
def __init__(
self,
on: tp.Optional[types.IndexLike] = None,
name: tp.Optional[str] = None,
dtype: tp.Optional[jnp.dtype] = None,
):
"""
Computes Mean Square Error_ (MSE):
.. math:: \text{MSE} = \frac{1}{N}\sum_i^N(y_i - \hat{y_i})^2
Where :math:y is a tensor of target values, and :math:\hat{y} is a tensor of predictions.

Args:
on:
A string or integer, or iterable of string or integers, that
indicate how to index/filter the target and preds
arguments before passing them to call. For example if on = "a" then
target = target["a"]. If on is an iterable
the structures will be indexed iteratively, for example if on = ["a", 0, "b"]
then target = target["a"][0]["b"], same for preds. For more information
check out [Keras-like behavior](https://poets-ai.github.io/elegy/guides/modules-losses-metrics/#keras-like-behavior).
name:
Module name
dtype:
Metrics states initialization dtype

Example:
>>> import jax.numpy as jnp
>>> from treex.metrics.mean_square_error import MeanSquareError

>>> target = jnp.array([3.0, -0.5, 2.0, 7.0])
>>> preds = jnp.array([3.0, -0.5, 2.0, 7.0])

>>> mse = MeanSquareError()
>>> mse(preds, target)

"""
super().__init__(on=on, name=name, dtype=dtype)


## update(self, target, preds, sample_weight=None)

Accumulates metric statistics. target and preds should have the same shape.

Parameters:

Name Type Description Default
target ndarray

Ground truth values. shape = [batch_size, d0, .. dN].

required
preds ndarray

The predicted values. shape = [batch_size, d0, .. dN]

required
sample_weight ndarray

Optional weighting of each example. Defaults to 1. shape = [batch_size, d0, .. dN]

None

Returns:

Type Description
Any

Array with the cumulative mean absolute error.

Source code in treex/metrics/mean_square_error.py
def update(
self,
target: jnp.ndarray,
preds: jnp.ndarray,
sample_weight: jnp.ndarray = None,
) -> tp.Any:
"""
Accumulates metric statistics. target and preds should have the same shape.

Arguments:
target:
Ground truth values. shape = [batch_size, d0, .. dN].
preds:
The predicted values. shape = [batch_size, d0, .. dN]
sample_weight:
Optional weighting of each example. Defaults to 1. shape = [batch_size, d0, .. dN]
Returns:
Array with the cumulative mean absolute error.
"""
values = _mean_square_error(preds, target)
return super().update(values, sample_weight)