treex.copy
Returns a deep copy of the tree, almost equivalent to:
jax.tree_map(lambda x: x, self)
Source code in treeo/tree.py
def copy(obj: A) -> A:
"""
Returns a deep copy of the tree, almost equivalent to:
```python
jax.tree_map(lambda x: x, self)
```
but will try to copy static nodes as well.
"""
with _CONTEXT.update(flatten_mode=FlattenMode.all_fields):
return jax.tree_map(lambda x: x, obj)