Skip to content

treex.preserve_state

Takes in a function transformation such as jit or vmap and the function f to be transformed and returns a the transformed function with the expected behaviour but that additionally preserves the state of the first argument of f.

For example, within a Module if you try to vmap over a method with stateful operations like this:

@jax.vmap
def __call__(self, x):
    self.n += 1
    return 2.0 * x

It will not work since the composed function vmap(__call__) is a pure function so any change to self will not be reflected outside.

To solve you can wrap vmap using preserve_state like this:

@preserve_state(jax.vmap)
def __call__(self, x):
    self.n += 1
    return 2.0 * x

This will guarantee that the state of self is propagated to the outside.

Parameters:

Name Type Description Default
transformation ~C

The transformation to be applied to the function f.

required
f

The function to be transformed.

required
*args

Additional arguments to be passed to the transformation.

required
**kwargs

Additional keyword arguments to be passed to the transformation.

required

Returns:

Type Description
~C

The transformed function.

Source code in treex/module.py
def preserve_state(
    transformation: C, *transformation_args, **transformation_kwargs
) -> C:
    """
    Takes in a function transformation such as `jit` or `vmap` and the function `f`
    to be transformed and returns a the transformed function with the expected behaviour
    but that additionally preserves the state of the first argument of `f`.

    For example, within a `Module`  if you try to `vmap` over a method with stateful operations like this:

    ```python
    @jax.vmap
    def __call__(self, x):
        self.n += 1
        return 2.0 * x
    ```

    It will not work since the composed function `vmap(__call__)` is a pure function so any change
    to `self` will not be reflected outside.


    To solve you can wrap `vmap` using `preserve_state` like this:

    ```python
    @preserve_state(jax.vmap)
    def __call__(self, x):
        self.n += 1
        return 2.0 * x
    ```

    This will guarantee that the state of `self` is propagated to the outside.

    Arguments:
        transformation: The transformation to be applied to the function `f`.
        f: The function to be transformed.
        *args: Additional arguments to be passed to the transformation.
        **kwargs: Additional keyword arguments to be passed to the transformation.

    Returns:
        The transformed function.
    """

    @functools.wraps(transformation)
    def new_transformation(f):
        f_original = f

        f = _return_first(f)
        f = _update_first(
            transformation(f, *transformation_args, **transformation_kwargs)
        )

        @functools.wraps(f_original)
        def wrapper(*args, **kwargs):
            return f(*args, **kwargs)

        return wrapper

    return new_transformation