treex.losses.Loss
Loss base class.
To be implemented by subclasses:
call()
: Contains the logic for loss calculation.
Example subclass implementation:
class MeanSquaredError(Loss):
def call(self, target, preds):
return jnp.mean(jnp.square(preds - target), axis=-1)
Please see the [Modules, Losses, and Metrics Guide] (https://poets-ai.github.io/elegy/guides/modules-losses-metrics/#losses) for more details on this.
Source code in treex/losses/loss.py
class Loss:
"""
Loss base class.
To be implemented by subclasses:
* `call()`: Contains the logic for loss calculation.
Example subclass implementation:
```python
class MeanSquaredError(Loss):
def call(self, target, preds):
return jnp.mean(jnp.square(preds - target), axis=-1)
```
Please see the [Modules, Losses, and Metrics Guide]
(https://poets-ai.github.io/elegy/guides/modules-losses-metrics/#losses) for more
details on this.
"""
def __init__(
self,
reduction: tp.Optional[Reduction] = None,
weight: tp.Optional[types.ScalarLike] = None,
on: tp.Optional[types.IndexLike] = None,
name: tp.Optional[str] = None,
):
"""
Initializes `Loss` class.
Arguments:
reduction: (Optional) Type of `tx.losses.Reduction` to apply to
loss. Default value is `SUM_OVER_BATCH_SIZE`. For almost all cases
this defaults to `SUM_OVER_BATCH_SIZE`.
weight: Optional weight contribution for the total loss. Defaults to `1`.
on: A string or integer, or iterable of string or integers, that
indicate how to index/filter the `target` and `preds`
arguments before passing them to `call`. For example if `on = "a"` then
`target = target["a"]`. If `on` is an iterable
the structures will be indexed iteratively, for example if `on = ["a", 0, "b"]`
then `target = target["a"][0]["b"]`, same for `preds`. For more information
check out [Keras-like behavior](https://poets-ai.github.io/elegy/guides/modules-losses-metrics/#keras-like-behavior).
name: Optional name for the instance, if not provided lower snake_case version
of the name of the class is used instead.
"""
self.name = name if name is not None else utils._get_name(self)
self.weight = (
jnp.asarray(weight, dtype=jnp.float32)
if weight is not None
else jnp.array(1.0, dtype=jnp.float32)
)
self._reduction = (
reduction if reduction is not None else Reduction.SUM_OVER_BATCH_SIZE
)
self._labels_filter = (on,) if isinstance(on, (str, int)) else on
self._signature_f = self.call
def __call__(
self,
**kwargs,
) -> jnp.ndarray:
if self._labels_filter is not None:
if "target" in kwargs and kwargs["target"] is not None:
for index in self._labels_filter:
kwargs["target"] = kwargs["target"][index]
if "preds" in kwargs and kwargs["preds"] is not None:
for index in self._labels_filter:
kwargs["preds"] = kwargs["preds"][index]
sample_weight: tp.Optional[jnp.ndarray] = kwargs.pop("sample_weight", None)
values = self.call(**kwargs)
return reduce_loss(values, sample_weight, self.weight, self._reduction)
@abstractmethod
def call(self, *args, **kwargs) -> jnp.ndarray:
...
__init__(self, reduction=None, weight=None, on=None, name=None)
special
Initializes Loss
class.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
reduction |
Optional[treex.losses.loss.Reduction] |
(Optional) Type of |
None |
weight |
Union[float, numpy.ndarray, jax._src.numpy.lax_numpy.ndarray] |
Optional weight contribution for the total loss. Defaults to |
None |
on |
Union[str, int, Sequence[Union[str, int]]] |
A string or integer, or iterable of string or integers, that
indicate how to index/filter the |
None |
name |
Optional[str] |
Optional name for the instance, if not provided lower snake_case version of the name of the class is used instead. |
None |
Source code in treex/losses/loss.py
def __init__(
self,
reduction: tp.Optional[Reduction] = None,
weight: tp.Optional[types.ScalarLike] = None,
on: tp.Optional[types.IndexLike] = None,
name: tp.Optional[str] = None,
):
"""
Initializes `Loss` class.
Arguments:
reduction: (Optional) Type of `tx.losses.Reduction` to apply to
loss. Default value is `SUM_OVER_BATCH_SIZE`. For almost all cases
this defaults to `SUM_OVER_BATCH_SIZE`.
weight: Optional weight contribution for the total loss. Defaults to `1`.
on: A string or integer, or iterable of string or integers, that
indicate how to index/filter the `target` and `preds`
arguments before passing them to `call`. For example if `on = "a"` then
`target = target["a"]`. If `on` is an iterable
the structures will be indexed iteratively, for example if `on = ["a", 0, "b"]`
then `target = target["a"][0]["b"]`, same for `preds`. For more information
check out [Keras-like behavior](https://poets-ai.github.io/elegy/guides/modules-losses-metrics/#keras-like-behavior).
name: Optional name for the instance, if not provided lower snake_case version
of the name of the class is used instead.
"""
self.name = name if name is not None else utils._get_name(self)
self.weight = (
jnp.asarray(weight, dtype=jnp.float32)
if weight is not None
else jnp.array(1.0, dtype=jnp.float32)
)
self._reduction = (
reduction if reduction is not None else Reduction.SUM_OVER_BATCH_SIZE
)
self._labels_filter = (on,) if isinstance(on, (str, int)) else on
self._signature_f = self.call