Skip to content

treex.to_dict

Source code in treeo/api.py
def to_dict(
    obj: tp.Any,
    *,
    private_fields: bool = False,
    static_fields: bool = True,
    type_info: bool = False,
    field_info: bool = False,
) -> tp.Any:

    if field_info:
        with add_field_info(), flatten_mode(FlattenMode.all_fields):
            flat, treedef = jax.tree_flatten(obj)

        obj = jax.tree_unflatten(treedef, flat)
        obj = apply(_remove_field_info_from_metadata, obj)

    return _to_dict(obj, private_fields, static_fields, type_info)