treex.to_dict
Source code in treeo/api.py
def to_dict(
obj: tp.Any,
*,
private_fields: bool = False,
static_fields: bool = True,
type_info: bool = False,
field_info: bool = False,
) -> tp.Any:
if field_info:
with add_field_info(), flatten_mode(FlattenMode.all_fields):
flat, treedef = jax.tree_flatten(obj)
obj = jax.tree_unflatten(treedef, flat)
obj = apply(_remove_field_info_from_metadata, obj)
return _to_dict(obj, private_fields, static_fields, type_info)