Skip to content

treex.nn.GroupNorm

Group normalization Module (arxiv.org/abs/1803.08494).

GroupNorm is implemented as a wrapper over flax.linen.GroupNorm, its constructor arguments accept the same arguments including any Flax artifacts such as initializers.

This op is similar to batch normalization, but statistics are shared across equally-sized groups of channels and not shared across batch dimension. Thus, group normalization does not depend on the batch composition and does not require maintaining internal state for storing statistics. The user should either specify the total number of channel groups or the number of channels per group..

Source code in treex/nn/norm.py
class GroupNorm(Module):
    """Group normalization Module (arxiv.org/abs/1803.08494).


    `GroupNorm` is implemented as a wrapper over `flax.linen.GroupNorm`, its constructor arguments accept the same arguments including any Flax artifacts such as initializers.

    This op is similar to batch normalization, but statistics are shared across equally-sized groups of channels and not shared across batch dimension. Thus, group normalization does not depend on the batch composition and does not require maintaining internal state for storing statistics. The user should either specify the total number of channel groups or the number of channels per group..

    """

    # pytree
    scale: tp.Optional[jnp.ndarray] = types.Parameter.node()
    bias: tp.Optional[jnp.ndarray] = types.Parameter.node()

    # props
    num_groups: tp.Optional[int]
    group_size: tp.Optional[int]
    epsilon: float
    dtype: flax_module.Dtype
    use_bias: bool
    use_scale: bool
    bias_init: tp.Callable[
        [flax_module.PRNGKey, flax_module.Shape, flax_module.Dtype],
        flax_module.Array,
    ]
    scale_init: tp.Callable[
        [flax_module.PRNGKey, flax_module.Shape, flax_module.Dtype],
        flax_module.Array,
    ]

    def __init__(
        self,
        *,
        num_groups: tp.Optional[int] = 32,
        group_size: tp.Optional[int] = None,
        epsilon: float = 1e-5,
        dtype: flax_module.Dtype = jnp.float32,
        use_bias: bool = True,
        use_scale: bool = True,
        bias_init: tp.Callable[
            [flax_module.PRNGKey, flax_module.Shape, flax_module.Dtype],
            flax_module.Array,
        ] = flax_module.initializers.zeros,
        scale_init: tp.Callable[
            [flax_module.PRNGKey, flax_module.Shape, flax_module.Dtype],
            flax_module.Array,
        ] = flax_module.initializers.ones,
    ):
        """
        Arguments:
            num_groups: the total number of channel groups. The default value of 32 is proposed by the original group normalization paper.
            group_size: the number of channels in a group.
            epsilon: a small float added to variance to avoid dividing by zero.
            dtype: the dtype of the computation (default: float32).
            use_bias:  if True, bias (beta) is added.
            use_scale: if True, multiply by scale (gamma).
                When the next layer is linear (also e.g. nn.relu), this can be disabled
                since the scaling will be done by the next layer.
            bias_init: initializer for bias, by default, zero.
            scale_init: initializer for scale, by default, one.
        """

        self.num_groups = num_groups
        self.group_size = group_size
        self.epsilon = epsilon
        self.dtype = dtype
        self.use_bias = use_bias
        self.use_scale = use_scale
        self.bias_init = bias_init
        self.scale_init = scale_init

        self.scale = None
        self.bias = None

    @property
    def module(self) -> flax_module.GroupNorm:
        return flax_module.GroupNorm(
            num_groups=self.num_groups,
            group_size=self.group_size,
            epsilon=self.epsilon,
            dtype=self.dtype,
            use_bias=self.use_bias,
            use_scale=self.use_scale,
            bias_init=self.bias_init,
            scale_init=self.scale_init,
        )

    def __call__(
        self,
        x: jnp.ndarray,
    ) -> jnp.ndarray:
        """Normalizes the individual input over equally-sized group of channels.

        Arguments:
            x: the input to be normalized.

        Returns:
            Normalized inputs (the same shape as inputs).
        """
        if self.initializing():
            variables = self.module.init(
                next_key(),
                x,
            ).unfreeze()

            # Extract collections
            if "params" in variables:
                params = variables["params"]

                if self.use_bias:
                    self.bias = params["bias"]

                if self.use_scale:
                    self.scale = params["scale"]

        params = {}

        if self.use_bias:
            params["bias"] = self.bias

        if self.use_scale:
            params["scale"] = self.scale

        variables = dict(
            params=params,
        )

        # call apply
        output = self.module.apply(
            variables,
            x,
        )

        return tp.cast(jnp.ndarray, output)

