Skip to content

treex.Flatten

Source code in treex/nn/flatten.py
class Flatten(Module):
    def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
        return einops.rearrange(x, "batch ... -> batch (...)")