treex.flatten_mode
A context manager that defines how Tree
s are flattened. Options are:
'normal'
: Fields are selected as nodes as declared in the class definition (default behavior).'all_fields'
: All fields are treated as nodes during flattening.'no_fields'
: All fields are treated as static,Tree
s produce no leaves.None
: Context is not changed, current flatten mode is preserved.
Examples:
@dataclass
class MyTree(Tree):
x: int # static
y: int = to.node()
tree = MyTree(x=1, y=3)
jax.tree_map(lambda x: x * 2, tree) # MyTree(x=1, y=6)
with flatten_mode('all_fields'):
jax.tree_map(lambda x: x + 1, tree) # MyTree(x=2, y=6)
Parameters:
Name | Type | Description | Default |
---|---|---|---|
mode |
Union[treeo.tree.FlattenMode, str] |
The new flatten mode. |
required |
Source code in treeo/api.py
@contextmanager
def flatten_mode(mode: tp.Optional[tp.Union[FlattenMode, str]]):
"""
A context manager that defines how `Tree`s are flattened. Options are:
* `'normal'`: Fields are selected as nodes as declared in the class definition (default behavior).
* `'all_fields'`: All fields are treated as nodes during flattening.
* `'no_fields'`: All fields are treated as static, `Tree`s produce no leaves.
* `None`: Context is not changed, current flatten mode is preserved.
Example:
```python
@dataclass
class MyTree(Tree):
x: int # static
y: int = to.node()
tree = MyTree(x=1, y=3)
jax.tree_map(lambda x: x * 2, tree) # MyTree(x=1, y=6)
with flatten_mode('all_fields'):
jax.tree_map(lambda x: x + 1, tree) # MyTree(x=2, y=6)
```
Arguments:
mode: The new flatten mode.
"""
if mode is not None:
if isinstance(mode, str):
mode = FlattenMode(mode)
with tree_m._CONTEXT.update(flatten_mode=mode):
yield
else:
yield