Skip to content

treex.nn.Lambda

A Module that applies a pure function to its input.

Source code in treex/nn/sequential.py
class Lambda(Module):
    """
    A Module that applies a pure function to its input.
    """

    f: tp.Callable[[jnp.ndarray], jnp.ndarray]

    def __init__(self, f: tp.Callable[[jnp.ndarray], jnp.ndarray]):
        """
        Arguments:
            f: A function to apply to the input.
        """

        self.f = f

    def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
        """
        Arguments:
            x: The input to the function.
        Returns:
            The output of the function.
        """
        return self.f(x)

__call__(self, x) special

Parameters:

Name Type Description Default
x ndarray

The input to the function.

required

Returns:

Type Description
ndarray

The output of the function.

Source code in treex/nn/sequential.py
def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
    """
    Arguments:
        x: The input to the function.
    Returns:
        The output of the function.
    """
    return self.f(x)

__init__(self, f) special

Parameters:

Name Type Description Default
f Callable[[jax._src.numpy.lax_numpy.ndarray], jax._src.numpy.lax_numpy.ndarray]

A function to apply to the input.

required
Source code in treex/nn/sequential.py
def __init__(self, f: tp.Callable[[jnp.ndarray], jnp.ndarray]):
    """
    Arguments:
        f: A function to apply to the input.
    """

    self.f = f