treex.filter
The filter
function allows you to select a subtree by filtering based on a predicate or kind
type,
leaves that pass all filters are kept, the rest are set to Nothing
. For more information see
filter's user guide.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
obj |
~A |
A pytree (possibly containing |
required |
*filters |
Union[Type[Any], Callable[[FieldInfo], bool]] |
Types to filter by, membership is determined by |
() |
inplace |
bool |
If |
False |
flatten_mode |
Union[treeo.tree.FlattenMode, str] |
Sets a new |
None |
Returns:
Type | Description |
---|---|
~A |
A new pytree with the filtered fields. If |
Source code in treeo/api.py
def filter(
obj: A,
*filters: Filter,
inplace: bool = False,
flatten_mode: tp.Union[FlattenMode, str, None] = None,
) -> A:
"""
The `filter` function allows you to select a subtree by filtering based on a predicate or `kind` type,
leaves that pass all filters are kept, the rest are set to `Nothing`. For more information see
[filter's user guide](https://cgarciae.github.io/treeo/user-guide/api/filter).
Arguments:
obj: A pytree (possibly containing `to.Tree`s) to be filtered.
*filters: Types to filter by, membership is determined by `issubclass`, or
callables that take in a `FieldInfo` and return a `bool`.
inplace: If `True`, the input `obj` is mutated and returned.
flatten_mode: Sets a new `FlattenMode` context for the operation.
Returns:
A new pytree with the filtered fields. If `inplace` is `True`, `obj` is returned.
"""
if inplace and not hasattr(obj, "__dict__"):
raise ValueError(
f"Cannot filter inplace on objects with no __dict__ property, got {obj}"
)
input_obj = obj
filters = tuple(
_get_kind_filter(f) if isinstance(f, tp.Type) else f for f in filters
)
def apply_filters(info: tp.Any) -> tp.Any:
if not isinstance(info, FieldInfo):
info = FieldInfo(
name=None,
value=info,
kind=type(None),
module=None,
)
assert isinstance(info, FieldInfo)
return info.value if all(f(info) for f in filters) else types.NOTHING
with tree_m._CONTEXT.update(add_field_info=True), _flatten_context(flatten_mode):
obj = jax.tree_map(apply_filters, obj)
if inplace:
input_obj.__dict__.update(obj.__dict__)
return input_obj
else:
return obj