Creates a new Module with the same structure but its values merged based on the values from the incoming Modules.
@dataclass class MyModule(tx.Module): a: int = tx.field(node=True, kind=Parameter) b: int = tx.field(node=True, kind=BatchStat) m1 = MyModule(x=Nothing, y=2, z=3) m2 = MyModule(x=1, y=Nothing, z=4) m1.merge(m2) # MyModule(x=1, y=2, z=4)
Updates are performed using the following rules:
- For a list of equivalent leaves
l1, l2, ..., ln, it returns the first non-
Nothingleaf from right to left.
- If no
flatten_mode()context manager is active and
flatten_modeis not given, all fields will be updated.
flatten_mode="normal"is set then static fields won't be updated and the output will have the exact same static components as the first input (
merge with multiple Modules the following equivalence holds:
m1.merge(m2, m3) = m1.merge(m2.merge(m3))
If you want to merge the current module instead of creating a new one use
inplace=True. This is useful when applying transformation inside a method where reassigning
self is not possible:
def double_params(self): # this is not doing what you expect self = jax.tree_map(lambda x: 2 * x, self)
def double_params(self): doubled = jax.tree_map(lambda x: 2 * x, self) self.merge(doubled, inplace=True)
True, the input
obj is mutated and returned. You can only merge inplace if the input
obj has a
__dict__ attribute, else a
TypeError is raised.
True, static fields (according to the flattening mode) will be bypassed during the merge process, the final output will have the same static components as the first input (
obj). This strategy is a bit less safe in general as it will flatten all trees using
jax.tree_leaves instead of
PyTreeDef.flatten_up_to, this skips some checks so it effectively ignores their static components, the only requirement is that the flattened struture of all trees matches.