treex.metrics.MeanAbsoluteError
Source code in treex/metrics/mean_absolute_error.py
class MeanAbsoluteError(Mean):
def __init__(
self,
on: tp.Optional[types.IndexLike] = None,
name: tp.Optional[str] = None,
dtype: tp.Optional[jnp.dtype] = None,
):
"""
`Computes Mean Absolute Error`_ (MAE):
.. math:: \text{MAE} = \frac{1}{N}\sum_i^N | y_i - \hat{y_i} |
Where :math:`y` is a tensor of target values, and :math:`\hat{y}` is a tensor of predictions.
Args:
on:
A string or integer, or iterable of string or integers, that
indicate how to index/filter the `target` and `preds`
arguments before passing them to `call`. For example if `on = "a"` then
`target = target["a"]`. If `on` is an iterable
the structures will be indexed iteratively, for example if `on = ["a", 0, "b"]`
then `target = target["a"][0]["b"]`, same for `preds`. For more information
check out [Keras-like behavior](https://poets-ai.github.io/elegy/guides/modules-losses-metrics/#keras-like-behavior).
name:
Module name
dtype:
Metrics states initialization dtype
Example:
>>> import jax.numpy as jnp
>>> from treex.metrics.mean_absolute_error import MeanAbsolutError
>>> target = jnp.array([3.0, -0.5, 2.0, 7.0])
>>> preds = jnp.array([3.0, -0.5, 2.0, 7.0])
>>> mae = MeanAbsolutError()
>>> mae(preds, target)
"""
super().__init__(on=on, name=name, dtype=dtype)
def update(
self,
target: jnp.ndarray,
preds: jnp.ndarray,
sample_weight: jnp.ndarray = None,
) -> tp.Any:
"""
Accumulates metric statistics. `target` and `preds` should have the same shape.
Arguments:
target:
Ground truth values. shape = `[batch_size, d0, .. dN]`.
preds:
The predicted values. shape = `[batch_size, d0, .. dN]`
sample_weight:
Optional weighting of each example. Defaults to 1. shape = `[batch_size, d0, .. dN]`
Returns:
Array with the cumulative mean absolute error.
"""
values = _mean_absolute_error(preds, target)
return super().update(values, sample_weight)
__call__(self, target, preds, sample_weight=None)
special
Accumulates metric statistics. target
and preds
should have the same shape.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
target |
ndarray |
Ground truth values. shape = |
required |
preds |
ndarray |
The predicted values. shape = |
required |
sample_weight |
ndarray |
Optional weighting of each example. Defaults to 1. shape = |
None |
Returns:
Type | Description |
---|---|
Any |
Array with the cumulative mean absolute error. |
Source code in treex/metrics/mean_absolute_error.py
def update(
self,
target: jnp.ndarray,
preds: jnp.ndarray,
sample_weight: jnp.ndarray = None,
) -> tp.Any:
"""
Accumulates metric statistics. `target` and `preds` should have the same shape.
Arguments:
target:
Ground truth values. shape = `[batch_size, d0, .. dN]`.
preds:
The predicted values. shape = `[batch_size, d0, .. dN]`
sample_weight:
Optional weighting of each example. Defaults to 1. shape = `[batch_size, d0, .. dN]`
Returns:
Array with the cumulative mean absolute error.
"""
values = _mean_absolute_error(preds, target)
return super().update(values, sample_weight)
__init__(self, on=None, name=None, dtype=None)
special
Computes Mean Absolute Error
_ (MAE):
.. math:: ext{MAE} = rac{1}{N}\sum_i^N | y_i - \hat{y_i} |
Where :math:y
is a tensor of target values, and :math:\hat{y}
is a tensor of predictions.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
on |
Union[str, int, Sequence[Union[str, int]]] |
A string or integer, or iterable of string or integers, that
indicate how to index/filter the |
None |
name |
Optional[str] |
Module name |
None |
dtype |
Optional[numpy.dtype] |
Metrics states initialization dtype |
None |
Examples:
import jax.numpy as jnp from treex.metrics.mean_absolute_error import MeanAbsolutError
target = jnp.array([3.0, -0.5, 2.0, 7.0]) preds = jnp.array([3.0, -0.5, 2.0, 7.0])
mae = MeanAbsolutError() mae(preds, target)
Source code in treex/metrics/mean_absolute_error.py
def __init__(
self,
on: tp.Optional[types.IndexLike] = None,
name: tp.Optional[str] = None,
dtype: tp.Optional[jnp.dtype] = None,
):
"""
`Computes Mean Absolute Error`_ (MAE):
.. math:: \text{MAE} = \frac{1}{N}\sum_i^N | y_i - \hat{y_i} |
Where :math:`y` is a tensor of target values, and :math:`\hat{y}` is a tensor of predictions.
Args:
on:
A string or integer, or iterable of string or integers, that
indicate how to index/filter the `target` and `preds`
arguments before passing them to `call`. For example if `on = "a"` then
`target = target["a"]`. If `on` is an iterable
the structures will be indexed iteratively, for example if `on = ["a", 0, "b"]`
then `target = target["a"][0]["b"]`, same for `preds`. For more information
check out [Keras-like behavior](https://poets-ai.github.io/elegy/guides/modules-losses-metrics/#keras-like-behavior).
name:
Module name
dtype:
Metrics states initialization dtype
Example:
>>> import jax.numpy as jnp
>>> from treex.metrics.mean_absolute_error import MeanAbsolutError
>>> target = jnp.array([3.0, -0.5, 2.0, 7.0])
>>> preds = jnp.array([3.0, -0.5, 2.0, 7.0])
>>> mae = MeanAbsolutError()
>>> mae(preds, target)
"""
super().__init__(on=on, name=name, dtype=dtype)
update(self, target, preds, sample_weight=None)
Accumulates metric statistics. target
and preds
should have the same shape.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
target |
ndarray |
Ground truth values. shape = |
required |
preds |
ndarray |
The predicted values. shape = |
required |
sample_weight |
ndarray |
Optional weighting of each example. Defaults to 1. shape = |
None |
Returns:
Type | Description |
---|---|
Any |
Array with the cumulative mean absolute error. |
Source code in treex/metrics/mean_absolute_error.py
def update(
self,
target: jnp.ndarray,
preds: jnp.ndarray,
sample_weight: jnp.ndarray = None,
) -> tp.Any:
"""
Accumulates metric statistics. `target` and `preds` should have the same shape.
Arguments:
target:
Ground truth values. shape = `[batch_size, d0, .. dN]`.
preds:
The predicted values. shape = `[batch_size, d0, .. dN]`
sample_weight:
Optional weighting of each example. Defaults to 1. shape = `[batch_size, d0, .. dN]`
Returns:
Array with the cumulative mean absolute error.
"""
values = _mean_absolute_error(preds, target)
return super().update(values, sample_weight)