treex.merge
Creates a new Tree with the same structure but its values merged based on the values from the incoming Trees. For more information see merge's user guide.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
obj |
~A |
Main pytree to merge. |
required |
other |
~A |
The pytree first to get the values to merge with. |
required |
*rest |
~A |
Additional pytree to perform the merge in order from left to right. |
() |
inplace |
bool |
If |
False |
flatten_mode |
Union[treeo.tree.FlattenMode, str] |
Sets a new |
None |
ignore_static |
bool |
If |
False |
Returns:
Type | Description |
---|---|
~A |
A new pytree with the updated values. If |
Source code in treeo/api.py
def merge(
obj: A,
other: A,
*rest: A,
inplace: bool = False,
flatten_mode: tp.Union[FlattenMode, str, None] = None,
ignore_static: bool = False,
) -> A:
"""
Creates a new Tree with the same structure but its values merged based on the values from the incoming Trees. For more information see
[merge's user guide](https://cgarciae.github.io/treeo/user-guide/api/merge).
Arguments:
obj: Main pytree to merge.
other: The pytree first to get the values to merge with.
*rest: Additional pytree to perform the merge in order from left to right.
inplace: If `True`, the input `obj` is mutated and returned.
flatten_mode: Sets a new `FlattenMode` context for the operation, if `None` the current context is used. If the current flatten context is `None` and `flatten_mode` is not passed then `FlattenMode.all_fields` is used.
ignore_static: If `True`, bypasses static fields during the process and the statics fields for output are taken from the first input (`obj`).
Returns:
A new pytree with the updated values. If `inplace` is `True`, `obj` is returned.
"""
if inplace and not hasattr(obj, "__dict__"):
raise TypeError(
f"Cannot update inplace on objects with no __dict__ property, got {obj}"
)
if flatten_mode is None and tree_m._CONTEXT.flatten_mode is None:
flatten_mode = FlattenMode.all_fields
input_obj = obj
def merge_fn(*xs):
for x in reversed(xs):
if not isinstance(x, types.Nothing):
return x
return types.NOTHING
tree_map_fn = _looser_tree_map if ignore_static else jax.tree_map
with _flatten_context(flatten_mode):
obj = tree_map_fn(
merge_fn,
obj,
other,
*rest,
is_leaf=lambda x: isinstance(x, LEAF_TYPES),
)
if inplace:
input_obj.__dict__.update(obj.__dict__)
return input_obj
else:
return obj