Skip to content

treex.nn.LayerNorm

LayerNorm Module.

LayerNorm is implemented as a wrapper over flax.linen.LayerNorm, its constructor arguments accept the same arguments including any Flax artifacts such as initializers.

It normalizes the activations of the layer for each given example in a batch independently, rather than across a batch like Batch Normalization. i.e. applies a transformation that maintains the mean activation within each example close to 0 and the activation standard deviation close to 1.

Source code in treex/nn/norm.py
class LayerNorm(Module):
    """LayerNorm Module.

    `LayerNorm` is implemented as a wrapper over `flax.linen.LayerNorm`, its constructor
    arguments accept the same arguments including any Flax artifacts such as initializers.

    It normalizes the activations of the layer for each given example in a batch independently, rather than across a batch like Batch Normalization. i.e. applies a transformation that maintains the mean activation within each example close to 0 and the activation standard deviation close to 1.

    """

    # pytree
    scale: tp.Optional[jnp.ndarray] = types.Parameter.node()
    bias: tp.Optional[jnp.ndarray] = types.Parameter.node()

    # props
    epsilon: float
    dtype: flax_module.Dtype
    use_bias: bool
    use_scale: bool
    bias_init: tp.Callable[
        [flax_module.PRNGKey, flax_module.Shape, flax_module.Dtype],
        flax_module.Array,
    ]
    scale_init: tp.Callable[
        [flax_module.PRNGKey, flax_module.Shape, flax_module.Dtype],
        flax_module.Array,
    ]

    def __init__(
        self,
        *,
        epsilon: float = 1e-5,
        dtype: flax_module.Dtype = jnp.float32,
        use_bias: bool = True,
        use_scale: bool = True,
        bias_init: tp.Callable[
            [flax_module.PRNGKey, flax_module.Shape, flax_module.Dtype],
            flax_module.Array,
        ] = flax_module.initializers.zeros,
        scale_init: tp.Callable[
            [flax_module.PRNGKey, flax_module.Shape, flax_module.Dtype],
            flax_module.Array,
        ] = flax_module.initializers.ones,
    ):
        """
        Arguments:
            epsilon: a small float added to variance to avoid dividing by zero.
            dtype: the dtype of the computation (default: float32).
            use_bias:  if True, bias (beta) is added.
            use_scale: if True, multiply by scale (gamma).
                When the next layer is linear (also e.g. nn.relu), this can be disabled
                since the scaling will be done by the next layer.
            bias_init: initializer for bias, by default, zero.
            scale_init: initializer for scale, by default, one.
        """

        self.epsilon = epsilon
        self.dtype = dtype
        self.use_bias = use_bias
        self.use_scale = use_scale
        self.bias_init = bias_init
        self.scale_init = scale_init

        self.scale = None
        self.bias = None

    @property
    def module(self) -> flax_module.LayerNorm:
        return flax_module.LayerNorm(
            epsilon=self.epsilon,
            dtype=self.dtype,
            use_bias=self.use_bias,
            use_scale=self.use_scale,
            bias_init=self.bias_init,
            scale_init=self.scale_init,
        )

    def __call__(
        self,
        x: jnp.ndarray,
    ) -> jnp.ndarray:
        """Normalizes individual input on the last axis (channels) of the input data.

        Arguments:
            x: the input to be normalized.

        Returns:
            Normalized inputs (the same shape as inputs).
        """
        if self.initializing():
            variables = self.module.init(
                next_key(),
                x,
            ).unfreeze()

            # Extract collections
            if "params" in variables:
                params = variables["params"]

                if self.use_bias:
                    self.bias = params["bias"]

                if self.use_scale:
                    self.scale = params["scale"]

        params = {}

        if self.use_bias:
            params["bias"] = self.bias

        if self.use_scale:
            params["scale"] = self.scale

        variables = dict(
            params=params,
        )

        # call apply
        output = self.module.apply(
            variables,
            x,
        )

        return tp.cast(jnp.ndarray, output)

__call__(self, x) special

Normalizes individual input on the last axis (channels) of the input data.

Parameters:

Name Type Description Default
x ndarray

the input to be normalized.

required

Returns:

Type Description
ndarray

Normalized inputs (the same shape as inputs).

Source code in treex/nn/norm.py
def __call__(
    self,
    x: jnp.ndarray,
) -> jnp.ndarray:
    """Normalizes individual input on the last axis (channels) of the input data.

    Arguments:
        x: the input to be normalized.

    Returns:
        Normalized inputs (the same shape as inputs).
    """
    if self.initializing():
        variables = self.module.init(
            next_key(),
            x,
        ).unfreeze()

        # Extract collections
        if "params" in variables:
            params = variables["params"]

            if self.use_bias:
                self.bias = params["bias"]

            if self.use_scale:
                self.scale = params["scale"]

    params = {}

    if self.use_bias:
        params["bias"] = self.bias

    if self.use_scale:
        params["scale"] = self.scale

    variables = dict(
        params=params,
    )

    # call apply
    output = self.module.apply(
        variables,
        x,
    )

    return tp.cast(jnp.ndarray, output)

__init__(self, *, epsilon=1e-05, dtype=<class 'jax._src.numpy.lax_numpy.float32'>, use_bias=True, use_scale=True, bias_init=<function zeros at 0x7f8860592290>, scale_init=<function ones at 0x7f8860598200>) special

Parameters:

Name Type Description Default
epsilon float

a small float added to variance to avoid dividing by zero.

1e-05
dtype Any

the dtype of the computation (default: float32).

<class 'jax._src.numpy.lax_numpy.float32'>
use_bias bool

if True, bias (beta) is added.

True
use_scale bool

if True, multiply by scale (gamma). When the next layer is linear (also e.g. nn.relu), this can be disabled since the scaling will be done by the next layer.

True
bias_init Callable[[Any, Tuple[int], Any], Any]

initializer for bias, by default, zero.

<function zeros at 0x7f8860592290>
scale_init Callable[[Any, Tuple[int], Any], Any]

initializer for scale, by default, one.

<function ones at 0x7f8860598200>
Source code in treex/nn/norm.py
def __init__(
    self,
    *,
    epsilon: float = 1e-5,
    dtype: flax_module.Dtype = jnp.float32,
    use_bias: bool = True,
    use_scale: bool = True,
    bias_init: tp.Callable[
        [flax_module.PRNGKey, flax_module.Shape, flax_module.Dtype],
        flax_module.Array,
    ] = flax_module.initializers.zeros,
    scale_init: tp.Callable[
        [flax_module.PRNGKey, flax_module.Shape, flax_module.Dtype],
        flax_module.Array,
    ] = flax_module.initializers.ones,
):
    """
    Arguments:
        epsilon: a small float added to variance to avoid dividing by zero.
        dtype: the dtype of the computation (default: float32).
        use_bias:  if True, bias (beta) is added.
        use_scale: if True, multiply by scale (gamma).
            When the next layer is linear (also e.g. nn.relu), this can be disabled
            since the scaling will be done by the next layer.
        bias_init: initializer for bias, by default, zero.
        scale_init: initializer for scale, by default, one.
    """

    self.epsilon = epsilon
    self.dtype = dtype
    self.use_bias = use_bias
    self.use_scale = use_scale
    self.bias_init = bias_init
    self.scale_init = scale_init

    self.scale = None
    self.bias = None