# treex.metrics.Mean

Computes the (weighted) mean of the given values.

For example, if values is [1, 3, 5, 7] then the mean is 4. If the weights were specified as [1, 1, 0, 0] then the mean would be 2. This metric creates two variables, total and count that are used to compute the average of values. This average is ultimately returned as mean which is an idempotent operation that simply divides total by count. If sample_weight is None, weights default to 1. Use sample_weight of 0 to mask values.

Usage:

mean = elegy.metrics.Mean()
result = mean([1, 3, 5, 7])  # 16 / 4
assert result == 4.0

result = mean([4, 10])  # 30 / 6
assert result == 5.0


Usage with elegy API:

model = elegy.Model(
module_fn,
loss=tx.losses.MeanSquaredError(),
metrics=elegy.metrics.Mean(),
)

Source code in treex/metrics/mean.py
class Mean(Reduce):
"""
Computes the (weighted) mean of the given values.

For example, if values is [1, 3, 5, 7] then the mean is 4.
If the weights were specified as [1, 1, 0, 0] then the mean would be 2.
This metric creates two variables, total and count that are used to
compute the average of values. This average is ultimately returned as mean
which is an idempotent operation that simply divides total by count.
If sample_weight is None, weights default to 1.
Use sample_weight of 0 to mask values.

Usage:

python
mean = elegy.metrics.Mean()
result = mean([1, 3, 5, 7])  # 16 / 4
assert result == 4.0

result = mean([4, 10])  # 30 / 6
assert result == 5.0


Usage with elegy API:

python
model = elegy.Model(
module_fn,
loss=tx.losses.MeanSquaredError(),
metrics=elegy.metrics.Mean(),
)

"""

def __init__(
self,
on: tp.Optional[types.IndexLike] = None,
name: tp.Optional[str] = None,
dtype: tp.Optional[jnp.dtype] = None,
):
"""Creates a Mean instance.
Arguments:
on: A string or integer, or iterable of string or integers, that
indicate how to index/filter the target and preds
arguments before passing them to call. For example if on = "a" then
target = target["a"]. If on is an iterable
the structures will be indexed iteratively, for example if on = ["a", 0, "b"]
then target = target["a"][0]["b"], same for preds. For more information
check out [Keras-like behavior](https://poets-ai.github.io/elegy/guides/modules-losses-metrics/#keras-like-behavior).
kwargs: Additional keyword arguments passed to Module.
"""
super().__init__(
reduction=Reduction.weighted_mean,
on=on,
name=name,
dtype=dtype,
)

def update(
self,
values: jnp.ndarray,
sample_weight: tp.Optional[jnp.ndarray] = None,
):
"""
Accumulates the mean statistic over various batches.

Arguments:
values: Per-example value.
sample_weight: Optional weighting of each example.

Returns:
Array with the cumulative mean.
"""

super().update(
values=values,
sample_weight=sample_weight,
)


## __call__(self, values, sample_weight=None) special

Accumulates the mean statistic over various batches.

Parameters:

Name Type Description Default
values ndarray

Per-example value.

required
sample_weight Optional[jax._src.numpy.lax_numpy.ndarray]

Optional weighting of each example.

None

Returns:

Type Description

Array with the cumulative mean.

Source code in treex/metrics/mean.py
def update(
self,
values: jnp.ndarray,
sample_weight: tp.Optional[jnp.ndarray] = None,
):
"""
Accumulates the mean statistic over various batches.

Arguments:
values: Per-example value.
sample_weight: Optional weighting of each example.

Returns:
Array with the cumulative mean.
"""

super().update(
values=values,
sample_weight=sample_weight,
)


## __init__(self, on=None, name=None, dtype=None) special

Creates a Mean instance.

Parameters:

Name Type Description Default
on Union[str, int, Sequence[Union[str, int]]]

A string or integer, or iterable of string or integers, that indicate how to index/filter the target and preds arguments before passing them to call. For example if on = "a" then target = target["a"]. If on is an iterable the structures will be indexed iteratively, for example if on = ["a", 0, "b"] then target = target["a"][0]["b"], same for preds. For more information check out Keras-like behavior.

None
kwargs

Additional keyword arguments passed to Module.

required
Source code in treex/metrics/mean.py
def __init__(
self,
on: tp.Optional[types.IndexLike] = None,
name: tp.Optional[str] = None,
dtype: tp.Optional[jnp.dtype] = None,
):
"""Creates a Mean instance.
Arguments:
on: A string or integer, or iterable of string or integers, that
indicate how to index/filter the target and preds
arguments before passing them to call. For example if on = "a" then
target = target["a"]. If on is an iterable
the structures will be indexed iteratively, for example if on = ["a", 0, "b"]
then target = target["a"][0]["b"], same for preds. For more information
check out [Keras-like behavior](https://poets-ai.github.io/elegy/guides/modules-losses-metrics/#keras-like-behavior).
kwargs: Additional keyword arguments passed to Module.
"""
super().__init__(
reduction=Reduction.weighted_mean,
on=on,
name=name,
dtype=dtype,
)


## update(self, values, sample_weight=None)

Accumulates the mean statistic over various batches.

Parameters:

Name Type Description Default
values ndarray

Per-example value.

required
sample_weight Optional[jax._src.numpy.lax_numpy.ndarray]

Optional weighting of each example.

None

Returns:

Type Description

Array with the cumulative mean.

Source code in treex/metrics/mean.py
def update(
self,
values: jnp.ndarray,
sample_weight: tp.Optional[jnp.ndarray] = None,
):
"""
Accumulates the mean statistic over various batches.

Arguments:
values: Per-example value.
sample_weight: Optional weighting of each example.

Returns:
Array with the cumulative mean.
"""

super().update(
values=values,
sample_weight=sample_weight,
)