Skip to content

treeo.copy

Returns a deep copy of the tree, almost equivalent to:

jax.tree_map(lambda x: x, self)
but will try to copy static nodes as well.

Source code in treeo/tree.py
def copy(obj: A, shallow: bool = False) -> A:
    """
    Returns a deep copy of the tree, almost equivalent to:
    ```python
    jax.tree_map(lambda x: x, self)
    ```
    but will try to copy static nodes as well.
    """
    if shallow:
        isleaf = lambda x: isinstance(x, Tree) and x is not obj
    else:
        isleaf = None

    with _CONTEXT.update(flatten_mode=FlattenMode.all_fields):
        return jax.tree_map(lambda x: x, obj, is_leaf=isleaf)