treex.metrics.Mean
Computes the (weighted) mean of the given values.
For example, if values is [1, 3, 5, 7] then the mean is 4.
If the weights were specified as [1, 1, 0, 0] then the mean would be 2.
This metric creates two variables, total and count that are used to
compute the average of values. This average is ultimately returned as mean
which is an idempotent operation that simply divides total by count.
If sample_weight is None, weights default to 1.
Use sample_weight of 0 to mask values.
Usage:
mean = elegy.metrics.Mean()
result = mean([1, 3, 5, 7]) # 16 / 4
assert result == 4.0
result = mean([4, 10]) # 30 / 6
assert result == 5.0
Usage with elegy API:
model = elegy.Model(
module_fn,
loss=tx.losses.MeanSquaredError(),
metrics=elegy.metrics.Mean(),
)
Source code in treex/metrics/mean.py
class Mean(Reduce):
"""
Computes the (weighted) mean of the given values.
For example, if values is `[1, 3, 5, 7]` then the mean is `4`.
If the weights were specified as `[1, 1, 0, 0]` then the mean would be `2`.
This metric creates two variables, `total` and `count` that are used to
compute the average of `values`. This average is ultimately returned as `mean`
which is an idempotent operation that simply divides `total` by `count`.
If `sample_weight` is `None`, weights default to 1.
Use `sample_weight` of 0 to mask values.
Usage:
```python
mean = elegy.metrics.Mean()
result = mean([1, 3, 5, 7]) # 16 / 4
assert result == 4.0
result = mean([4, 10]) # 30 / 6
assert result == 5.0
```
Usage with elegy API:
```python
model = elegy.Model(
module_fn,
loss=tx.losses.MeanSquaredError(),
metrics=elegy.metrics.Mean(),
)
```
"""
def __init__(
self,
on: tp.Optional[types.IndexLike] = None,
name: tp.Optional[str] = None,
dtype: tp.Optional[jnp.dtype] = None,
):
"""Creates a `Mean` instance.
Arguments:
on: A string or integer, or iterable of string or integers, that
indicate how to index/filter the `target` and `preds`
arguments before passing them to `call`. For example if `on = "a"` then
`target = target["a"]`. If `on` is an iterable
the structures will be indexed iteratively, for example if `on = ["a", 0, "b"]`
then `target = target["a"][0]["b"]`, same for `preds`. For more information
check out [Keras-like behavior](https://poets-ai.github.io/elegy/guides/modules-losses-metrics/#keras-like-behavior).
kwargs: Additional keyword arguments passed to Module.
"""
super().__init__(
reduction=Reduction.weighted_mean,
on=on,
name=name,
dtype=dtype,
)
def update(
self,
values: jnp.ndarray,
sample_weight: tp.Optional[jnp.ndarray] = None,
):
"""
Accumulates the mean statistic over various batches.
Arguments:
values: Per-example value.
sample_weight: Optional weighting of each example.
Returns:
Array with the cumulative mean.
"""
super().update(
values=values,
sample_weight=sample_weight,
)
__call__(self, values, sample_weight=None)
special
Accumulates the mean statistic over various batches.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
values |
ndarray |
Per-example value. |
required |
sample_weight |
Optional[jax._src.numpy.lax_numpy.ndarray] |
Optional weighting of each example. |
None |
Returns:
| Type | Description |
|---|---|
Array with the cumulative mean. |
Source code in treex/metrics/mean.py
def update(
self,
values: jnp.ndarray,
sample_weight: tp.Optional[jnp.ndarray] = None,
):
"""
Accumulates the mean statistic over various batches.
Arguments:
values: Per-example value.
sample_weight: Optional weighting of each example.
Returns:
Array with the cumulative mean.
"""
super().update(
values=values,
sample_weight=sample_weight,
)
__init__(self, on=None, name=None, dtype=None)
special
Creates a Mean instance.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
on |
Union[str, int, Sequence[Union[str, int]]] |
A string or integer, or iterable of string or integers, that
indicate how to index/filter the |
None |
kwargs |
Additional keyword arguments passed to Module. |
required |
Source code in treex/metrics/mean.py
def __init__(
self,
on: tp.Optional[types.IndexLike] = None,
name: tp.Optional[str] = None,
dtype: tp.Optional[jnp.dtype] = None,
):
"""Creates a `Mean` instance.
Arguments:
on: A string or integer, or iterable of string or integers, that
indicate how to index/filter the `target` and `preds`
arguments before passing them to `call`. For example if `on = "a"` then
`target = target["a"]`. If `on` is an iterable
the structures will be indexed iteratively, for example if `on = ["a", 0, "b"]`
then `target = target["a"][0]["b"]`, same for `preds`. For more information
check out [Keras-like behavior](https://poets-ai.github.io/elegy/guides/modules-losses-metrics/#keras-like-behavior).
kwargs: Additional keyword arguments passed to Module.
"""
super().__init__(
reduction=Reduction.weighted_mean,
on=on,
name=name,
dtype=dtype,
)
update(self, values, sample_weight=None)
Accumulates the mean statistic over various batches.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
values |
ndarray |
Per-example value. |
required |
sample_weight |
Optional[jax._src.numpy.lax_numpy.ndarray] |
Optional weighting of each example. |
None |
Returns:
| Type | Description |
|---|---|
Array with the cumulative mean. |
Source code in treex/metrics/mean.py
def update(
self,
values: jnp.ndarray,
sample_weight: tp.Optional[jnp.ndarray] = None,
):
"""
Accumulates the mean statistic over various batches.
Arguments:
values: Per-example value.
sample_weight: Optional weighting of each example.
Returns:
Array with the cumulative mean.
"""
super().update(
values=values,
sample_weight=sample_weight,
)