treex.nn.Lambda
A Module that applies a pure function to its input.
Source code in treex/nn/sequential.py
class Lambda(Module):
"""
A Module that applies a pure function to its input.
"""
f: tp.Callable[[jnp.ndarray], jnp.ndarray]
def __init__(self, f: tp.Callable[[jnp.ndarray], jnp.ndarray]):
"""
Arguments:
f: A function to apply to the input.
"""
self.f = f
def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
"""
Arguments:
x: The input to the function.
Returns:
The output of the function.
"""
return self.f(x)
__call__(self, x)
special
Parameters:
Name | Type | Description | Default |
---|---|---|---|
x |
ndarray |
The input to the function. |
required |
Returns:
Type | Description |
---|---|
ndarray |
The output of the function. |
Source code in treex/nn/sequential.py
def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
"""
Arguments:
x: The input to the function.
Returns:
The output of the function.
"""
return self.f(x)
__init__(self, f)
special
Parameters:
Name | Type | Description | Default |
---|---|---|---|
f |
Callable[[jax._src.numpy.lax_numpy.ndarray], jax._src.numpy.lax_numpy.ndarray] |
A function to apply to the input. |
required |
Source code in treex/nn/sequential.py
def __init__(self, f: tp.Callable[[jnp.ndarray], jnp.ndarray]):
"""
Arguments:
f: A function to apply to the input.
"""
self.f = f