treex.Optimizer
Wraps an optax optimizer and turn it into a Pytree while maintaining a similar API.
The main difference with optax is that tx.Optimizer contains its own state, thus, there is
no opt_state
.
Examples:
def main():
...
optimizer = tx.Optimizer(optax.adam(1e-3))
optimizer = optimizer.init(params)
...
jax.jit
def train_step(model, x, y, optimizer):
...
params = optimizer.update(grads, params)
...
return model, loss, optimizer
Notice that since the optimizer is a Pytree
it can naturally pass through jit
.
Differences with Optax
init
return a new optimizer instance, there is noopt_state
.update
doesn't getopt_state
as an argument, instead it performs updates to its internal state inplace.update
applies the updates to the params and returns them by default, useupdate=False
to to get the param updates instead.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
optimizer |
GradientTransformation |
An optax optimizer. |
required |
Source code in treex/optimizer.py
class Optimizer(Treex):
"""Wraps an optax optimizer and turn it into a Pytree while maintaining a similar API.
The main difference with optax is that tx.Optimizer contains its own state, thus, there is
no `opt_state`.
Example:
```python
def main():
...
optimizer = tx.Optimizer(optax.adam(1e-3))
optimizer = optimizer.init(params)
...
jax.jit
def train_step(model, x, y, optimizer):
...
params = optimizer.update(grads, params)
...
return model, loss, optimizer
```
Notice that since the optimizer is a `Pytree` it can naturally pass through `jit`.
### Differences with Optax
* `init` return a new optimizer instance, there is no `opt_state`.
* `update` doesn't get `opt_state` as an argument, instead it performs updates
to its internal state inplace.
* `update` applies the updates to the params and returns them by default, use `update=False` to
to get the param updates instead.
Arguments:
optimizer: An optax optimizer.
"""
optimizer: optax.GradientTransformation
opt_state: tp.Optional[tp.Any] = types.OptState.node(None, init=False)
_n_params: tp.Optional[int] = to.static(None, init=False)
# use to.field to copy class vars to instance
_initialized: bool = to.static(False)
def __init__(self, optimizer: optax.GradientTransformation) -> None:
self.optimizer = optimizer
@property
def initialized(self) -> bool:
return self._initialized
def init(self: O, params: tp.Any) -> O:
"""
Initialize the optimizer from an initial set of parameters.
Arguments:
params: An initial set of parameters.
Returns:
A new optimizer instance.
"""
module = to.copy(self)
params = jax.tree_leaves(params)
module.opt_state = module.optimizer.init(params)
module._n_params = len(params)
module._initialized = True
return module
# NOTE: params are flattened because:
# - The flat list is not a Module, thus all of its internal parameters in the list are marked as
# OptState by a single annotation (no need to rewrite the module's annotations)
# - It ignores the static part of Modules which if changed Optax yields an error.
def update(
self, grads: A, params: tp.Optional[A] = None, apply_updates: bool = True
) -> A:
"""
Applies the parameters updates and updates the optimizers internal state inplace.
Arguments:
grads: the gradients to perform the update.
params: the parameters to update. If `None` then `update` has to be `False`.
apply_updates: if `False` then the updates are returned instead of being applied.
Returns:
The updated parameters. If `apply_updates` is `False` then the updates are returned instead.
"""
if not self.initialized:
raise RuntimeError("Optimizer is not initialized")
assert self.opt_state is not None
if apply_updates and params is None:
raise ValueError("params must be provided if updates are being applied")
opt_grads, treedef = jax.tree_flatten(grads)
opt_params = jax.tree_leaves(params)
if len(opt_params) != self._n_params:
raise ValueError(
f"params must have length {self._n_params}, got {len(opt_params)}"
)
if len(opt_grads) != self._n_params:
raise ValueError(
f"grads must have length {self._n_params}, got {len(opt_grads)}"
)
param_updates: A
param_updates, self.opt_state = self.optimizer.update(
opt_grads,
self.opt_state,
opt_params,
)
output: A
if apply_updates:
output = optax.apply_updates(opt_params, param_updates)
else:
output = param_updates
return jax.tree_unflatten(treedef, output)
init(self, params)
Initialize the optimizer from an initial set of parameters.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
params |
Any |
An initial set of parameters. |
required |
Returns:
Type | Description |
---|---|
~O |
A new optimizer instance. |
Source code in treex/optimizer.py
def init(self: O, params: tp.Any) -> O:
"""
Initialize the optimizer from an initial set of parameters.
Arguments:
params: An initial set of parameters.
Returns:
A new optimizer instance.
"""
module = to.copy(self)
params = jax.tree_leaves(params)
module.opt_state = module.optimizer.init(params)
module._n_params = len(params)
module._initialized = True
return module
update(self, grads, params=None, apply_updates=True)
Applies the parameters updates and updates the optimizers internal state inplace.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
grads |
~A |
the gradients to perform the update. |
required |
params |
Optional[~A] |
the parameters to update. If |
None |
apply_updates |
bool |
if |
True |
Returns:
Type | Description |
---|---|
~A |
The updated parameters. If |
Source code in treex/optimizer.py
def update(
self, grads: A, params: tp.Optional[A] = None, apply_updates: bool = True
) -> A:
"""
Applies the parameters updates and updates the optimizers internal state inplace.
Arguments:
grads: the gradients to perform the update.
params: the parameters to update. If `None` then `update` has to be `False`.
apply_updates: if `False` then the updates are returned instead of being applied.
Returns:
The updated parameters. If `apply_updates` is `False` then the updates are returned instead.
"""
if not self.initialized:
raise RuntimeError("Optimizer is not initialized")
assert self.opt_state is not None
if apply_updates and params is None:
raise ValueError("params must be provided if updates are being applied")
opt_grads, treedef = jax.tree_flatten(grads)
opt_params = jax.tree_leaves(params)
if len(opt_params) != self._n_params:
raise ValueError(
f"params must have length {self._n_params}, got {len(opt_params)}"
)
if len(opt_grads) != self._n_params:
raise ValueError(
f"grads must have length {self._n_params}, got {len(opt_grads)}"
)
param_updates: A
param_updates, self.opt_state = self.optimizer.update(
opt_grads,
self.opt_state,
opt_params,
)
output: A
if apply_updates:
output = optax.apply_updates(opt_params, param_updates)
else:
output = param_updates
return jax.tree_unflatten(treedef, output)