Skip to content

treex.flatten_mode

A context manager that defines how Trees are flattened. Options are:

  • 'normal': Fields are selected as nodes as declared in the class definition (default behavior).
  • 'all_fields': All fields are treated as nodes during flattening.
  • 'no_fields': All fields are treated as static, Trees produce no leaves.
  • None: Context is not changed, current flatten mode is preserved.

Examples:

@dataclass
class MyTree(Tree):
    x: int # static
    y: int = to.node()

tree = MyTree(x=1, y=3)

jax.tree_map(lambda x: x * 2, tree) # MyTree(x=1, y=6)

with flatten_mode('all_fields'):
    jax.tree_map(lambda x: x + 1, tree) # MyTree(x=2, y=6)

Parameters:

Name Type Description Default
mode Union[treeo.tree.FlattenMode, str]

The new flatten mode.

required
Source code in treeo/api.py
@contextmanager
def flatten_mode(mode: tp.Optional[tp.Union[FlattenMode, str]]):
    """
    A context manager that defines how `Tree`s are flattened. Options are:

    * `'normal'`: Fields are selected as nodes as declared in the class definition (default behavior).
    * `'all_fields'`: All fields are treated as nodes during flattening.
    * `'no_fields'`: All fields are treated as static, `Tree`s produce no leaves.
    * `None`: Context is not changed, current flatten mode is preserved.

    Example:

    ```python
    @dataclass
    class MyTree(Tree):
        x: int # static
        y: int = to.node()

    tree = MyTree(x=1, y=3)

    jax.tree_map(lambda x: x * 2, tree) # MyTree(x=1, y=6)

    with flatten_mode('all_fields'):
        jax.tree_map(lambda x: x + 1, tree) # MyTree(x=2, y=6)
    ```

    Arguments:
        mode: The new flatten mode.
    """
    if mode is not None:
        if isinstance(mode, str):
            mode = FlattenMode(mode)

        with tree_m._CONTEXT.update(flatten_mode=mode):
            yield
    else:
        yield