Skip to content

treex.metrics.AuxMetrics

Source code in treex/metrics/metrics.py
class AuxMetrics(Metric):
    totals: tp.Dict[str, jnp.ndarray] = types.MetricState.node()
    counts: tp.Dict[str, jnp.ndarray] = types.MetricState.node()

    def __init__(
        self,
        aux_metrics: tp.Any,
        on: tp.Optional[types.IndexLike] = None,
        name: tp.Optional[str] = None,
        dtype: tp.Optional[jnp.dtype] = None,
    ):
        super().__init__(on=on, name=name, dtype=dtype)
        logs = self.as_logs(aux_metrics)
        self.totals = {name: jnp.array(0.0, dtype=jnp.float32) for name in logs}
        self.counts = {name: jnp.array(0, dtype=jnp.uint32) for name in logs}

    def update(self, aux_metrics: tp.Any) -> None:
        logs = self.as_logs(aux_metrics)

        self.totals = {
            name: (self.totals[name] + logs[name]).astype(self.totals[name].dtype)
            for name in self.totals
        }
        self.counts = {
            name: (self.counts[name] + np.prod(logs[name].shape)).astype(
                self.counts[name].dtype
            )
            for name in self.counts
        }

    def compute(self) -> tp.Dict[str, jnp.ndarray]:
        return {name: self.totals[name] / self.counts[name] for name in self.totals}

    def __call__(self, aux_metrics: tp.Any) -> tp.Dict[str, jnp.ndarray]:
        return super().__call__(aux_metrics=aux_metrics)

    @staticmethod
    def metric_name(field_info: to.FieldInfo) -> str:
        return (
            field_info.value.name
            if isinstance(field_info.value, types.Named)
            else field_info.name
            if field_info.name is not None
            else "aux_metric"
        )

    def as_logs(self, tree: tp.Any) -> tp.Dict[str, jnp.ndarray]:

        names: tp.Set[str] = set()

        with to.add_field_info():
            fields_info: tp.List[to.FieldInfo] = jax.tree_flatten(
                tree,
                is_leaf=lambda x: isinstance(x, types.Named)
                and not isinstance(x.value, to.Nothing),
            )[0]

        # pretend Named values are leaves
        for i, x in enumerate(fields_info):
            if isinstance(x, types.Named):
                field_info = x.value
                field_info.value = types.Named(x.name, field_info.value)
                fields_info[i] = field_info

        metrics = {
            self.metric_name(field_info): field_info.value.value
            if isinstance(field_info.value, types.Named)
            else field_info.value
            for field_info in fields_info
        }
        metrics = {
            utils._unique_name(names, name): value for name, value in metrics.items()
        }

        return metrics