Skip to content

treex.nn.Embed

A linear transformation applied over the last dimension of the input.

Embed is implemented as a wrapper over flax.linen.Embed, its constructor arguments accept almost the same arguments including any Flax artifacts such as initializers.

Source code in treex/nn/embed.py
class Embed(Module):
    """A linear transformation applied over the last dimension of the input.

    `Embed` is implemented as a wrapper over `flax.linen.Embed`, its constructor
    arguments accept almost the same arguments including any Flax artifacts such as initializers.
    """

    # pytree
    num_embeddings: int
    features: int
    dtype: flax_module.Dtype
    embedding_init: tp.Callable[[PRNGKey, Shape, Dtype], Array]

    embedding: tp.Optional[Array] = types.Parameter.node()

    def __init__(
        self,
        num_embeddings: int,
        features: int,
        *,
        dtype: flax_module.Dtype = jnp.float32,
        embedding_init: tp.Callable[
            [PRNGKey, Shape, Dtype], Array
        ] = flax_module.default_embed_init,
        name: tp.Optional[str] = None,
    ):
        """
        Arguments:
            num_embeddings: number of embeddings.
            features: number of feature dimensions for each embedding.
            dtype: the dtype of the embedding vectors (default: float32).
            embedding_init: embedding initializer.
        """
        super().__init__(name=name)
        self.num_embeddings = num_embeddings
        self.features = features
        self.dtype = dtype
        self.embedding_init = embedding_init

        self.embedding = None

    @property
    def module(self) -> flax_module.Embed:
        return flax_module.Embed(
            num_embeddings=self.num_embeddings,
            features=self.features,
            dtype=self.dtype,
            embedding_init=self.embedding_init,
        )

    def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
        """Embeds the inputs along the last dimension.

        Arguments:
            inputs: input data, all dimensions are considered batch dimensions.

        Returns:
            Output which is embedded input data.  The output shape follows the input,
            with an additional `features` dimension appended.
        """
        if self.initializing():
            rngs = {"params": next_key()}
            variables = self.module.init(rngs, x)

            # Extract collections
            params = variables["params"].unfreeze()

            self.embedding = params["embedding"]

        assert self.embedding is not None
        params = {"embedding": self.embedding}

        output = self.module.apply({"params": params}, x)
        return tp.cast(jnp.ndarray, output)

__call__(self, x) special

Embeds the inputs along the last dimension.

Parameters:

Name Type Description Default
inputs

input data, all dimensions are considered batch dimensions.

required

Returns:

Type Description
ndarray

Output which is embedded input data. The output shape follows the input, with an additional features dimension appended.

Source code in treex/nn/embed.py
def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
    """Embeds the inputs along the last dimension.

    Arguments:
        inputs: input data, all dimensions are considered batch dimensions.

    Returns:
        Output which is embedded input data.  The output shape follows the input,
        with an additional `features` dimension appended.
    """
    if self.initializing():
        rngs = {"params": next_key()}
        variables = self.module.init(rngs, x)

        # Extract collections
        params = variables["params"].unfreeze()

        self.embedding = params["embedding"]

    assert self.embedding is not None
    params = {"embedding": self.embedding}

    output = self.module.apply({"params": params}, x)
    return tp.cast(jnp.ndarray, output)

__init__(self, num_embeddings, features, *, dtype=<class 'jax._src.numpy.lax_numpy.float32'>, embedding_init=<function variance_scaling.<locals>.init at 0x7f8862afbb00>, name=None) special

Parameters:

Name Type Description Default
num_embeddings int

number of embeddings.

required
features int

number of feature dimensions for each embedding.

required
dtype Any

the dtype of the embedding vectors (default: float32).

<class 'jax._src.numpy.lax_numpy.float32'>
embedding_init Callable[[Any, Iterable[int], Any], Any]

embedding initializer.

<function variance_scaling.<locals>.init at 0x7f8862afbb00>
Source code in treex/nn/embed.py
def __init__(
    self,
    num_embeddings: int,
    features: int,
    *,
    dtype: flax_module.Dtype = jnp.float32,
    embedding_init: tp.Callable[
        [PRNGKey, Shape, Dtype], Array
    ] = flax_module.default_embed_init,
    name: tp.Optional[str] = None,
):
    """
    Arguments:
        num_embeddings: number of embeddings.
        features: number of feature dimensions for each embedding.
        dtype: the dtype of the embedding vectors (default: float32).
        embedding_init: embedding initializer.
    """
    super().__init__(name=name)
    self.num_embeddings = num_embeddings
    self.features = features
    self.dtype = dtype
    self.embedding_init = embedding_init

    self.embedding = None