treex.Initializer
Initialize a field from a function that expects a single argument with a PRNGKey.
Initializers are called by Module.init
and replace the value of the field they are assigned to.
Source code in treex/types.py
class Initializer(to.Tree):
"""Initialize a field from a function that expects a single argument with a PRNGKey.
Initializers are called by `Module.init` and replace the value of the field they are assigned to.
"""
f: tp.Callable[[jnp.ndarray], tp.Any]
def __init__(self, f: tp.Callable[[jnp.ndarray], tp.Any]):
"""
Arguments:
f: A function that takes a PRNGKey and returns the initial value of the field.
"""
self.f = f
def __call__(self, x: jnp.ndarray) -> np.ndarray:
return self.f(x)
def __repr__(self) -> str:
return "Initializer"
__init__(self, f)
special
Parameters:
Name | Type | Description | Default |
---|---|---|---|
f |
Callable[[jax._src.numpy.lax_numpy.ndarray], Any] |
A function that takes a PRNGKey and returns the initial value of the field. |
required |
Source code in treex/types.py
def __init__(self, f: tp.Callable[[jnp.ndarray], tp.Any]):
"""
Arguments:
f: A function that takes a PRNGKey and returns the initial value of the field.
"""
self.f = f