Optimizer
Optax is an amazing library however, its optimizers are not pytrees, this means that their state and computation are separate and you cannot jit them. To solve this Treex provides a Optimizer
class which inherits from treeo.Tree
and can wrap any Optax optimizer. Optimizer follows a similar API as optax.GradientTransformation
except that:
- There is no separate
opt_state
, the Optimizer contains the state. update
by default applies the update the parameters, if you want the gradientupdates
instead you can setapply_updates=False
.update
also updates the internal state of the Optimizer in-place.
While in Optax you would define something like this:
def main():
...
optimizer = optax.adam(1e-3)
opt_state = optimizer.init(params)
...
@partial(jax.jit, static_argnums=(4,))
def train_step(model, x, y, opt_state, optimizer): # optimizer has to be static
...
updates, opt_state = optimizer.update(grads, opt_state, params)
params = optax.apply_updates(params, updates)
...
return model, loss, opt_state
With tx.Optimizer
you it can be simplified to:
def main():
...
optimizer = tx.Optimizer(optax.adam(1e-3)).init(params)
...
jax.jit # no static_argnums needed
def train_step(model, x, y, optimizer):
...
params = optimizer.update(grads, params)
...
return model, loss, optimizer
Notice that since tx.Optimizer
is a Pytree it was passed through jit
naturally without the need to specify static_argnums
.