treex.GroupNorm
Group normalization Module (arxiv.org/abs/1803.08494).
GroupNorm
is implemented as a wrapper over flax.linen.GroupNorm
, its constructor arguments accept the same arguments including any Flax artifacts such as initializers.
This op is similar to batch normalization, but statistics are shared across equally-sized groups of channels and not shared across batch dimension. Thus, group normalization does not depend on the batch composition and does not require maintaining internal state for storing statistics. The user should either specify the total number of channel groups or the number of channels per group..
Source code in treex/nn/norm.py
class GroupNorm(Module):
"""Group normalization Module (arxiv.org/abs/1803.08494).
`GroupNorm` is implemented as a wrapper over `flax.linen.GroupNorm`, its constructor arguments accept the same arguments including any Flax artifacts such as initializers.
This op is similar to batch normalization, but statistics are shared across equally-sized groups of channels and not shared across batch dimension. Thus, group normalization does not depend on the batch composition and does not require maintaining internal state for storing statistics. The user should either specify the total number of channel groups or the number of channels per group..
"""
# pytree
scale: tp.Optional[jnp.ndarray] = types.Parameter.node()
bias: tp.Optional[jnp.ndarray] = types.Parameter.node()
# props
num_groups: tp.Optional[int]
group_size: tp.Optional[int]
epsilon: float
dtype: flax_module.Dtype
use_bias: bool
use_scale: bool
bias_init: tp.Callable[
[flax_module.PRNGKey, flax_module.Shape, flax_module.Dtype],
flax_module.Array,
]
scale_init: tp.Callable[
[flax_module.PRNGKey, flax_module.Shape, flax_module.Dtype],
flax_module.Array,
]
def __init__(
self,
*,
num_groups: tp.Optional[int] = 32,
group_size: tp.Optional[int] = None,
epsilon: float = 1e-5,
dtype: flax_module.Dtype = jnp.float32,
use_bias: bool = True,
use_scale: bool = True,
bias_init: tp.Callable[
[flax_module.PRNGKey, flax_module.Shape, flax_module.Dtype],
flax_module.Array,
] = flax_module.initializers.zeros,
scale_init: tp.Callable[
[flax_module.PRNGKey, flax_module.Shape, flax_module.Dtype],
flax_module.Array,
] = flax_module.initializers.ones,
):
"""
Arguments:
num_groups: the total number of channel groups. The default value of 32 is proposed by the original group normalization paper.
group_size: the number of channels in a group.
epsilon: a small float added to variance to avoid dividing by zero.
dtype: the dtype of the computation (default: float32).
use_bias: if True, bias (beta) is added.
use_scale: if True, multiply by scale (gamma).
When the next layer is linear (also e.g. nn.relu), this can be disabled
since the scaling will be done by the next layer.
bias_init: initializer for bias, by default, zero.
scale_init: initializer for scale, by default, one.
"""
self.num_groups = num_groups
self.group_size = group_size
self.epsilon = epsilon
self.dtype = dtype
self.use_bias = use_bias
self.use_scale = use_scale
self.bias_init = bias_init
self.scale_init = scale_init
self.scale = None
self.bias = None
@property
def module(self) -> flax_module.GroupNorm:
return flax_module.GroupNorm(
num_groups=self.num_groups,
group_size=self.group_size,
epsilon=self.epsilon,
dtype=self.dtype,
use_bias=self.use_bias,
use_scale=self.use_scale,
bias_init=self.bias_init,
scale_init=self.scale_init,
)
def __call__(
self,
x: jnp.ndarray,
) -> jnp.ndarray:
"""Normalizes the individual input over equally-sized group of channels.
Arguments:
x: the input to be normalized.
Returns:
Normalized inputs (the same shape as inputs).
"""
if self.initializing():
variables = self.module.init(
next_key(),
x,
).unfreeze()
# Extract collections
if "params" in variables:
params = variables["params"]
if self.use_bias:
self.bias = params["bias"]
if self.use_scale:
self.scale = params["scale"]
params = {}
if self.use_bias:
params["bias"] = self.bias
if self.use_scale:
params["scale"] = self.scale
variables = dict(
params=params,
)
# call apply
output = self.module.apply(
variables,
x,
)
return tp.cast(jnp.ndarray, output)
__call__(self, x)
special
Normalizes the individual input over equally-sized group of channels.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
x |
ndarray |
the input to be normalized. |
required |
Returns:
Type | Description |
---|---|
ndarray |
Normalized inputs (the same shape as inputs). |
Source code in treex/nn/norm.py
def __call__(
self,
x: jnp.ndarray,
) -> jnp.ndarray:
"""Normalizes the individual input over equally-sized group of channels.
Arguments:
x: the input to be normalized.
Returns:
Normalized inputs (the same shape as inputs).
"""
if self.initializing():
variables = self.module.init(
next_key(),
x,
).unfreeze()
# Extract collections
if "params" in variables:
params = variables["params"]
if self.use_bias:
self.bias = params["bias"]
if self.use_scale:
self.scale = params["scale"]
params = {}
if self.use_bias:
params["bias"] = self.bias
if self.use_scale:
params["scale"] = self.scale
variables = dict(
params=params,
)
# call apply
output = self.module.apply(
variables,
x,
)
return tp.cast(jnp.ndarray, output)
__init__(self, *, num_groups=32, group_size=None, epsilon=1e-05, dtype=<class 'jax._src.numpy.lax_numpy.float32'>, use_bias=True, use_scale=True, bias_init=<function zeros at 0x7f8860592290>, scale_init=<function ones at 0x7f8860598200>)
special
Parameters:
Name | Type | Description | Default |
---|---|---|---|
num_groups |
Optional[int] |
the total number of channel groups. The default value of 32 is proposed by the original group normalization paper. |
32 |
group_size |
Optional[int] |
the number of channels in a group. |
None |
epsilon |
float |
a small float added to variance to avoid dividing by zero. |
1e-05 |
dtype |
Any |
the dtype of the computation (default: float32). |
<class 'jax._src.numpy.lax_numpy.float32'> |
use_bias |
bool |
if True, bias (beta) is added. |
True |
use_scale |
bool |
if True, multiply by scale (gamma). When the next layer is linear (also e.g. nn.relu), this can be disabled since the scaling will be done by the next layer. |
True |
bias_init |
Callable[[Any, Tuple[int], Any], Any] |
initializer for bias, by default, zero. |
<function zeros at 0x7f8860592290> |
scale_init |
Callable[[Any, Tuple[int], Any], Any] |
initializer for scale, by default, one. |
<function ones at 0x7f8860598200> |
Source code in treex/nn/norm.py
def __init__(
self,
*,
num_groups: tp.Optional[int] = 32,
group_size: tp.Optional[int] = None,
epsilon: float = 1e-5,
dtype: flax_module.Dtype = jnp.float32,
use_bias: bool = True,
use_scale: bool = True,
bias_init: tp.Callable[
[flax_module.PRNGKey, flax_module.Shape, flax_module.Dtype],
flax_module.Array,
] = flax_module.initializers.zeros,
scale_init: tp.Callable[
[flax_module.PRNGKey, flax_module.Shape, flax_module.Dtype],
flax_module.Array,
] = flax_module.initializers.ones,
):
"""
Arguments:
num_groups: the total number of channel groups. The default value of 32 is proposed by the original group normalization paper.
group_size: the number of channels in a group.
epsilon: a small float added to variance to avoid dividing by zero.
dtype: the dtype of the computation (default: float32).
use_bias: if True, bias (beta) is added.
use_scale: if True, multiply by scale (gamma).
When the next layer is linear (also e.g. nn.relu), this can be disabled
since the scaling will be done by the next layer.
bias_init: initializer for bias, by default, zero.
scale_init: initializer for scale, by default, one.
"""
self.num_groups = num_groups
self.group_size = group_size
self.epsilon = epsilon
self.dtype = dtype
self.use_bias = use_bias
self.use_scale = use_scale
self.bias_init = bias_init
self.scale_init = scale_init
self.scale = None
self.bias = None