Skip to content

treex.metrics.Metrics

Source code in treex/metrics/metrics.py
class Metrics(Metric):
    metrics: tp.Dict[str, Metric]

    def __init__(
        self,
        metrics: tp.Any,
        on: tp.Optional[types.IndexLike] = None,
        name: tp.Optional[str] = None,
        dtype: tp.Optional[jnp.dtype] = None,
    ):
        super().__init__(on=on, name=name, dtype=dtype)

        names: tp.Set[str] = set()

        def get_name(path, metric):
            name = utils._get_name(metric)
            return f"{path}/{name}" if path else name

        self.metrics = {
            utils._unique_name(names, get_name(path, metric)): metric
            for path, metric in utils._flatten_names(metrics)
        }

    def update(self, **kwargs) -> None:
        for name, metric in self.metrics.items():
            arg_names = utils._function_argument_names(metric.update)

            if arg_names is None:
                metric_kwargs = kwargs
            else:
                metric_kwargs = {arg: kwargs[arg] for arg in arg_names if arg in kwargs}

            metric.update(**metric_kwargs)

    def compute(self) -> tp.Dict[str, jnp.ndarray]:
        outputs = {}
        names = set()

        for name, metric in self.metrics.items():

            value = metric.compute()

            for path, value in utils._flatten_names(value):
                name = f"{name}/{path}" if path else name
                name = utils._unique_name(names, name)

                outputs[name] = value

        return outputs

    def __call__(self, **kwargs) -> tp.Dict[str, jnp.ndarray]:
        return super().__call__(**kwargs)