Skip to content

treex.copy

Returns a deep copy of the tree, almost equivalent to:

jax.tree_map(lambda x: x, self)
but will try to copy static nodes as well.

Source code in treeo/tree.py
def copy(obj: A) -> A:
    """
    Returns a deep copy of the tree, almost equivalent to:
    ```python
    jax.tree_map(lambda x: x, self)
    ```
    but will try to copy static nodes as well.
    """
    with _CONTEXT.update(flatten_mode=FlattenMode.all_fields):
        return jax.tree_map(lambda x: x, obj)