treex.Flatten Source code in treex/nn/flatten.py class Flatten(Module): def __call__(self, x: jnp.ndarray) -> jnp.ndarray: return einops.rearrange(x, "batch ... -> batch (...)")