Skip to content

treex.Metric

Encapsulates metric logic and state. Metrics accumulate state between calls such that their output value reflect the metric as if calculated on the whole data given up to that point.

__init__(self, on=None, name=None, dtype=None) special

Parameters:

Name Type Description Default
on Union[str, int, Sequence[Union[str, int]]]

A string or integer, or iterable of string or integers, that indicate how to index/filter the target and preds arguments before passing them to call. For example if on = "a" then target = target["a"]. If on is an iterable the structures will be indexed iteratively, for example if on = ["a", 0, "b"] then target = target["a"][0]["b"], same for preds. For more information check out Keras-like behavior.

None
Source code in treex/metrics/metric.py
def __init__(
    self,
    on: tp.Optional[types.IndexLike] = None,
    name: tp.Optional[str] = None,
    dtype: tp.Optional[jnp.dtype] = None,
):
    """
    Arguments:
        on: A string or integer, or iterable of string or integers, that
            indicate how to index/filter the `target` and `preds`
            arguments before passing them to `call`. For example if `on = "a"` then
            `target = target["a"]`. If `on` is an iterable
            the structures will be indexed iteratively, for example if `on = ["a", 0, "b"]`
            then `target = target["a"][0]["b"]`, same for `preds`. For more information
            check out [Keras-like behavior](https://poets-ai.github.io/elegy/guides/modules-losses-metrics/#keras-like-behavior).
    """

    self._labels_filter = (on,) if isinstance(on, (str, int)) else on
    self.name = name if name is not None else utils._get_name(self)
    self.dtype = dtype if dtype is not None else jnp.float32

__init_subclass__() classmethod special

This method is called when a class is subclassed.

The default implementation does nothing. It may be overridden to extend subclasses.

Source code in treex/metrics/metric.py
def __init_subclass__(cls):
    super().__init_subclass__()

    # add call signature
    old_call = cls.__call__

    @functools.wraps(cls.update)
    def new_call(self: M, *args, **kwargs) -> M:
        if len(args) > 0:
            raise TypeError(
                f"All arguments to {cls.__name__}.__call__ should be passed as keyword arguments."
            )

        return old_call(self, *args, **kwargs)

    cls.__call__ = new_call