Skip to content

treex.metrics.AuxLosses

Source code in treex/metrics/losses.py
class AuxLosses(Metric):
    totals: tp.Dict[str, jnp.ndarray] = types.MetricState.node()
    counts: tp.Dict[str, jnp.ndarray] = types.MetricState.node()

    def __init__(
        self,
        aux_losses: tp.Any,
        on: tp.Optional[types.IndexLike] = None,
        name: tp.Optional[str] = None,
        dtype: tp.Optional[jnp.dtype] = None,
    ):
        super().__init__(on=on, name=name, dtype=dtype)
        logs = self.as_logs(aux_losses)
        self.totals = {name: jnp.array(0.0, dtype=jnp.float32) for name in logs}
        self.counts = {name: jnp.array(0, dtype=jnp.uint32) for name in logs}

    def update(self, aux_losses: tp.Any) -> None:
        logs = self.as_logs(aux_losses)

        self.totals = {
            name: (self.totals[name] + logs[name]).astype(jnp.float32)
            for name in self.totals
        }
        self.counts = {
            name: (self.counts[name] + 1).astype(dtype=jnp.uint32)
            for name in self.counts
        }

    def compute(self) -> tp.Tuple[jnp.ndarray, tp.Dict[str, jnp.ndarray]]:
        losses = {name: self.totals[name] / self.counts[name] for name in self.totals}
        total_loss = sum(losses.values(), jnp.array(0.0, dtype=jnp.float32))

        return total_loss, losses

    def __call__(
        self, aux_losses: tp.Any
    ) -> tp.Tuple[jnp.ndarray, tp.Dict[str, jnp.ndarray]]:
        return super().__call__(aux_losses=aux_losses)

    @staticmethod
    def loss_name(field_info: to.FieldInfo) -> str:
        return (
            field_info.value.name
            if isinstance(field_info.value, types.Named)
            else field_info.name
            if field_info.name is not None
            else "aux_loss"
        )

    def as_logs(self, tree: tp.Any) -> tp.Dict[str, jnp.ndarray]:

        names: tp.Set[str] = set()

        with to.add_field_info():
            fields_info: tp.List[to.FieldInfo] = jax.tree_flatten(
                tree,
                is_leaf=lambda x: isinstance(x, types.Named)
                and not isinstance(x.value, to.Nothing),
            )[0]

        # pretend Named values are leaves
        for i, x in enumerate(fields_info):
            if isinstance(x, types.Named):
                field_info = x.value
                field_info.value = types.Named(x.name, field_info.value)
                fields_info[i] = field_info

        losses = {
            self.loss_name(field_info): field_info.value.value
            if isinstance(field_info.value, types.Named)
            else field_info.value
            for field_info in fields_info
        }
        losses = {
            utils._unique_name(
                names,
                f"{name}_loss" if not name.endswith("loss") else name,
            ): value
            for name, value in losses.items()
        }

        return losses