Skip to content

treex.nn.HaikuModule

Source code in treex/nn/haiku_module.py
class HaikuModule(Module):

    # static
    transform: hk.TransformedWithState

    # dynamic
    params_: tp.Optional[hk.Params] = types.Parameter.node()
    state_: tp.Optional[hk.State] = types.BatchStat.node()
    next_key: KeySeq

    def __init__(
        self,
        transform: tp.Union[
            hk.TransformedWithState,
            tp.Callable[..., tp.Any],
        ],
        params: tp.Optional[hk.Params] = None,
        states: tp.Optional[hk.State] = None,
    ) -> None:

        self.transform = (
            hk.transform_with_state(transform)
            if not isinstance(transform, hk.TransformedWithState)
            else transform
        )
        self.next_key = KeySeq()
        self.params_ = to_mutable_dict(params) if params is not None else None
        self.state_ = to_mutable_dict(states) if states is not None else None

    def __call__(self, *args, **kwargs):

        key = self.next_key()

        if "training" not in kwargs:
            _original_fn = self.transform.init._original_fn
            arg_names = utils._function_argument_names(_original_fn)

            if arg_names is not None and "training" in arg_names:
                kwargs["training"] = self.training if self.initialized else True

        if self.initializing() and self.params_ is None and self.state_ is None:
            self.params_, self.state_ = self.transform.init(
                key,
                *args,
                **kwargs,
            )
            output, _ = self.transform.apply(
                self.params_, self.state_, key, *args, **kwargs
            )
            self.params_ = to_mutable_dict(self.params_)
            self.state_ = to_mutable_dict(self.state_)
            return output

        output, next_state = self.transform.apply(
            self.params_, self.state_, key, *args, **kwargs
        )

        self.state_ = to_mutable_dict(next_state)

        return output