Skip to content

treex.metrics.Mean

Computes the (weighted) mean of the given values.

For example, if values is [1, 3, 5, 7] then the mean is 4. If the weights were specified as [1, 1, 0, 0] then the mean would be 2. This metric creates two variables, total and count that are used to compute the average of values. This average is ultimately returned as mean which is an idempotent operation that simply divides total by count. If sample_weight is None, weights default to 1. Use sample_weight of 0 to mask values.

Usage:

mean = elegy.metrics.Mean()
result = mean([1, 3, 5, 7])  # 16 / 4
assert result == 4.0


result = mean([4, 10])  # 30 / 6
assert result == 5.0

Usage with elegy API:

model = elegy.Model(
    module_fn,
    loss=tx.losses.MeanSquaredError(),
    metrics=elegy.metrics.Mean(),
)
Source code in treex/metrics/mean.py
class Mean(Reduce):
    """
    Computes the (weighted) mean of the given values.

    For example, if values is `[1, 3, 5, 7]` then the mean is `4`.
    If the weights were specified as `[1, 1, 0, 0]` then the mean would be `2`.
    This metric creates two variables, `total` and `count` that are used to
    compute the average of `values`. This average is ultimately returned as `mean`
    which is an idempotent operation that simply divides `total` by `count`.
    If `sample_weight` is `None`, weights default to 1.
    Use `sample_weight` of 0 to mask values.

    Usage:

    ```python
    mean = elegy.metrics.Mean()
    result = mean([1, 3, 5, 7])  # 16 / 4
    assert result == 4.0


    result = mean([4, 10])  # 30 / 6
    assert result == 5.0
    ```

    Usage with elegy API:

    ```python
    model = elegy.Model(
        module_fn,
        loss=tx.losses.MeanSquaredError(),
        metrics=elegy.metrics.Mean(),
    )
    ```
    """

    def __init__(
        self,
        on: tp.Optional[types.IndexLike] = None,
        name: tp.Optional[str] = None,
        dtype: tp.Optional[jnp.dtype] = None,
    ):
        """Creates a `Mean` instance.
        Arguments:
            on: A string or integer, or iterable of string or integers, that
                indicate how to index/filter the `target` and `preds`
                arguments before passing them to `call`. For example if `on = "a"` then
                `target = target["a"]`. If `on` is an iterable
                the structures will be indexed iteratively, for example if `on = ["a", 0, "b"]`
                then `target = target["a"][0]["b"]`, same for `preds`. For more information
                check out [Keras-like behavior](https://poets-ai.github.io/elegy/guides/modules-losses-metrics/#keras-like-behavior).
            kwargs: Additional keyword arguments passed to Module.
        """
        super().__init__(
            reduction=Reduction.weighted_mean,
            on=on,
            name=name,
            dtype=dtype,
        )

    def update(
        self,
        values: jnp.ndarray,
        sample_weight: tp.Optional[jnp.ndarray] = None,
    ):
        """
        Accumulates the mean statistic over various batches.

        Arguments:
            values: Per-example value.
            sample_weight: Optional weighting of each example.

        Returns:
            Array with the cumulative mean.
        """

        super().update(
            values=values,
            sample_weight=sample_weight,
        )

__call__(self, values, sample_weight=None) special

Accumulates the mean statistic over various batches.

Parameters:

Name Type Description Default
values ndarray

Per-example value.

required
sample_weight Optional[jax._src.numpy.lax_numpy.ndarray]

Optional weighting of each example.

None

Returns:

Type Description

Array with the cumulative mean.

Source code in treex/metrics/mean.py
def update(
    self,
    values: jnp.ndarray,
    sample_weight: tp.Optional[jnp.ndarray] = None,
):
    """
    Accumulates the mean statistic over various batches.

    Arguments:
        values: Per-example value.
        sample_weight: Optional weighting of each example.

    Returns:
        Array with the cumulative mean.
    """

    super().update(
        values=values,
        sample_weight=sample_weight,
    )

__init__(self, on=None, name=None, dtype=None) special

Creates a Mean instance.

Parameters:

Name Type Description Default
on Union[str, int, Sequence[Union[str, int]]]

A string or integer, or iterable of string or integers, that indicate how to index/filter the target and preds arguments before passing them to call. For example if on = "a" then target = target["a"]. If on is an iterable the structures will be indexed iteratively, for example if on = ["a", 0, "b"] then target = target["a"][0]["b"], same for preds. For more information check out Keras-like behavior.

None
kwargs

Additional keyword arguments passed to Module.

required
Source code in treex/metrics/mean.py
def __init__(
    self,
    on: tp.Optional[types.IndexLike] = None,
    name: tp.Optional[str] = None,
    dtype: tp.Optional[jnp.dtype] = None,
):
    """Creates a `Mean` instance.
    Arguments:
        on: A string or integer, or iterable of string or integers, that
            indicate how to index/filter the `target` and `preds`
            arguments before passing them to `call`. For example if `on = "a"` then
            `target = target["a"]`. If `on` is an iterable
            the structures will be indexed iteratively, for example if `on = ["a", 0, "b"]`
            then `target = target["a"][0]["b"]`, same for `preds`. For more information
            check out [Keras-like behavior](https://poets-ai.github.io/elegy/guides/modules-losses-metrics/#keras-like-behavior).
        kwargs: Additional keyword arguments passed to Module.
    """
    super().__init__(
        reduction=Reduction.weighted_mean,
        on=on,
        name=name,
        dtype=dtype,
    )

update(self, values, sample_weight=None)

Accumulates the mean statistic over various batches.

Parameters:

Name Type Description Default
values ndarray

Per-example value.

required
sample_weight Optional[jax._src.numpy.lax_numpy.ndarray]

Optional weighting of each example.

None

Returns:

Type Description

Array with the cumulative mean.

Source code in treex/metrics/mean.py
def update(
    self,
    values: jnp.ndarray,
    sample_weight: tp.Optional[jnp.ndarray] = None,
):
    """
    Accumulates the mean statistic over various batches.

    Arguments:
        values: Per-example value.
        sample_weight: Optional weighting of each example.

    Returns:
        Array with the cumulative mean.
    """

    super().update(
        values=values,
        sample_weight=sample_weight,
    )