From f4eb760780473dbc4faa0440ee2ed0b0c90a9876 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Mon, 10 Nov 2025 13:23:29 -0800 Subject: [PATCH] Support pytree classes with dunder-style flatten/unflatten methods --- jax/_src/tree_util.py | 46 ++++++++++++++++++++++++++---------- tests/tree_util_test.py | 52 +++++++++++++++++++++++++++++++++++++++++ 2 files changed, 86 insertions(+), 12 deletions(-) diff --git a/jax/_src/tree_util.py b/jax/_src/tree_util.py index 4f439a770e69..57e99732289c 100644 --- a/jax/_src/tree_util.py +++ b/jax/_src/tree_util.py @@ -334,10 +334,10 @@ def register_pytree_node_class(cls: Typ) -> Typ: ... def __init__(self, x, y): ... self.x = x ... self.y = y - ... def tree_flatten(self): + ... def __tree_flatten__(self): ... return ((self.x, self.y), None) ... @classmethod - ... def tree_unflatten(cls, aux_data, children): + ... def __tree_unflatten__(cls, aux_data, children): ... return cls(*children) ... >>> m = MyContainer(jnp.zeros(4), jnp.arange(4)) @@ -346,7 +346,20 @@ def register_pytree_node_class(cls: Typ) -> Typ: >>> jax.jit(f)(m) Array([0., 2., 4., 6.], dtype=float32) """ - register_pytree_node(cls, op.methodcaller("tree_flatten"), cls.tree_unflatten) + has_dunder_flatten = hasattr(cls, "__tree_flatten__") + has_dunder_unflatten = hasattr(cls, "__tree_unflatten__") + if has_dunder_flatten != has_dunder_unflatten: + defined = "__tree_flatten__" if has_dunder_flatten else "__tree_unflatten__" + raise ValueError( + "register_pytree_node_class: if using dunder methods, the type must define" + " both a __tree_flatten__ method and a __tree_unflatten__ classmethod." + f" {cls=} only defines {defined}.") + if has_dunder_flatten and has_dunder_unflatten: + register_pytree_node( + cls, op.methodcaller("__tree_flatten__"), cls.__tree_unflatten__) + else: + register_pytree_node( + cls, op.methodcaller("tree_flatten"), cls.tree_unflatten) return cls @@ -966,19 +979,28 @@ class that defines how it could be flattened with keys. ... def __init__(self, x, y): ... self.x = x ... self.y = y - ... def tree_flatten_with_keys(self): + ... def __tree_flatten_with_keys__(self): ... return (((GetAttrKey('x'), self.x), (GetAttrKey('y'), self.y)), None) ... @classmethod - ... def tree_unflatten(cls, aux_data, children): + ... def __tree_unflatten__(cls, aux_data, children): ... return cls(*children) """ - flatten_func = ( - op.methodcaller("tree_flatten") if hasattr(cls, "tree_flatten") else None - ) - register_pytree_with_keys( - cls, op.methodcaller("tree_flatten_with_keys"), cls.tree_unflatten, - flatten_func - ) + has_dunder_flatten = hasattr(cls, "__tree_flatten_with_keys__") + has_dunder_unflatten = hasattr(cls, "__tree_unflatten__") + if has_dunder_flatten != has_dunder_unflatten: + defined = "__tree_flatten_with_keys__" if has_dunder_flatten else "__tree_unflatten__" + raise ValueError( + "register_pytree_with_keys_class: if using dunder methods, the type must" + " define both a __tree_flatten_with_keys__ method and a __tree_unflatten__" + f" classmethod. {cls=} only defines {defined}.") + if has_dunder_flatten and has_dunder_unflatten: + register_pytree_with_keys( + cls, op.methodcaller("__tree_flatten_with_keys__"), + cls.__tree_unflatten__, getattr(cls, "__tree_flatten__", None)) + else: + register_pytree_with_keys( + cls, op.methodcaller("tree_flatten_with_keys"), + cls.tree_unflatten, getattr(cls, "tree_flatten", None)) return cls diff --git a/tests/tree_util_test.py b/tests/tree_util_test.py index e148e6bd355e..2504bcba643a 100644 --- a/tests/tree_util_test.py +++ b/tests/tree_util_test.py @@ -115,6 +115,33 @@ def tree_flatten_with_keys(self): (tree_util.GetAttrKey('y'), self.y)), None) +@tree_util.register_pytree_node_class +class SpecialDunder: + def __init__(self, x, y): + self.x = x + self.y = y + + def __repr__(self): + return f"SpecialDunder(x={self.x}, y={self.y})" + + def __tree_flatten__(self): + return ((self.x, self.y), None) + + @classmethod + def __tree_unflatten__(cls, aux_data, children): + return cls(*children) + + def __eq__(self, other): + return type(self) is type(other) and (self.x, self.y) == (other.x, other.y) + + +@tree_util.register_pytree_with_keys_class +class SpecialDunderWithKeys(SpecialDunder): + def __tree_flatten_with_keys__(self): + return (((tree_util.GetAttrKey('x'), self.x), + (tree_util.GetAttrKey('y'), self.y)), None) + + @tree_util.register_pytree_node_class class FlatCache: def __init__(self, structured, *, leaves=None, treedef=None): @@ -190,6 +217,7 @@ def __eq__(self, other): ([AnObject(3, None, [4, "foo"])],), ([AnObject2(3, None, [4, "foo"])],), (Special(2, 3.0),), + (SpecialDunder(2, 3.0),), ({"a": 1, "b": 2},), (StaticInt(1),), (StaticTuple((2, 3)),), @@ -210,6 +238,7 @@ def __eq__(self, other): "PyTreeDef([CustomNode(AnObject[[4, 'foo']], [*, None])])", "PyTreeDef([CustomNode(AnObject2[[4, 'foo']], [*, None])])", "PyTreeDef(CustomNode(Special[None], [*, *]))", + "PyTreeDef(CustomNode(SpecialDunder[None], [*, *]))", "PyTreeDef({'a': *, 'b': *})", "PyTreeDef(CustomNode(StaticInt[1], []))", "PyTreeDef(CustomNode(StaticTuple[(2, 3)], []))", @@ -281,6 +310,7 @@ class ADataclassWithMeta: ([3, ATuple(foo=(3, ATuple(foo=3, bar=None)), bar={"baz": 34})],), ([AnObject2(3, None, [4, "foo"])],), (SpecialWithKeys(2, 3.),), + (SpecialDunderWithKeys(2, 3.),), ({"a": 1, "b": 0},), (collections.OrderedDict([("foo", 34), ("baz", 101), ("something", -42)]),), (collections.defaultdict(dict, @@ -953,6 +983,28 @@ def testFlattenOneLevel(self): with self.assertRaisesRegex(ValueError, "can't tree-flatten type"): flatten_one_level(jnp.array((1, 2))) + def testSingleDunderError(self): + class SingleDunder: + def __tree_flatten__(self): + return (), '' + @classmethod + def tree_unflatten(cls, aux_data, children): + return cls() + msg = "register_pytree_node_class.*SingleDunder.*only defines __tree_flatten__" + with self.assertRaisesRegex(ValueError, msg): + tree_util.register_pytree_node_class(SingleDunder) + + def testSingleDunderErrorWithKeys(self): + class SingleDunder: + @classmethod + def __tree_unflatten__(cls, aux_data, children): + return cls() + def tree_flatten_with_keys(cls): + return (), None + msg = "register_pytree_with_keys_class.*SingleDunder.*only defines __tree_unflatten__" + with self.assertRaisesRegex(ValueError, msg): + tree_util.register_pytree_with_keys_class(SingleDunder) + def testOptionalFlatten(self): @tree_util.register_pytree_with_keys_class class FooClass: