treex.nn.Conv
Convolution Module wrapping lax.conv_general_dilated.
Conv
is implemented as a wrapper over flax.linen.Conv
, its constructor
arguments accept almost the same arguments including any Flax artifacts such as initializers.
Main differences:
- receives
features_in
as a first argument since shapes must be statically known. features
argument is renamed tofeatures_out
.
Source code in treex/nn/conv.py
class Conv(Module):
"""Convolution Module wrapping lax.conv_general_dilated.
`Conv` is implemented as a wrapper over `flax.linen.Conv`, its constructor
arguments accept almost the same arguments including any Flax artifacts such as initializers.
Main differences:
* receives `features_in` as a first argument since shapes must be statically known.
* `features` argument is renamed to `features_out`.
"""
# pytree
kernel: tp.Optional[jnp.ndarray] = types.Parameter.node()
bias: tp.Optional[jnp.ndarray] = types.Parameter.node()
# props
features_out: int
kernel_size: tp.Union[int, tp.Iterable[int]]
strides: tp.Optional[tp.Iterable[int]]
padding: tp.Union[str, tp.Iterable[tp.Tuple[int, int]]]
input_dilation: tp.Optional[tp.Iterable[int]]
kernel_dilation: tp.Optional[tp.Iterable[int]]
feature_group_count: int
use_bias: bool
dtype: flax_module.Dtype
precision: tp.Any
kernel_init: tp.Callable[
[flax_module.PRNGKey, flax_module.Shape, flax_module.Dtype],
flax_module.Array,
]
bias_init: tp.Callable[
[flax_module.PRNGKey, flax_module.Shape, flax_module.Dtype],
flax_module.Array,
]
def __init__(
self,
features_out: int,
kernel_size: tp.Union[int, tp.Iterable[int]],
*,
strides: tp.Optional[tp.Iterable[int]] = None,
padding: tp.Union[str, tp.Iterable[tp.Tuple[int, int]]] = "SAME",
input_dilation: tp.Optional[tp.Iterable[int]] = None,
kernel_dilation: tp.Optional[tp.Iterable[int]] = None,
feature_group_count: int = 1,
use_bias: bool = True,
dtype: flax_module.Dtype = jnp.float32,
param_dtype: flax_module.Dtype = jnp.float32,
precision: tp.Any = None,
kernel_init: tp.Callable[
[flax_module.PRNGKey, flax_module.Shape, flax_module.Dtype],
flax_module.Array,
] = flax_module.default_kernel_init,
bias_init: tp.Callable[
[flax_module.PRNGKey, flax_module.Shape, flax_module.Dtype],
flax_module.Array,
] = flax_module.zeros,
):
"""
Arguments:
features_out: number of convolution filters.
kernel_size: shape of the convolutional kernel. For 1D convolution,
the kernel size can be passed as an integer. For all other cases, it must
be a sequence of integers.
strides: a sequence of `n` integers, representing the inter-window
strides.
padding: either the string `'SAME'`, the string `'VALID'`, the string
`'CIRCULAR'` (periodic boundary conditions), or a sequence of `n` `(low,
high)` integer pairs that give the padding to apply before and after each
spatial dimension.
input_dilation: `None`, or a sequence of `n` integers, giving the
dilation factor to apply in each spatial dimension of `inputs`.
Convolution with input dilation `d` is equivalent to transposed
convolution with stride `d`.
kernel_dilation: `None`, or a sequence of `n` integers, giving the
dilation factor to apply in each spatial dimension of the convolution
kernel. Convolution with kernel dilation is also known as 'atrous
convolution'.
feature_group_count: integer, default 1. If specified divides the input
features into groups.
use_bias: whether to add a bias to the output (default: True).
dtype: the dtype of the computation (default: float32).
param_dtype: the dtype passed to parameter initializers (default: float32).
precision: numerical precision of the computation see `jax.lax.Precision`
for details.
kernel_init: initializer for the convolutional kernel.
bias_init: initializer for the bias.
"""
self.features_out = features_out
self.kernel_size = kernel_size
self.strides = strides
self.padding = padding
self.input_dilation = input_dilation
self.kernel_dilation = kernel_dilation
self.feature_group_count = feature_group_count
self.use_bias = use_bias
self.dtype = dtype
self.param_dtype = param_dtype
self.precision = precision
self.kernel_init = kernel_init
self.bias_init = bias_init
self.kernel = None
self.bias = None
@property
def module(self) -> flax_module.Conv:
return flax_module.Conv(
features=self.features_out,
kernel_size=self.kernel_size,
strides=self.strides,
padding=self.padding,
input_dilation=self.input_dilation,
kernel_dilation=self.kernel_dilation,
feature_group_count=self.feature_group_count,
use_bias=self.use_bias,
dtype=self.dtype,
param_dtype=self.param_dtype,
precision=self.precision,
kernel_init=self.kernel_init,
bias_init=self.bias_init,
)
def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
"""Applies a convolution to the inputs.
Arguments:
x: input data with dimensions (batch, spatial_dims..., features).
Returns:
The convolved data.
"""
if self.initializing():
variables = self.module.init({"params": next_key()}, x)
# Extract collections
params = variables["params"].unfreeze()
self.kernel = params["kernel"]
if self.use_bias:
self.bias = params["bias"]
assert self.kernel is not None
params = {"kernel": self.kernel}
if self.use_bias:
assert self.bias is not None
params["bias"] = self.bias
output = self.module.apply({"params": params}, x)
return tp.cast(jnp.ndarray, output)
__call__(self, x)
special
Applies a convolution to the inputs.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
x |
ndarray |
input data with dimensions (batch, spatial_dims..., features). |
required |
Returns:
Type | Description |
---|---|
ndarray |
The convolved data. |
Source code in treex/nn/conv.py
def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
"""Applies a convolution to the inputs.
Arguments:
x: input data with dimensions (batch, spatial_dims..., features).
Returns:
The convolved data.
"""
if self.initializing():
variables = self.module.init({"params": next_key()}, x)
# Extract collections
params = variables["params"].unfreeze()
self.kernel = params["kernel"]
if self.use_bias:
self.bias = params["bias"]
assert self.kernel is not None
params = {"kernel": self.kernel}
if self.use_bias:
assert self.bias is not None
params["bias"] = self.bias
output = self.module.apply({"params": params}, x)
return tp.cast(jnp.ndarray, output)
__init__(self, features_out, kernel_size, *, strides=None, padding='SAME', input_dilation=None, kernel_dilation=None, feature_group_count=1, use_bias=True, dtype=<class 'jax._src.numpy.lax_numpy.float32'>, param_dtype=<class 'jax._src.numpy.lax_numpy.float32'>, precision=None, kernel_init=<function variance_scaling.<locals>.init at 0x7f8862b6a680>, bias_init=<function zeros at 0x7f8860592290>)
special
Parameters:
Name | Type | Description | Default |
---|---|---|---|
features_out |
int |
number of convolution filters. |
required |
kernel_size |
Union[int, Iterable[int]] |
shape of the convolutional kernel. For 1D convolution, the kernel size can be passed as an integer. For all other cases, it must be a sequence of integers. |
required |
strides |
Optional[Iterable[int]] |
a sequence of |
None |
padding |
Union[str, Iterable[Tuple[int, int]]] |
either the string |
'SAME' |
input_dilation |
Optional[Iterable[int]] |
|
None |
kernel_dilation |
Optional[Iterable[int]] |
|
None |
feature_group_count |
int |
integer, default 1. If specified divides the input features into groups. |
1 |
use_bias |
bool |
whether to add a bias to the output (default: True). |
True |
dtype |
Any |
the dtype of the computation (default: float32). |
<class 'jax._src.numpy.lax_numpy.float32'> |
param_dtype |
Any |
the dtype passed to parameter initializers (default: float32). |
<class 'jax._src.numpy.lax_numpy.float32'> |
precision |
Any |
numerical precision of the computation see |
None |
kernel_init |
Callable[[Any, Iterable[int], Any], Any] |
initializer for the convolutional kernel. |
<function variance_scaling.<locals>.init at 0x7f8862b6a680> |
bias_init |
Callable[[Any, Iterable[int], Any], Any] |
initializer for the bias. |
<function zeros at 0x7f8860592290> |
Source code in treex/nn/conv.py
def __init__(
self,
features_out: int,
kernel_size: tp.Union[int, tp.Iterable[int]],
*,
strides: tp.Optional[tp.Iterable[int]] = None,
padding: tp.Union[str, tp.Iterable[tp.Tuple[int, int]]] = "SAME",
input_dilation: tp.Optional[tp.Iterable[int]] = None,
kernel_dilation: tp.Optional[tp.Iterable[int]] = None,
feature_group_count: int = 1,
use_bias: bool = True,
dtype: flax_module.Dtype = jnp.float32,
param_dtype: flax_module.Dtype = jnp.float32,
precision: tp.Any = None,
kernel_init: tp.Callable[
[flax_module.PRNGKey, flax_module.Shape, flax_module.Dtype],
flax_module.Array,
] = flax_module.default_kernel_init,
bias_init: tp.Callable[
[flax_module.PRNGKey, flax_module.Shape, flax_module.Dtype],
flax_module.Array,
] = flax_module.zeros,
):
"""
Arguments:
features_out: number of convolution filters.
kernel_size: shape of the convolutional kernel. For 1D convolution,
the kernel size can be passed as an integer. For all other cases, it must
be a sequence of integers.
strides: a sequence of `n` integers, representing the inter-window
strides.
padding: either the string `'SAME'`, the string `'VALID'`, the string
`'CIRCULAR'` (periodic boundary conditions), or a sequence of `n` `(low,
high)` integer pairs that give the padding to apply before and after each
spatial dimension.
input_dilation: `None`, or a sequence of `n` integers, giving the
dilation factor to apply in each spatial dimension of `inputs`.
Convolution with input dilation `d` is equivalent to transposed
convolution with stride `d`.
kernel_dilation: `None`, or a sequence of `n` integers, giving the
dilation factor to apply in each spatial dimension of the convolution
kernel. Convolution with kernel dilation is also known as 'atrous
convolution'.
feature_group_count: integer, default 1. If specified divides the input
features into groups.
use_bias: whether to add a bias to the output (default: True).
dtype: the dtype of the computation (default: float32).
param_dtype: the dtype passed to parameter initializers (default: float32).
precision: numerical precision of the computation see `jax.lax.Precision`
for details.
kernel_init: initializer for the convolutional kernel.
bias_init: initializer for the bias.
"""
self.features_out = features_out
self.kernel_size = kernel_size
self.strides = strides
self.padding = padding
self.input_dilation = input_dilation
self.kernel_dilation = kernel_dilation
self.feature_group_count = feature_group_count
self.use_bias = use_bias
self.dtype = dtype
self.param_dtype = param_dtype
self.precision = precision
self.kernel_init = kernel_init
self.bias_init = bias_init
self.kernel = None
self.bias = None