Skip to content

Compact

Treeo's compact decorator and Compact mixin allow the initialization of fields and the definition of Tree nodes during a function call. compact enables a simpler syntax for Trees whose computation structure follows the Tree's structure.

For example, if you have Trees with the following behavior:

import Treeo as to

class Child(to.Tree):
    some_node: float = to.node()

    def __call__(self, x):
        ...
        return x

class Parent(to.Tree):

    def __init__(self):
        self.child1 = Child(10)
        self.child2 = Child(20)
        self.child3 = Child(30)

    def __call__(self, x):
        x = self.child1(x)
        x = self.child2(x)
        x = self.child3(x)
        return x

Notice how you have to specify/use the same fields in __init__ and __call__. To reduce the amount of boilerplate you can use the compact decorator:

class Parent(to.Tree):
    @to.compact
    def __call__(self, x):
        x = Child(10)(x)
        x = Child(20)(x)
        x = Child(30)(x)
        return x

While it seems that Child Trees are been created on every call, compact will keep track of the Trees created during the first call, assign them as fields to Parent, and reuse them on subsequent calls; their constructors will be called only once.

Warning

You cannot conditionally construct Trees on a compact method unless that conditional doesn't change during the Tree's lifespan. Adding the following to the previous example will cause trouble:

if x.shape[0] > 10:
    x = Child(10)(x)

The number and order in which the sub-Trees are defined inside compact should always be the same.

Naming

The names of the created Trees are stored in order of creation in ._subtrees, the name of the field will be defined as follows:

  • If the Tree has a name attribute, it will be used as the name of the field.
  • Else if it has a __name__ attribute, it will be used.
  • Else a snake_case version of the Tree's class name will be used.
  • If a field with the same name already exists, a number will be appended to the name.

The previous example will result in the following fields: child, child2, child3.

Compact Mixin

With the Compact mixin you can add the get_field method and the first_run property to a Tree subclass. These methods provide mechanisms to initialize fields at runtime potentially based on some properties of the input. As an example let's code a Linear Tree that does shape inference for its w and b parameters:

class Linear(to.Tree, to.Compact):
    w: float = to.node()
    b: float = to.node()

    def __init__(self, dout, key):
        self.dout = dout
        self.key = key

    @to.compact
    def __call__(self, x):
        din = x.shape[-1]
        w = self.get_field("w", lambda: jax.random.uniform(self.key, [din, self.dout]))
        b = self.get_field("b", lambda: jnp.zeros(shape=[self.dout]))

        return jnp.dot(x, w) + b

get_field will initialize the w and b fields on the first run and fetch their values on subsequent runs. You can also use the first_run property and manually initialize the fields:

class Linear(to.Tree, to.Compact):
    w: float = to.node()
    b: float = to.node()

    def __init__(self, dout, key):
        self.dout = dout
        self.key = key

    @to.compact
    def __call__(self, x):
        if self.first_run:
            din = x.shape[-1]
            self.w = jax.random.uniform(self.key, [din, self.dout])
            self.b = jnp.zeros(shape=[self.dout])

        return jnp.dot(x, self.w) + self.b

This is useful if you want to perform more complex initialization procedures.