Skip to content

treex.Initializer

Initialize a field from a function that expects a single argument with a PRNGKey.

Initializers are called by Module.init and replace the value of the field they are assigned to.

Source code in treex/types.py
class Initializer(to.Tree):
    """Initialize a field from a function that expects a single argument with a PRNGKey.

    Initializers are called by `Module.init` and replace the value of the field they are assigned to.
    """

    f: tp.Callable[[jnp.ndarray], tp.Any]

    def __init__(self, f: tp.Callable[[jnp.ndarray], tp.Any]):
        """
        Arguments:
            f: A function that takes a PRNGKey and returns the initial value of the field.
        """
        self.f = f

    def __call__(self, x: jnp.ndarray) -> np.ndarray:
        return self.f(x)

    def __repr__(self) -> str:
        return "Initializer"

__init__(self, f) special

Parameters:

Name Type Description Default
f Callable[[jax._src.numpy.lax_numpy.ndarray], Any]

A function that takes a PRNGKey and returns the initial value of the field.

required
Source code in treex/types.py
def __init__(self, f: tp.Callable[[jnp.ndarray], tp.Any]):
    """
    Arguments:
        f: A function that takes a PRNGKey and returns the initial value of the field.
    """
    self.f = f