treeo.copy
Returns a deep copy of the tree, almost equivalent to:
jax.tree_map(lambda x: x, self)
Source code in treeo/tree.py
def copy(obj: A, shallow: bool = False) -> A:
"""
Returns a deep copy of the tree, almost equivalent to:
```python
jax.tree_map(lambda x: x, self)
```
but will try to copy static nodes as well.
"""
if shallow:
isleaf = lambda x: isinstance(x, Tree) and x is not obj
else:
isleaf = None
with _CONTEXT.update(flatten_mode=FlattenMode.all_fields):
return jax.tree_map(lambda x: x, obj, is_leaf=isleaf)