treex.nn.LayerNorm
LayerNorm Module.
LayerNorm
is implemented as a wrapper over flax.linen.LayerNorm
, its constructor
arguments accept the same arguments including any Flax artifacts such as initializers.
It normalizes the activations of the layer for each given example in a batch independently, rather than across a batch like Batch Normalization. i.e. applies a transformation that maintains the mean activation within each example close to 0 and the activation standard deviation close to 1.
Source code in treex/nn/norm.py
class LayerNorm(Module):
"""LayerNorm Module.
`LayerNorm` is implemented as a wrapper over `flax.linen.LayerNorm`, its constructor
arguments accept the same arguments including any Flax artifacts such as initializers.
It normalizes the activations of the layer for each given example in a batch independently, rather than across a batch like Batch Normalization. i.e. applies a transformation that maintains the mean activation within each example close to 0 and the activation standard deviation close to 1.
"""
# pytree
scale: tp.Optional[jnp.ndarray] = types.Parameter.node()
bias: tp.Optional[jnp.ndarray] = types.Parameter.node()
# props
epsilon: float
dtype: flax_module.Dtype
use_bias: bool
use_scale: bool
bias_init: tp.Callable[
[flax_module.PRNGKey, flax_module.Shape, flax_module.Dtype],
flax_module.Array,
]
scale_init: tp.Callable[
[flax_module.PRNGKey, flax_module.Shape, flax_module.Dtype],
flax_module.Array,
]
def __init__(
self,
*,
epsilon: float = 1e-5,
dtype: flax_module.Dtype = jnp.float32,
use_bias: bool = True,
use_scale: bool = True,
bias_init: tp.Callable[
[flax_module.PRNGKey, flax_module.Shape, flax_module.Dtype],
flax_module.Array,
] = flax_module.initializers.zeros,
scale_init: tp.Callable[
[flax_module.PRNGKey, flax_module.Shape, flax_module.Dtype],
flax_module.Array,
] = flax_module.initializers.ones,
):
"""
Arguments:
epsilon: a small float added to variance to avoid dividing by zero.
dtype: the dtype of the computation (default: float32).
use_bias: if True, bias (beta) is added.
use_scale: if True, multiply by scale (gamma).
When the next layer is linear (also e.g. nn.relu), this can be disabled
since the scaling will be done by the next layer.
bias_init: initializer for bias, by default, zero.
scale_init: initializer for scale, by default, one.
"""
self.epsilon = epsilon
self.dtype = dtype
self.use_bias = use_bias
self.use_scale = use_scale
self.bias_init = bias_init
self.scale_init = scale_init
self.scale = None
self.bias = None
@property
def module(self) -> flax_module.LayerNorm:
return flax_module.LayerNorm(
epsilon=self.epsilon,
dtype=self.dtype,
use_bias=self.use_bias,
use_scale=self.use_scale,
bias_init=self.bias_init,
scale_init=self.scale_init,
)
def __call__(
self,
x: jnp.ndarray,
) -> jnp.ndarray:
"""Normalizes individual input on the last axis (channels) of the input data.
Arguments:
x: the input to be normalized.
Returns:
Normalized inputs (the same shape as inputs).
"""
if self.initializing():
variables = self.module.init(
next_key(),
x,
).unfreeze()
# Extract collections
if "params" in variables:
params = variables["params"]
if self.use_bias:
self.bias = params["bias"]
if self.use_scale:
self.scale = params["scale"]
params = {}
if self.use_bias:
params["bias"] = self.bias
if self.use_scale:
params["scale"] = self.scale
variables = dict(
params=params,
)
# call apply
output = self.module.apply(
variables,
x,
)
return tp.cast(jnp.ndarray, output)
__call__(self, x)
special
Normalizes individual input on the last axis (channels) of the input data.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
x |
ndarray |
the input to be normalized. |
required |
Returns:
Type | Description |
---|---|
ndarray |
Normalized inputs (the same shape as inputs). |
Source code in treex/nn/norm.py
def __call__(
self,
x: jnp.ndarray,
) -> jnp.ndarray:
"""Normalizes individual input on the last axis (channels) of the input data.
Arguments:
x: the input to be normalized.
Returns:
Normalized inputs (the same shape as inputs).
"""
if self.initializing():
variables = self.module.init(
next_key(),
x,
).unfreeze()
# Extract collections
if "params" in variables:
params = variables["params"]
if self.use_bias:
self.bias = params["bias"]
if self.use_scale:
self.scale = params["scale"]
params = {}
if self.use_bias:
params["bias"] = self.bias
if self.use_scale:
params["scale"] = self.scale
variables = dict(
params=params,
)
# call apply
output = self.module.apply(
variables,
x,
)
return tp.cast(jnp.ndarray, output)
__init__(self, *, epsilon=1e-05, dtype=<class 'jax._src.numpy.lax_numpy.float32'>, use_bias=True, use_scale=True, bias_init=<function zeros at 0x7f8860592290>, scale_init=<function ones at 0x7f8860598200>)
special
Parameters:
Name | Type | Description | Default |
---|---|---|---|
epsilon |
float |
a small float added to variance to avoid dividing by zero. |
1e-05 |
dtype |
Any |
the dtype of the computation (default: float32). |
<class 'jax._src.numpy.lax_numpy.float32'> |
use_bias |
bool |
if True, bias (beta) is added. |
True |
use_scale |
bool |
if True, multiply by scale (gamma). When the next layer is linear (also e.g. nn.relu), this can be disabled since the scaling will be done by the next layer. |
True |
bias_init |
Callable[[Any, Tuple[int], Any], Any] |
initializer for bias, by default, zero. |
<function zeros at 0x7f8860592290> |
scale_init |
Callable[[Any, Tuple[int], Any], Any] |
initializer for scale, by default, one. |
<function ones at 0x7f8860598200> |
Source code in treex/nn/norm.py
def __init__(
self,
*,
epsilon: float = 1e-5,
dtype: flax_module.Dtype = jnp.float32,
use_bias: bool = True,
use_scale: bool = True,
bias_init: tp.Callable[
[flax_module.PRNGKey, flax_module.Shape, flax_module.Dtype],
flax_module.Array,
] = flax_module.initializers.zeros,
scale_init: tp.Callable[
[flax_module.PRNGKey, flax_module.Shape, flax_module.Dtype],
flax_module.Array,
] = flax_module.initializers.ones,
):
"""
Arguments:
epsilon: a small float added to variance to avoid dividing by zero.
dtype: the dtype of the computation (default: float32).
use_bias: if True, bias (beta) is added.
use_scale: if True, multiply by scale (gamma).
When the next layer is linear (also e.g. nn.relu), this can be disabled
since the scaling will be done by the next layer.
bias_init: initializer for bias, by default, zero.
scale_init: initializer for scale, by default, one.
"""
self.epsilon = epsilon
self.dtype = dtype
self.use_bias = use_bias
self.use_scale = use_scale
self.bias_init = bias_init
self.scale_init = scale_init
self.scale = None
self.bias = None