Skip to content

treex.BatchNorm

BatchNorm Module.

BatchNorm is implemented as a wrapper over flax.linen.BatchNorm, its constructor arguments accept almost the same arguments including any Flax artifacts such as initializers. Main differences:

  • use_running_average is not a constructor argument, but remains a __call__ argument.
  • self.training state is used to indicate how BatchNorm should behave, interally use_running_average = not self.training or self.frozen is used unless use_running_average is explicitly passed via __call__.

__call__(self, x, use_running_average=None) special

Normalizes the input using batch statistics.

Parameters:

Name Type Description Default
x ndarray

the input to be normalized.

required
use_running_average Optional[bool]

if true, the statistics stored in batch_stats will be used instead of computing the batch statistics on the input.

None

Returns:

Type Description
ndarray

Normalized inputs (the same shape as inputs).

Source code in treex/nn/batch_norm.py
def __call__(
    self, x: jnp.ndarray, use_running_average: tp.Optional[bool] = None
) -> jnp.ndarray:
    """Normalizes the input using batch statistics.

    Arguments:
        x: the input to be normalized.
        use_running_average: if true, the statistics stored in batch_stats
            will be used instead of computing the batch statistics on the input.

    Returns:
        Normalized inputs (the same shape as inputs).
    """
    if self.initializing():
        variables = self.module.init(
            next_key(),
            x,
            use_running_average=True,
        ).unfreeze()

        # Extract collections
        if "params" in variables:
            params = variables["params"]

            if self.use_bias:
                self.bias = params["bias"]

            if self.use_scale:
                self.scale = params["scale"]

        self.mean = variables["batch_stats"]["mean"]
        self.var = variables["batch_stats"]["var"]

    params = {}

    if self.use_bias:
        params["bias"] = self.bias

    if self.use_scale:
        params["scale"] = self.scale

    variables = dict(
        batch_stats=dict(
            mean=self.mean,
            var=self.var,
        ),
        params=params,
    )
    # use_running_average = True means batch_stats will not be mutated
    # self.training = True means batch_stats will be mutated
    training = (
        not use_running_average
        if use_running_average is not None
        else self.training and not self.frozen and self.initialized
    )

    # call apply
    output, variables = self.module.apply(
        variables,
        x,
        mutable=["batch_stats"] if training else [],
        use_running_average=not training,
    )
    variables = variables.unfreeze()

    # update batch_stats
    if "batch_stats" in variables:
        self.mean = variables["batch_stats"]["mean"]
        self.var = variables["batch_stats"]["var"]

    return tp.cast(jnp.ndarray, output)

__init__(self, *, axis=-1, momentum=0.99, epsilon=1e-05, dtype=<class 'jax._src.numpy.lax_numpy.float32'>, use_bias=True, use_scale=True, bias_init=<function zeros at 0x7f0f5aad27a0>, scale_init=<function ones at 0x7f0f5aadb5f0>, axis_name=None, axis_index_groups=None) special

Parameters:

Name Type Description Default
features_in

the number of input features.

required
axis int

the feature or non-batch axis of the input.

-1
momentum Union[float, jax._src.numpy.lax_numpy.ndarray]

decay rate for the exponential moving average of the batch statistics.

0.99
epsilon float

a small float added to variance to avoid dividing by zero.

1e-05
dtype Any

the dtype of the computation (default: float32).

<class 'jax._src.numpy.lax_numpy.float32'>
use_bias bool

if True, bias (beta) is added.

True
use_scale bool

if True, multiply by scale (gamma). When the next layer is linear (also e.g. nn.relu), this can be disabled since the scaling will be done by the next layer.

True
bias_init Callable[[Any, Tuple[int], Any], Any]

initializer for bias, by default, zero.

<function zeros at 0x7f0f5aad27a0>
scale_init Callable[[Any, Tuple[int], Any], Any]

initializer for scale, by default, one.

<function ones at 0x7f0f5aadb5f0>
axis_name Optional[str]

the axis name used to combine batch statistics from multiple devices. See jax.pmap for a description of axis names (default: None).

None
axis_index_groups Any

groups of axis indices within that named axis representing subsets of devices to reduce over (default: None). For example, [[0, 1], [2, 3]] would independently batch-normalize over the examples on the first two and last two devices. See jax.lax.psum for more details.

None
Source code in treex/nn/batch_norm.py
def __init__(
    self,
    *,
    axis: int = -1,
    momentum: tp.Union[float, jnp.ndarray] = 0.99,
    epsilon: float = 1e-5,
    dtype: flax_module.Dtype = jnp.float32,
    use_bias: bool = True,
    use_scale: bool = True,
    bias_init: tp.Callable[
        [flax_module.PRNGKey, flax_module.Shape, flax_module.Dtype],
        flax_module.Array,
    ] = flax_module.initializers.zeros,
    scale_init: tp.Callable[
        [flax_module.PRNGKey, flax_module.Shape, flax_module.Dtype],
        flax_module.Array,
    ] = flax_module.initializers.ones,
    axis_name: tp.Optional[str] = None,
    axis_index_groups: tp.Any = None,
):
    """
    Arguments:
        features_in: the number of input features.
        axis: the feature or non-batch axis of the input.
        momentum: decay rate for the exponential moving average of
            the batch statistics.
        epsilon: a small float added to variance to avoid dividing by zero.
        dtype: the dtype of the computation (default: float32).
        use_bias:  if True, bias (beta) is added.
        use_scale: if True, multiply by scale (gamma).
            When the next layer is linear (also e.g. nn.relu), this can be disabled
            since the scaling will be done by the next layer.
        bias_init: initializer for bias, by default, zero.
        scale_init: initializer for scale, by default, one.
        axis_name: the axis name used to combine batch statistics from multiple
            devices. See `jax.pmap` for a description of axis names (default: None).
        axis_index_groups: groups of axis indices within that named axis
            representing subsets of devices to reduce over (default: None). For
            example, `[[0, 1], [2, 3]]` would independently batch-normalize over
            the examples on the first two and last two devices. See `jax.lax.psum`
            for more details.
    """

    self.axis = axis
    self.momentum = jnp.asarray(momentum)
    self.epsilon = epsilon
    self.dtype = dtype
    self.use_bias = use_bias
    self.use_scale = use_scale
    self.bias_init = bias_init
    self.scale_init = scale_init
    self.axis_name = axis_name
    self.axis_index_groups = axis_index_groups

    self.mean = None
    self.var = None
    self.scale = None
    self.bias = None