treex.metrics.Accuracy
Computes Accuracy_:
.. math:: \text{Accuracy} = \frac{1}{N}\sum_i^N 1(y_i = \hat{y}_i)
Where :math:y
is a tensor of target values, and :math:\hat{y}
is a
tensor of predictions.
For multi-class and multi-dimensional multi-class data with probability or logits predictions, the
parameter top_k
generalizes this metric to a Top-K accuracy metric: for each sample the
top-K highest probability or logit score items are considered to find the correct label.
For multi-label and multi-dimensional multi-class inputs, this metric computes the "glob"
accuracy by default, which counts all target or sub-samples separately. This can be
changed to subset accuracy (which requires all target or sub-samples in the sample to
be correctly predicted) by setting subset_accuracy=True
.
Accepts all input types listed in :ref:references/modules:input types
.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
num_classes |
Optional[int] |
Number of classes. Necessary for |
None |
threshold |
float |
Threshold for transforming probability or logit predictions to binary (0,1) predictions, in the case of binary or multi-label inputs. Default value of 0.5 corresponds to input being probabilities. |
0.5 |
average |
Union[str, treex.metrics.utils.AverageMethod] |
Defines the reduction that is applied. Should be one of the following:
.. note:: What is considered a sample in the multi-dimensional multi-class case
depends on the value of .. note:: If |
<AverageMethod.MICRO: 1> |
mdmc_average |
Union[str, treex.metrics.utils.MDMCAverageMethod] |
Defines how averaging is done for multi-dimensional multi-class inputs (on top of the
|
<MDMCAverageMethod.GLOBAL: 1> |
ignore_index |
Optional[int] |
Integer specifying a target class to ignore. If given, this class index does not contribute
to the returned score, regardless of reduction method. If an index is ignored, and |
None |
top_k |
Optional[int] |
Number of highest probability or logit score predictions considered to find the correct label,
relevant only for (multi-dimensional) multi-class inputs. The
default value ( Should be left at default ( |
None |
multiclass |
Optional[bool] |
Used only in certain special cases, where you want to treat inputs as a different type
than what they appear to be. See the parameter's
:ref: |
None |
subset_accuracy |
bool |
Whether to compute subset accuracy for multi-label and multi-dimensional multi-class inputs (has no effect for other input types).
|
False |
compute_on_step |
bool |
Forward only calls |
True |
dist_sync_on_step |
bool |
Synchronize metric state across processes at each |
False |
process_group |
Optional[Any] |
Specify the process group on which synchronization is called.
default: |
None |
dist_sync_fn |
Callable |
Callback that performs the allgather operation on the metric state. When |
None |
Examples:
>>> import torch
>>> from torchmetrics import Accuracy
>>> target = torch.tensor([0, 1, 2, 3])
>>> preds = torch.tensor([0, 2, 1, 3])
>>> accuracy = Accuracy()
>>> accuracy(preds, target)
tensor(0.5000)
>>> target = torch.tensor([0, 1, 2])
>>> preds = torch.tensor([[0.1, 0.9, 0], [0.3, 0.1, 0.6], [0.2, 0.5, 0.3]])
>>> accuracy = Accuracy(top_k=2)
>>> accuracy(preds, target)
tensor(0.6667)
Source code in treex/metrics/accuracy.py
class Accuracy(Metric):
r"""
Computes Accuracy_:
.. math::
\text{Accuracy} = \frac{1}{N}\sum_i^N 1(y_i = \hat{y}_i)
Where :math:`y` is a tensor of target values, and :math:`\hat{y}` is a
tensor of predictions.
For multi-class and multi-dimensional multi-class data with probability or logits predictions, the
parameter ``top_k`` generalizes this metric to a Top-K accuracy metric: for each sample the
top-K highest probability or logit score items are considered to find the correct label.
For multi-label and multi-dimensional multi-class inputs, this metric computes the "glob"
accuracy by default, which counts all target or sub-samples separately. This can be
changed to subset accuracy (which requires all target or sub-samples in the sample to
be correctly predicted) by setting ``subset_accuracy=True``.
Accepts all input types listed in :ref:`references/modules:input types`.
Args:
num_classes:
Number of classes. Necessary for ``'macro'``, ``'weighted'`` and ``None`` average methods.
threshold:
Threshold for transforming probability or logit predictions to binary (0,1) predictions, in the case
of binary or multi-label inputs. Default value of 0.5 corresponds to input being probabilities.
average:
Defines the reduction that is applied. Should be one of the following:
- ``'micro'`` [default]: Calculate the metric globally, across all samples and classes.
- ``'macro'``: Calculate the metric for each class separately, and average the
metrics across classes (with equal weights for each class).
- ``'weighted'``: Calculate the metric for each class separately, and average the
metrics across classes, weighting each class by its support (``tp + fn``).
- ``'none'`` or ``None``: Calculate the metric for each class separately, and return
the metric for every class.
- ``'samples'``: Calculate the metric for each sample, and average the metrics
across samples (with equal weights for each sample).
.. note:: What is considered a sample in the multi-dimensional multi-class case
depends on the value of ``mdmc_average``.
.. note:: If ``'none'`` and a given class doesn't occur in the `preds` or `target`,
the value for the class will be ``nan``.
mdmc_average:
Defines how averaging is done for multi-dimensional multi-class inputs (on top of the
``average`` parameter). Should be one of the following:
- ``None`` [default]: Should be left unchanged if your data is not multi-dimensional
multi-class.
- ``'samplewise'``: In this case, the statistics are computed separately for each
sample on the ``N`` axis, and then averaged over samples.
The computation for each sample is done by treating the flattened extra axes ``...``
(see :ref:`references/modules:input types`) as the ``N`` dimension within the sample,
and computing the metric for the sample based on that.
- ``'global'``: In this case the ``N`` and ``...`` dimensions of the inputs
(see :ref:`references/modules:input types`)
are flattened into a new ``N_X`` sample axis, i.e. the inputs are treated as if they
were ``(N_X, C)``. From here on the ``average`` parameter applies as usual.
ignore_index:
Integer specifying a target class to ignore. If given, this class index does not contribute
to the returned score, regardless of reduction method. If an index is ignored, and ``average=None``
or ``'none'``, the score for the ignored class will be returned as ``nan``.
top_k:
Number of highest probability or logit score predictions considered to find the correct label,
relevant only for (multi-dimensional) multi-class inputs. The
default value (``None``) will be interpreted as 1 for these inputs.
Should be left at default (``None``) for all other types of inputs.
multiclass:
Used only in certain special cases, where you want to treat inputs as a different type
than what they appear to be. See the parameter's
:ref:`documentation section <references/modules:using the multiclass parameter>`
for a more detailed explanation and examples.
subset_accuracy:
Whether to compute subset accuracy for multi-label and multi-dimensional
multi-class inputs (has no effect for other input types).
- For multi-label inputs, if the parameter is set to ``True``, then all target for
each sample must be correctly predicted for the sample to count as correct. If it
is set to ``False``, then all target are counted separately - this is equivalent to
flattening inputs beforehand (i.e. ``preds = preds.flatten()`` and same for ``target``).
- For multi-dimensional multi-class inputs, if the parameter is set to ``True``, then all
sub-sample (on the extra axis) must be correct for the sample to be counted as correct.
If it is set to ``False``, then all sub-samples are counter separately - this is equivalent,
in the case of label predictions, to flattening the inputs beforehand (i.e.
``preds = preds.flatten()`` and same for ``target``). Note that the ``top_k`` parameter
still applies in both cases, if set.
compute_on_step:
Forward only calls ``update()`` and return ``None`` if this is set to ``False``.
dist_sync_on_step:
Synchronize metric state across processes at each ``forward()``
before returning the value at the step
process_group:
Specify the process group on which synchronization is called.
default: ``None`` (which selects the entire world)
dist_sync_fn:
Callback that performs the allgather operation on the metric state. When ``None``, DDP
will be used to perform the allgather
Raises:
ValueError:
If ``top_k`` is not an ``integer`` larger than ``0``.
ValueError:
If ``average`` is none of ``"micro"``, ``"macro"``, ``"weighted"``, ``"samples"``, ``"none"``, ``None``.
ValueError:
If two different input modes are provided, eg. using ``multi-label`` with ``multi-class``.
ValueError:
If ``top_k`` parameter is set for ``multi-label`` inputs.
Example:
>>> import torch
>>> from torchmetrics import Accuracy
>>> target = torch.tensor([0, 1, 2, 3])
>>> preds = torch.tensor([0, 2, 1, 3])
>>> accuracy = Accuracy()
>>> accuracy(preds, target)
tensor(0.5000)
>>> target = torch.tensor([0, 1, 2])
>>> preds = torch.tensor([[0.1, 0.9, 0], [0.3, 0.1, 0.6], [0.2, 0.5, 0.3]])
>>> accuracy = Accuracy(top_k=2)
>>> accuracy(preds, target)
tensor(0.6667)
"""
tp: jnp.ndarray = types.MetricState.node()
fp: jnp.ndarray = types.MetricState.node()
tn: jnp.ndarray = types.MetricState.node()
fn: jnp.ndarray = types.MetricState.node()
def __init__(
self,
threshold: float = 0.5,
num_classes: typing.Optional[int] = None,
average: typing.Union[str, AverageMethod] = AverageMethod.MICRO,
mdmc_average: typing.Union[str, MDMCAverageMethod] = MDMCAverageMethod.GLOBAL,
ignore_index: typing.Optional[int] = None,
top_k: typing.Optional[int] = None,
multiclass: typing.Optional[bool] = None,
subset_accuracy: bool = False,
compute_on_step: bool = True,
dist_sync_on_step: bool = False,
process_group: typing.Optional[typing.Any] = None,
dist_sync_fn: typing.Callable = None,
mode: DataType = DataType.MULTICLASS,
on: typing.Optional[types.IndexLike] = None,
name: typing.Optional[str] = None,
dtype: typing.Optional[jnp.dtype] = None,
):
super().__init__(on=on, name=name, dtype=dtype)
if isinstance(average, str):
average = AverageMethod[average.upper()]
if isinstance(mdmc_average, str):
mdmc_average = MDMCAverageMethod[mdmc_average.upper()]
average = (
AverageMethod.MACRO
if average in [AverageMethod.WEIGHTED, AverageMethod.NONE]
else average
)
if average not in [
AverageMethod.MICRO,
AverageMethod.MACRO,
# AverageMethod.SAMPLES,
]:
raise ValueError(f"The `reduce` {average} is not valid.")
if average == AverageMethod.MACRO and (not num_classes or num_classes < 1):
raise ValueError(
"When you set `reduce` as 'macro', you have to provide the number of classes."
)
if top_k is not None and (not isinstance(top_k, int) or top_k <= 0):
raise ValueError(
f"The `top_k` should be an integer larger than 0, got {top_k}"
)
if (
num_classes
and ignore_index is not None
and (not 0 <= ignore_index < num_classes or num_classes == 1)
):
raise ValueError(
f"The `ignore_index` {ignore_index} is not valid for inputs with {num_classes} classes"
)
# Update states
if average == AverageMethod.SAMPLES:
raise ValueError(f"The `average` method '{average}' is not yet supported.")
if mdmc_average == MDMCAverageMethod.SAMPLEWISE:
raise ValueError(
f"The `mdmc_average` method '{mdmc_average}' is not yet supported."
)
self.average = average
self.mdmc_average = mdmc_average
self.num_classes = num_classes
self.threshold = threshold
self.multiclass = multiclass
self.ignore_index = ignore_index
self.top_k = top_k
self.subset_accuracy = subset_accuracy
self.mode = mode
# nodes
if average == AverageMethod.MICRO:
zeros_shape = []
elif average == AverageMethod.MACRO:
zeros_shape = [num_classes]
else:
raise ValueError(f'Wrong reduce="{average}"')
initial_value = jnp.zeros(zeros_shape, dtype=jnp.uint32)
self.tp = initial_value
self.fp = initial_value
self.tn = initial_value
self.fn = initial_value
def update(self, preds: jnp.ndarray, target: jnp.ndarray) -> None: # type: ignore
"""Update state with predictions and targets. See
:ref:`references/modules:input types` for more information on input
types.
Args:
preds: Predictions from model (logits, probabilities, or target)
target: Ground truth target
"""
tp, fp, tn, fn = metric_utils._stat_scores_update(
preds,
target,
intended_mode=self.mode,
average_method=self.average,
mdmc_average_method=self.mdmc_average,
threshold=self.threshold,
num_classes=self.num_classes,
top_k=self.top_k,
multiclass=self.multiclass,
)
self.tp += tp
self.fp += fp
self.tn += tn
self.fn += fn
def compute(self) -> jnp.ndarray:
"""Computes accuracy based on inputs passed in to ``update`` previously."""
# if self.mode is None:
# raise RuntimeError("You have to have determined mode.")
return metric_utils._accuracy_compute(
self.tp,
self.fp,
self.tn,
self.fn,
self.average,
self.mdmc_average,
self.mode,
)
__call__(self, preds, target)
special
Update state with predictions and targets. See
:ref:references/modules:input types
for more information on input
types.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
preds |
ndarray |
Predictions from model (logits, probabilities, or target) |
required |
target |
ndarray |
Ground truth target |
required |
Source code in treex/metrics/accuracy.py
def update(self, preds: jnp.ndarray, target: jnp.ndarray) -> None: # type: ignore
"""Update state with predictions and targets. See
:ref:`references/modules:input types` for more information on input
types.
Args:
preds: Predictions from model (logits, probabilities, or target)
target: Ground truth target
"""
tp, fp, tn, fn = metric_utils._stat_scores_update(
preds,
target,
intended_mode=self.mode,
average_method=self.average,
mdmc_average_method=self.mdmc_average,
threshold=self.threshold,
num_classes=self.num_classes,
top_k=self.top_k,
multiclass=self.multiclass,
)
self.tp += tp
self.fp += fp
self.tn += tn
self.fn += fn
compute(self)
Computes accuracy based on inputs passed in to update
previously.
Source code in treex/metrics/accuracy.py
def compute(self) -> jnp.ndarray:
"""Computes accuracy based on inputs passed in to ``update`` previously."""
# if self.mode is None:
# raise RuntimeError("You have to have determined mode.")
return metric_utils._accuracy_compute(
self.tp,
self.fp,
self.tn,
self.fn,
self.average,
self.mdmc_average,
self.mode,
)
update(self, preds, target)
Update state with predictions and targets. See
:ref:references/modules:input types
for more information on input
types.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
preds |
ndarray |
Predictions from model (logits, probabilities, or target) |
required |
target |
ndarray |
Ground truth target |
required |
Source code in treex/metrics/accuracy.py
def update(self, preds: jnp.ndarray, target: jnp.ndarray) -> None: # type: ignore
"""Update state with predictions and targets. See
:ref:`references/modules:input types` for more information on input
types.
Args:
preds: Predictions from model (logits, probabilities, or target)
target: Ground truth target
"""
tp, fp, tn, fn = metric_utils._stat_scores_update(
preds,
target,
intended_mode=self.mode,
average_method=self.average,
mdmc_average_method=self.mdmc_average,
threshold=self.threshold,
num_classes=self.num_classes,
top_k=self.top_k,
multiclass=self.multiclass,
)
self.tp += tp
self.fp += fp
self.tn += tn
self.fn += fn