treex.metrics.Reduce
Encapsulates metrics that perform a reduce operation on the values.
Source code in treex/metrics/reduce.py
class Reduce(Metric):
"""Encapsulates metrics that perform a reduce operation on the values."""
total: jnp.ndarray = types.MetricState.node()
count: tp.Optional[jnp.ndarray] = types.MetricState.node()
def __init__(
self,
reduction: tp.Union[Reduction, str],
on: tp.Optional[types.IndexLike] = None,
name: tp.Optional[str] = None,
dtype: tp.Optional[jnp.dtype] = None,
):
super().__init__(on=on, name=name, dtype=dtype)
self.reduction = (
reduction if isinstance(reduction, Reduction) else Reduction[reduction]
)
# initialize states
self.total = jnp.array(0.0, dtype=self.dtype)
if self.reduction in (
Reduction.sum_over_batch_size,
Reduction.weighted_mean,
):
self.count = jnp.array(0, dtype=jnp.uint32)
else:
self.count = None
def update(
self,
values: jnp.ndarray,
sample_weight: tp.Optional[jnp.ndarray] = None,
):
"""
Accumulates statistics for computing the reduction metric. For example, if `values` is [1, 3, 5, 7]
and reduction=SUM_OVER_BATCH_SIZE, then the value of `result()` is 4. If the `sample_weight`
is specified as [1, 1, 0, 0] then value of `result()` would be 2.
Arguments:
values: Per-example value.
sample_weight: Optional weighting of each example. Defaults to 1.
Returns:
Array with the cumulative reduce.
"""
# perform update
if sample_weight is not None:
if sample_weight.ndim > values.ndim:
raise Exception(
f"sample_weight dimention is higher than values, when masking values sample_weight dimention needs to be equal or lower than values dimension, currently values have shape equal to {values.shape}"
)
try:
# Broadcast weights if possible.
sample_weight = jnp.broadcast_to(sample_weight, values.shape)
except ValueError:
# Reduce values to same ndim as weight array
values_ndim, weight_ndim = values.ndim, sample_weight.ndim
if self.reduction == Reduction.sum:
values = jnp.sum(values, axis=list(range(weight_ndim, values_ndim)))
else:
values = jnp.mean(
values, axis=list(range(weight_ndim, values_ndim))
)
values = values * sample_weight
value_sum = jnp.sum(values)
self.total = (self.total + value_sum).astype(self.total.dtype)
# Exit early if the reduction doesn't have a denominator.
if self.reduction == Reduction.sum:
num_values = None
# Update `count` for reductions that require a denominator.
elif self.reduction == Reduction.sum_over_batch_size:
num_values = np.prod(values.shape)
else:
if sample_weight is None:
num_values = np.prod(values.shape)
else:
num_values = jnp.sum(sample_weight)
if self.count is not None:
assert num_values is not None
self.count = (self.count + num_values).astype(self.count.dtype)
def compute(self) -> tp.Any:
if self.reduction == Reduction.sum:
return self.total
else:
return self.total / self.count
__call__(self, values, sample_weight=None)
special
Accumulates statistics for computing the reduction metric. For example, if values
is [1, 3, 5, 7]
and reduction=SUM_OVER_BATCH_SIZE, then the value of result()
is 4. If the sample_weight
is specified as [1, 1, 0, 0] then value of result()
would be 2.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
values |
ndarray |
Per-example value. |
required |
sample_weight |
Optional[jax._src.numpy.lax_numpy.ndarray] |
Optional weighting of each example. Defaults to 1. |
None |
Returns:
Type | Description |
---|---|
Array with the cumulative reduce. |
Source code in treex/metrics/reduce.py
def update(
self,
values: jnp.ndarray,
sample_weight: tp.Optional[jnp.ndarray] = None,
):
"""
Accumulates statistics for computing the reduction metric. For example, if `values` is [1, 3, 5, 7]
and reduction=SUM_OVER_BATCH_SIZE, then the value of `result()` is 4. If the `sample_weight`
is specified as [1, 1, 0, 0] then value of `result()` would be 2.
Arguments:
values: Per-example value.
sample_weight: Optional weighting of each example. Defaults to 1.
Returns:
Array with the cumulative reduce.
"""
# perform update
if sample_weight is not None:
if sample_weight.ndim > values.ndim:
raise Exception(
f"sample_weight dimention is higher than values, when masking values sample_weight dimention needs to be equal or lower than values dimension, currently values have shape equal to {values.shape}"
)
try:
# Broadcast weights if possible.
sample_weight = jnp.broadcast_to(sample_weight, values.shape)
except ValueError:
# Reduce values to same ndim as weight array
values_ndim, weight_ndim = values.ndim, sample_weight.ndim
if self.reduction == Reduction.sum:
values = jnp.sum(values, axis=list(range(weight_ndim, values_ndim)))
else:
values = jnp.mean(
values, axis=list(range(weight_ndim, values_ndim))
)
values = values * sample_weight
value_sum = jnp.sum(values)
self.total = (self.total + value_sum).astype(self.total.dtype)
# Exit early if the reduction doesn't have a denominator.
if self.reduction == Reduction.sum:
num_values = None
# Update `count` for reductions that require a denominator.
elif self.reduction == Reduction.sum_over_batch_size:
num_values = np.prod(values.shape)
else:
if sample_weight is None:
num_values = np.prod(values.shape)
else:
num_values = jnp.sum(sample_weight)
if self.count is not None:
assert num_values is not None
self.count = (self.count + num_values).astype(self.count.dtype)
update(self, values, sample_weight=None)
Accumulates statistics for computing the reduction metric. For example, if values
is [1, 3, 5, 7]
and reduction=SUM_OVER_BATCH_SIZE, then the value of result()
is 4. If the sample_weight
is specified as [1, 1, 0, 0] then value of result()
would be 2.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
values |
ndarray |
Per-example value. |
required |
sample_weight |
Optional[jax._src.numpy.lax_numpy.ndarray] |
Optional weighting of each example. Defaults to 1. |
None |
Returns:
Type | Description |
---|---|
Array with the cumulative reduce. |
Source code in treex/metrics/reduce.py
def update(
self,
values: jnp.ndarray,
sample_weight: tp.Optional[jnp.ndarray] = None,
):
"""
Accumulates statistics for computing the reduction metric. For example, if `values` is [1, 3, 5, 7]
and reduction=SUM_OVER_BATCH_SIZE, then the value of `result()` is 4. If the `sample_weight`
is specified as [1, 1, 0, 0] then value of `result()` would be 2.
Arguments:
values: Per-example value.
sample_weight: Optional weighting of each example. Defaults to 1.
Returns:
Array with the cumulative reduce.
"""
# perform update
if sample_weight is not None:
if sample_weight.ndim > values.ndim:
raise Exception(
f"sample_weight dimention is higher than values, when masking values sample_weight dimention needs to be equal or lower than values dimension, currently values have shape equal to {values.shape}"
)
try:
# Broadcast weights if possible.
sample_weight = jnp.broadcast_to(sample_weight, values.shape)
except ValueError:
# Reduce values to same ndim as weight array
values_ndim, weight_ndim = values.ndim, sample_weight.ndim
if self.reduction == Reduction.sum:
values = jnp.sum(values, axis=list(range(weight_ndim, values_ndim)))
else:
values = jnp.mean(
values, axis=list(range(weight_ndim, values_ndim))
)
values = values * sample_weight
value_sum = jnp.sum(values)
self.total = (self.total + value_sum).astype(self.total.dtype)
# Exit early if the reduction doesn't have a denominator.
if self.reduction == Reduction.sum:
num_values = None
# Update `count` for reductions that require a denominator.
elif self.reduction == Reduction.sum_over_batch_size:
num_values = np.prod(values.shape)
else:
if sample_weight is None:
num_values = np.prod(values.shape)
else:
num_values = jnp.sum(sample_weight)
if self.count is not None:
assert num_values is not None
self.count = (self.count + num_values).astype(self.count.dtype)