__call__(self, x) special

Normalizes the individual input over equally-sized group of channels.

Parameters:

Name Type Description Default
x ndarray

the input to be normalized.

required

Returns:

Type Description
ndarray

Normalized inputs (the same shape as inputs).

Source code in treex/nn/norm.py
def __call__(
    self,
    x: jnp.ndarray,
) -> jnp.ndarray:
    """Normalizes the individual input over equally-sized group of channels.

    Arguments:
        x: the input to be normalized.

    Returns:
        Normalized inputs (the same shape as inputs).
    """
    if self.initializing():
        variables = self.module.init(
            next_key(),
            x,
        ).unfreeze()

        # Extract collections
        if "params" in variables:
            params = variables["params"]

            if self.use_bias:
                self.bias = params["bias"]

            if self.use_scale:
                self.scale = params["scale"]

    params = {}

    if self.use_bias:
        params["bias"] = self.bias

    if self.use_scale:
        params["scale"] = self.scale

    variables = dict(
        params=params,
    )

    # call apply
    output = self.module.apply(
        variables,
        x,
    )

    return tp.cast(jnp.ndarray, output)

__init__(self, *, num_groups=32, group_size=None, epsilon=1e-05, dtype=<class 'jax._src.numpy.lax_numpy.float32'>, use_bias=True, use_scale=True, bias_init=<function zeros at 0x7f8860592290>, scale_init=<function ones at 0x7f8860598200>) special

Parameters:

Name Type Description Default
num_groups Optional[int]

the total number of channel groups. The default value of 32 is proposed by the original group normalization paper.

32
group_size Optional[int]

the number of channels in a group.

None
epsilon float

a small float added to variance to avoid dividing by zero.

1e-05
dtype Any

the dtype of the computation (default: float32).

<class 'jax._src.numpy.lax_numpy.float32'>
use_bias bool

if True, bias (beta) is added.

True
use_scale bool

if True, multiply by scale (gamma). When the next layer is linear (also e.g. nn.relu), this can be disabled since the scaling will be done by the next layer.

True
bias_init Callable[[Any, Tuple[int], Any], Any]

initializer for bias, by default, zero.

<function zeros at 0x7f8860592290>
scale_init Callable[[Any, Tuple[int], Any], Any]

initializer for scale, by default, one.

<function ones at 0x7f8860598200>
Source code in treex/nn/norm.py
def __init__(
    self,
    *,
    num_groups: tp.Optional[int] = 32,
    group_size: tp.Optional[int] = None,
    epsilon: float = 1e-5,
    dtype: flax_module.Dtype = jnp.float32,
    use_bias: bool = True,
    use_scale: bool = True,
    bias_init: tp.Callable[
        [flax_module.PRNGKey, flax_module.Shape, flax_module.Dtype],
        flax_module.Array,
    ] = flax_module.initializers.zeros,
    scale_init: tp.Callable[
        [flax_module.PRNGKey, flax_module.Shape, flax_module.Dtype],
        flax_module.Array,
    ] = flax_module.initializers.ones,
):
    """
    Arguments:
        num_groups: the total number of channel groups. The default value of 32 is proposed by the original group normalization paper.
        group_size: the number of channels in a group.
        epsilon: a small float added to variance to avoid dividing by zero.
        dtype: the dtype of the computation (default: float32).
        use_bias:  if True, bias (beta) is added.
        use_scale: if True, multiply by scale (gamma).
            When the next layer is linear (also e.g. nn.relu), this can be disabled
            since the scaling will be done by the next layer.
        bias_init: initializer for bias, by default, zero.
        scale_init: initializer for scale, by default, one.
    """

    self.num_groups = num_groups
    self.group_size = group_size
    self.epsilon = epsilon
    self.dtype = dtype
    self.use_bias = use_bias
    self.use_scale = use_scale
    self.bias_init = bias_init
    self.scale_init = scale_init

    self.scale = None
    self.bias = None