Skip to content

treex.Linear

A linear transformation applied over the last dimension of the input.

Linear is implemented as a wrapper over flax.linen.Dense, its constructor arguments accept almost the same arguments including any Flax artifacts such as initializers. Main differences:

  • receives features_in as a first argument since shapes must be statically known.
  • features argument is renamed to features_out.
Source code in treex/nn/linear.py
class Linear(Module):
    """A linear transformation applied over the last dimension of the input.

    `Linear` is implemented as a wrapper over `flax.linen.Dense`, its constructor
    arguments accept almost the same arguments including any Flax artifacts such as initializers.
    Main differences:

    * receives `features_in` as a first argument since shapes must be statically known.
    * `features` argument is renamed to `features_out`.
    """

    # pytree
    kernel: tp.Optional[jnp.ndarray] = types.Parameter.node()
    bias: tp.Optional[jnp.ndarray] = types.Parameter.node()

    # static
    features_out: int
    use_bias: bool
    dtype: tp.Any
    precision: tp.Any
    kernel_init: tp.Callable[
        [flax_module.PRNGKey, flax_module.Shape, flax_module.Dtype],
        flax_module.Array,
    ]
    bias_init: tp.Callable[
        [flax_module.PRNGKey, flax_module.Shape, flax_module.Dtype],
        flax_module.Array,
    ]

    def __init__(
        self,
        features_out: int,
        *,
        use_bias: bool = True,
        dtype: tp.Any = jnp.float32,
        precision: tp.Any = None,
        kernel_init: tp.Callable[
            [flax_module.PRNGKey, flax_module.Shape, flax_module.Dtype],
            flax_module.Array,
        ] = flax_module.default_kernel_init,
        bias_init: tp.Callable[
            [flax_module.PRNGKey, flax_module.Shape, flax_module.Dtype],
            flax_module.Array,
        ] = flax_module.zeros,
        name: tp.Optional[str] = None,
        axis_name: tp.Optional[tp.Any] = None
    ):
        """
        Arguments:
            features_in: the number of input features.
            features_out: the number of output features.
            use_bias: whether to add a bias to the output (default: True).
            dtype: the dtype of the computation (default: float32).
            precision: numerical precision of the computation see `jax.lax.Precision`
                for details.
            kernel_init: initializer function for the weight matrix.
            bias_init: initializer function for the bias.
        """
        super().__init__(name=name)
        self.features_out = features_out
        self.use_bias = use_bias
        self.dtype = dtype
        self.precision = precision
        self.kernel_init = kernel_init
        self.bias_init = bias_init
        self.axis_name = axis_name

        self.kernel = None
        self.bias = None

    @property
    def module(self) -> flax_module.Dense:
        return flax_module.Dense(
            features=self.features_out,
            use_bias=self.use_bias,
            dtype=self.dtype,
            precision=self.precision,
            kernel_init=self.kernel_init,
            bias_init=self.bias_init,
        )

    def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
        """Applies a linear transformation to the inputs along the last dimension.

        Arguments:
            x: The nd-array to be transformed.

        Returns:
            The transformed input.
        """
        if self.initializing():
            rngs = {"params": next_key(axis_name=self.axis_name)}
            variables = self.module.init(rngs, x)

            # Extract collections
            params = variables["params"].unfreeze()

            self.kernel = params["kernel"]

            if self.use_bias:
                self.bias = params["bias"]

        assert self.kernel is not None
        params = {"kernel": self.kernel}

        if self.use_bias:
            assert self.bias is not None
            params["bias"] = self.bias

        output = self.module.apply({"params": params}, x)
        return tp.cast(jnp.ndarray, output)

__call__(self, x) special

Applies a linear transformation to the inputs along the last dimension.

Parameters:

Name Type Description Default
x ndarray

The nd-array to be transformed.

required

Returns:

Type Description
ndarray

The transformed input.

Source code in treex/nn/linear.py
def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
    """Applies a linear transformation to the inputs along the last dimension.

    Arguments:
        x: The nd-array to be transformed.

    Returns:
        The transformed input.
    """
    if self.initializing():
        rngs = {"params": next_key(axis_name=self.axis_name)}
        variables = self.module.init(rngs, x)

        # Extract collections
        params = variables["params"].unfreeze()

        self.kernel = params["kernel"]

        if self.use_bias:
            self.bias = params["bias"]

    assert self.kernel is not None
    params = {"kernel": self.kernel}

    if self.use_bias:
        assert self.bias is not None
        params["bias"] = self.bias

    output = self.module.apply({"params": params}, x)
    return tp.cast(jnp.ndarray, output)

__init__(self, features_out, *, use_bias=True, dtype=<class 'jax._src.numpy.lax_numpy.float32'>, precision=None, kernel_init=<function variance_scaling.<locals>.init at 0x7f8862b6a680>, bias_init=<function zeros at 0x7f8860592290>, name=None, axis_name=None) special

Parameters:

Name Type Description Default
features_in

the number of input features.

required
features_out int

the number of output features.

required
use_bias bool

whether to add a bias to the output (default: True).

True
dtype Any

the dtype of the computation (default: float32).

<class 'jax._src.numpy.lax_numpy.float32'>
precision Any

numerical precision of the computation see jax.lax.Precision for details.

None
kernel_init Callable[[Any, Iterable[int], Any], Any]

initializer function for the weight matrix.

<function variance_scaling.<locals>.init at 0x7f8862b6a680>
bias_init Callable[[Any, Iterable[int], Any], Any]

initializer function for the bias.

<function zeros at 0x7f8860592290>
Source code in treex/nn/linear.py
def __init__(
    self,
    features_out: int,
    *,
    use_bias: bool = True,
    dtype: tp.Any = jnp.float32,
    precision: tp.Any = None,
    kernel_init: tp.Callable[
        [flax_module.PRNGKey, flax_module.Shape, flax_module.Dtype],
        flax_module.Array,
    ] = flax_module.default_kernel_init,
    bias_init: tp.Callable[
        [flax_module.PRNGKey, flax_module.Shape, flax_module.Dtype],
        flax_module.Array,
    ] = flax_module.zeros,
    name: tp.Optional[str] = None,
    axis_name: tp.Optional[tp.Any] = None
):
    """
    Arguments:
        features_in: the number of input features.
        features_out: the number of output features.
        use_bias: whether to add a bias to the output (default: True).
        dtype: the dtype of the computation (default: float32).
        precision: numerical precision of the computation see `jax.lax.Precision`
            for details.
        kernel_init: initializer function for the weight matrix.
        bias_init: initializer function for the bias.
    """
    super().__init__(name=name)
    self.features_out = features_out
    self.use_bias = use_bias
    self.dtype = dtype
    self.precision = precision
    self.kernel_init = kernel_init
    self.bias_init = bias_init
    self.axis_name = axis_name

    self.kernel = None
    self.bias = None