Skip to content

treeo.merge

Creates a new Tree with the same structure but its values merged based on the values from the incoming Trees. For more information see merge's user guide.

Parameters:

Name Type Description Default
obj ~A

Main pytree to merge.

required
other ~A

The pytree first to get the values to merge with.

required
*rest ~A

Additional pytree to perform the merge in order from left to right.

()
flatten_mode Union[treeo.tree.FlattenMode, str]

Sets a new FlattenMode context for the operation, if None the current context is used. If the current flatten context is None and flatten_mode is not passed then FlattenMode.all_fields is used.

None
ignore_static bool

If True, bypasses static fields during the process and the statics fields for output are taken from the first input (obj).

False

Returns:

Type Description
~A

A new pytree with the updated values.

Source code in treeo/api.py
def merge(
    obj: A,
    other: A,
    *rest: A,
    flatten_mode: tp.Union[FlattenMode, str, None] = None,
    ignore_static: bool = False,
) -> A:
    """
    Creates a new Tree with the same structure but its values merged based on the values from the incoming Trees. For more information see
    [merge's user guide](https://cgarciae.github.io/treeo/user-guide/api/merge).

    Arguments:
        obj: Main pytree to merge.
        other: The pytree first to get the values to merge with.
        *rest: Additional pytree to perform the merge in order from left to right.
        flatten_mode: Sets a new `FlattenMode` context for the operation, if `None` the current context is used. If the current flatten context is `None` and `flatten_mode` is not passed then `FlattenMode.all_fields` is used.
        ignore_static: If `True`, bypasses static fields during the process and the statics fields for output are taken from the first input (`obj`).

    Returns:
        A new pytree with the updated values.
    """

    if flatten_mode is None and tree_m._CONTEXT.flatten_mode is None:
        flatten_mode = FlattenMode.all_fields

    input_obj = obj

    def merge_fn(*xs):
        for x in reversed(xs):
            if not isinstance(x, types.Nothing):
                return x
        return types.NOTHING

    tree_map_fn = _looser_tree_map if ignore_static else jax.tree_map

    with _flatten_context(flatten_mode):
        obj = tree_map_fn(
            merge_fn,
            obj,
            other,
            *rest,
            is_leaf=lambda x: isinstance(x, LEAF_TYPES),
        )

    return obj