Optimizer
Optax is an amazing library however, its optimizers are not pytrees, this means that their state and computation are separate and you cannot jit them. To solve this Treex provides a Optimizer class which inherits from treeo.Tree and can wrap any Optax optimizer. Optimizer follows a similar API as optax.GradientTransformation except that:
- There is no separate
opt_state, the Optimizer contains the state. updateby default applies the update the parameters, if you want the gradientupdatesinstead you can setapply_updates=False.updatealso updates the internal state of the Optimizer in-place.
While in Optax you would define something like this:
def main():
...
optimizer = optax.adam(1e-3)
opt_state = optimizer.init(params)
...
@partial(jax.jit, static_argnums=(4,))
def train_step(model, x, y, opt_state, optimizer): # optimizer has to be static
...
updates, opt_state = optimizer.update(grads, opt_state, params)
params = optax.apply_updates(params, updates)
...
return model, loss, opt_state
With tx.Optimizer you it can be simplified to:
def main():
...
optimizer = tx.Optimizer(optax.adam(1e-3)).init(params)
...
jax.jit # no static_argnums needed
def train_step(model, x, y, optimizer):
...
params = optimizer.update(grads, params)
...
return model, loss, optimizer
Notice that since tx.Optimizer is a Pytree it was passed through jit naturally without the need to specify static_argnums.