treex.preserve_state
Takes in a function transformation such as jit
or vmap
and the function f
to be transformed and returns a the transformed function with the expected behaviour
but that additionally preserves the state of the first argument of f
.
For example, within a Module
if you try to vmap
over a method with stateful operations like this:
@jax.vmap
def __call__(self, x):
self.n += 1
return 2.0 * x
It will not work since the composed function vmap(__call__)
is a pure function so any change
to self
will not be reflected outside.
To solve you can wrap vmap
using preserve_state
like this:
@preserve_state(jax.vmap)
def __call__(self, x):
self.n += 1
return 2.0 * x
This will guarantee that the state of self
is propagated to the outside.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
transformation |
~C |
The transformation to be applied to the function |
required |
f |
The function to be transformed. |
required | |
*args |
Additional arguments to be passed to the transformation. |
required | |
**kwargs |
Additional keyword arguments to be passed to the transformation. |
required |
Returns:
Type | Description |
---|---|
~C |
The transformed function. |
Source code in treex/module.py
def preserve_state(
transformation: C, *transformation_args, **transformation_kwargs
) -> C:
"""
Takes in a function transformation such as `jit` or `vmap` and the function `f`
to be transformed and returns a the transformed function with the expected behaviour
but that additionally preserves the state of the first argument of `f`.
For example, within a `Module` if you try to `vmap` over a method with stateful operations like this:
```python
@jax.vmap
def __call__(self, x):
self.n += 1
return 2.0 * x
```
It will not work since the composed function `vmap(__call__)` is a pure function so any change
to `self` will not be reflected outside.
To solve you can wrap `vmap` using `preserve_state` like this:
```python
@preserve_state(jax.vmap)
def __call__(self, x):
self.n += 1
return 2.0 * x
```
This will guarantee that the state of `self` is propagated to the outside.
Arguments:
transformation: The transformation to be applied to the function `f`.
f: The function to be transformed.
*args: Additional arguments to be passed to the transformation.
**kwargs: Additional keyword arguments to be passed to the transformation.
Returns:
The transformed function.
"""
@functools.wraps(transformation)
def new_transformation(f):
f_original = f
f = _return_first(f)
f = _update_first(
transformation(f, *transformation_args, **transformation_kwargs)
)
@functools.wraps(f_original)
def wrapper(*args, **kwargs):
return f(*args, **kwargs)
return wrapper
return new_transformation