Skip to content

treex.Metric

Encapsulates metric logic and state. Metrics accumulate state between calls such that their output value reflect the metric as if calculated on the whole data given up to that point.

Source code in treex/metrics/metric.py
class Metric(Treex, metaclass=MetricMeta):
    """
    Encapsulates metric logic and state. Metrics accumulate state between calls such
    that their output value reflect the metric as if calculated on the whole data
    given up to that point.
    """

    _initial_state: tp.Dict[str, tp.Any] = types.MetricState.node()

    def __init__(
        self,
        on: tp.Optional[types.IndexLike] = None,
        name: tp.Optional[str] = None,
        dtype: tp.Optional[jnp.dtype] = None,
    ):
        """
        Arguments:
            on: A string or integer, or iterable of string or integers, that
                indicate how to index/filter the `target` and `preds`
                arguments before passing them to `call`. For example if `on = "a"` then
                `target = target["a"]`. If `on` is an iterable
                the structures will be indexed iteratively, for example if `on = ["a", 0, "b"]`
                then `target = target["a"][0]["b"]`, same for `preds`. For more information
                check out [Keras-like behavior](https://poets-ai.github.io/elegy/guides/modules-losses-metrics/#keras-like-behavior).
        """

        self._labels_filter = (on,) if isinstance(on, (str, int)) else on
        self.name = name if name is not None else utils._get_name(self)
        self.dtype = dtype if dtype is not None else jnp.float32

    def __call__(self, **kwargs) -> tp.Any:
        if self._labels_filter is not None:
            if "target" in kwargs and kwargs["target"] is not None:
                for index in self._labels_filter:
                    kwargs["target"] = kwargs["target"][index]

            if "preds" in kwargs and kwargs["preds"] is not None:
                for index in self._labels_filter:
                    kwargs["preds"] = kwargs["preds"][index]

        # update cumulative state
        self.update(**kwargs)

        # compute batch metrics
        module = to.copy(self)
        module.reset()
        module.update(**kwargs)
        return module.compute()

    def reset(self):
        def do_reset(metric):
            if isinstance(metric, Metric):
                metric.__dict__.update(to.copy(metric._initial_state))

        self.apply(do_reset, inplace=True)

    @abstractmethod
    def update(self, **kwargs) -> None:
        ...

    @abstractmethod
    def compute(self) -> tp.Any:
        ...

    def __init_subclass__(cls):
        super().__init_subclass__()

        # add call signature
        old_call = cls.__call__

        @functools.wraps(cls.update)
        def new_call(self: M, *args, **kwargs) -> M:
            if len(args) > 0:
                raise TypeError(
                    f"All arguments to {cls.__name__}.__call__ should be passed as keyword arguments."
                )

            return old_call(self, *args, **kwargs)

        cls.__call__ = new_call

__init__(self, on=None, name=None, dtype=None) special

Parameters:

Name Type Description Default
on Union[str, int, Sequence[Union[str, int]]]

A string or integer, or iterable of string or integers, that indicate how to index/filter the target and preds arguments before passing them to call. For example if on = "a" then target = target["a"]. If on is an iterable the structures will be indexed iteratively, for example if on = ["a", 0, "b"] then target = target["a"][0]["b"], same for preds. For more information check out Keras-like behavior.

None
Source code in treex/metrics/metric.py
def __init__(
    self,
    on: tp.Optional[types.IndexLike] = None,
    name: tp.Optional[str] = None,
    dtype: tp.Optional[jnp.dtype] = None,
):
    """
    Arguments:
        on: A string or integer, or iterable of string or integers, that
            indicate how to index/filter the `target` and `preds`
            arguments before passing them to `call`. For example if `on = "a"` then
            `target = target["a"]`. If `on` is an iterable
            the structures will be indexed iteratively, for example if `on = ["a", 0, "b"]`
            then `target = target["a"][0]["b"]`, same for `preds`. For more information
            check out [Keras-like behavior](https://poets-ai.github.io/elegy/guides/modules-losses-metrics/#keras-like-behavior).
    """

    self._labels_filter = (on,) if isinstance(on, (str, int)) else on
    self.name = name if name is not None else utils._get_name(self)
    self.dtype = dtype if dtype is not None else jnp.float32

__init_subclass__() classmethod special

This method is called when a class is subclassed.

The default implementation does nothing. It may be overridden to extend subclasses.

Source code in treex/metrics/metric.py
def __init_subclass__(cls):
    super().__init_subclass__()

    # add call signature
    old_call = cls.__call__

    @functools.wraps(cls.update)
    def new_call(self: M, *args, **kwargs) -> M:
        if len(args) > 0:
            raise TypeError(
                f"All arguments to {cls.__name__}.__call__ should be passed as keyword arguments."
            )

        return old_call(self, *args, **kwargs)

    cls.__call__ = new_call