Skip to content

treeo.filter

The filter function allows you to select a subtree by filtering based on a predicate or kind type, leaves that pass all filters are kept, the rest are set to Nothing. For more information see filter's user guide.

Parameters:

Name Type Description Default
obj ~A

A pytree (possibly containing to.Trees) to be filtered.

required
*filters Union[Type[Any], Callable[[FieldInfo], bool]]

Types to filter by, membership is determined by issubclass, or callables that take in a FieldInfo and return a bool.

()
flatten_mode Union[treeo.tree.FlattenMode, str]

Sets a new FlattenMode context for the operation.

None

Returns:

Type Description
~A

A new pytree with the filtered fields.

Source code in treeo/api.py
def filter(
    obj: A,
    *filters: Filter,
    flatten_mode: tp.Union[FlattenMode, str, None] = None,
) -> A:
    """
    The `filter` function allows you to select a subtree by filtering based on a predicate or `kind` type,
    leaves that pass all filters are kept, the rest are set to `Nothing`. For more information see
    [filter's user guide](https://cgarciae.github.io/treeo/user-guide/api/filter).



    Arguments:
        obj: A pytree (possibly containing `to.Tree`s) to be filtered.
        *filters: Types to filter by, membership is determined by `issubclass`, or
            callables that take in a `FieldInfo` and return a `bool`.
        flatten_mode: Sets a new `FlattenMode` context for the operation.
    Returns:
        A new pytree with the filtered fields.

    """

    input_obj = obj

    filters = tuple(
        _get_kind_filter(f) if isinstance(f, tp.Type) else f for f in filters
    )

    def apply_filters(info: tp.Any) -> tp.Any:
        if not isinstance(info, FieldInfo):
            info = FieldInfo(
                name=None,
                value=info,
                kind=type(None),
                module=None,
            )
        assert isinstance(info, FieldInfo)

        return info.value if all(f(info) for f in filters) else types.NOTHING

    with tree_m._CONTEXT.update(add_field_info=True), _flatten_context(flatten_mode):
        obj = jax.tree_map(apply_filters, obj)

    return obj