# treex.preserve_state

Takes in a function transformation such as jit or vmap and the function f to be transformed and returns a the transformed function with the expected behaviour but that additionally preserves the state of the first argument of f.

For example, within a Module if you try to vmap over a method with stateful operations like this:

@jax.vmap
def __call__(self, x):
self.n += 1
return 2.0 * x


It will not work since the composed function vmap(__call__) is a pure function so any change to self will not be reflected outside.

To solve you can wrap vmap using preserve_state like this:

@preserve_state(jax.vmap)
def __call__(self, x):
self.n += 1
return 2.0 * x


This will guarantee that the state of self is propagated to the outside.

Parameters:

Name Type Description Default
transformation ~C

The transformation to be applied to the function f.

required
f

The function to be transformed.

required
*args

Additional arguments to be passed to the transformation.

required
**kwargs

Additional keyword arguments to be passed to the transformation.

required

Returns:

Type Description
~C

The transformed function.

Source code in treex/module.py
def preserve_state(
transformation: C, *transformation_args, **transformation_kwargs
) -> C:
"""
Takes in a function transformation such as jit or vmap and the function f
to be transformed and returns a the transformed function with the expected behaviour
but that additionally preserves the state of the first argument of f.

For example, within a Module  if you try to vmap over a method with stateful operations like this:

python
@jax.vmap
def __call__(self, x):
self.n += 1
return 2.0 * x


It will not work since the composed function vmap(__call__) is a pure function so any change
to self will not be reflected outside.

To solve you can wrap vmap using preserve_state like this:

python
@preserve_state(jax.vmap)
def __call__(self, x):
self.n += 1
return 2.0 * x


This will guarantee that the state of self is propagated to the outside.

Arguments:
transformation: The transformation to be applied to the function f.
f: The function to be transformed.
*args: Additional arguments to be passed to the transformation.
**kwargs: Additional keyword arguments to be passed to the transformation.

Returns:
The transformed function.
"""

@functools.wraps(transformation)
def new_transformation(f):
f_original = f

f = _return_first(f)
f = _update_first(
transformation(f, *transformation_args, **transformation_kwargs)
)

@functools.wraps(f_original)
def wrapper(*args, **kwargs):
return f(*args, **kwargs)

return wrapper

return new_transformation