Skip to content

treex.Optimizer

Wraps an optax optimizer and turn it into a Pytree while maintaining a similar API.

The main difference with optax is that tx.Optimizer contains its own state, thus, there is no opt_state.

Examples:

def main():
    ...
    optimizer = tx.Optimizer(optax.adam(1e-3))
    optimizer = optimizer.init(params)
    ...

jax.jit
def train_step(model, x, y, optimizer):
    ...
    params = optimizer.update(grads, params)
    ...
    return model, loss, optimizer

Notice that since the optimizer is a Pytree it can naturally pass through jit.

Differences with Optax

  • init return a new optimizer instance, there is no opt_state.
  • update doesn't get opt_state as an argument, instead it performs updates to its internal state inplace.
  • update applies the updates to the params and returns them by default, use update=False to to get the param updates instead.

Parameters:

Name Type Description Default
optimizer

An optax optimizer.

required

init(self, params)

Initialize the optimizer from an initial set of parameters.

Parameters:

Name Type Description Default
params Any

An initial set of parameters.

required

Returns:

Type Description
~O

A new optimizer instance.

Source code in treex/optimizer.py
def init(self: O, params: tp.Any) -> O:
    """
    Initialize the optimizer from an initial set of parameters.

    Arguments:
        params: An initial set of parameters.

    Returns:
        A new optimizer instance.
    """
    module = to.copy(self)
    params = jax.tree_leaves(params)
    module.opt_state = module.optimizer.init(params)
    module._n_params = len(params)
    module._initialized = True
    return module

update(self, grads, params=None, apply_updates=True)

Applies the parameters updates and updates the optimizers internal state inplace.

Parameters:

Name Type Description Default
grads ~A

the gradients to perform the update.

required
params Optional[~A]

the parameters to update. If None then update has to be False.

None
apply_updates bool

if False then the updates are returned instead of being applied.

True

Returns:

Type Description
~A

The updated parameters. If apply_updates is False then the updates are returned instead.

Source code in treex/optimizer.py
def update(
    self, grads: A, params: tp.Optional[A] = None, apply_updates: bool = True
) -> A:
    """
    Applies the parameters updates and updates the optimizers internal state inplace.

    Arguments:
        grads: the gradients to perform the update.
        params: the parameters to update. If `None` then `update` has to be `False`.
        apply_updates: if `False` then the updates are returned instead of being applied.

    Returns:
        The updated parameters. If `apply_updates` is `False` then the updates are returned instead.
    """
    if not self.initialized:
        raise RuntimeError("Optimizer is not initialized")

    assert self.opt_state is not None
    if apply_updates and params is None:
        raise ValueError("params must be provided if updates are being applied")

    opt_grads, treedef = jax.tree_flatten(grads)
    opt_params = jax.tree_leaves(params)

    if len(opt_params) != self._n_params:
        raise ValueError(
            f"params must have length {self._n_params}, got {len(opt_params)}"
        )
    if len(opt_grads) != self._n_params:
        raise ValueError(
            f"grads must have length {self._n_params}, got {len(opt_grads)}"
        )

    param_updates: A
    param_updates, self.opt_state = self.optimizer.update(
        opt_grads,
        self.opt_state,
        opt_params,
    )

    output: A
    if apply_updates:
        output = optax.apply_updates(opt_params, param_updates)
    else:
        output = param_updates

    return jax.tree_unflatten(treedef, output)