Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 34 additions & 12 deletions jax/_src/tree_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,10 +334,10 @@ def register_pytree_node_class(cls: Typ) -> Typ:
... def __init__(self, x, y):
... self.x = x
... self.y = y
... def tree_flatten(self):
... def __tree_flatten__(self):
... return ((self.x, self.y), None)
... @classmethod
... def tree_unflatten(cls, aux_data, children):
... def __tree_unflatten__(cls, aux_data, children):
... return cls(*children)
...
>>> m = MyContainer(jnp.zeros(4), jnp.arange(4))
Expand All @@ -346,7 +346,20 @@ def register_pytree_node_class(cls: Typ) -> Typ:
>>> jax.jit(f)(m)
Array([0., 2., 4., 6.], dtype=float32)
"""
register_pytree_node(cls, op.methodcaller("tree_flatten"), cls.tree_unflatten)
has_dunder_flatten = hasattr(cls, "__tree_flatten__")
has_dunder_unflatten = hasattr(cls, "__tree_unflatten__")
if has_dunder_flatten != has_dunder_unflatten:
defined = "__tree_flatten__" if has_dunder_flatten else "__tree_unflatten__"
raise ValueError(
"register_pytree_node_class: if using dunder methods, the type must define"
" both a __tree_flatten__ method and a __tree_unflatten__ classmethod."
f" {cls=} only defines {defined}.")
if has_dunder_flatten and has_dunder_unflatten:
register_pytree_node(
cls, op.methodcaller("__tree_flatten__"), cls.__tree_unflatten__)
else:
register_pytree_node(
cls, op.methodcaller("tree_flatten"), cls.tree_unflatten)
return cls


Expand Down Expand Up @@ -966,19 +979,28 @@ class that defines how it could be flattened with keys.
... def __init__(self, x, y):
... self.x = x
... self.y = y
... def tree_flatten_with_keys(self):
... def __tree_flatten_with_keys__(self):
... return (((GetAttrKey('x'), self.x), (GetAttrKey('y'), self.y)), None)
... @classmethod
... def tree_unflatten(cls, aux_data, children):
... def __tree_unflatten__(cls, aux_data, children):
... return cls(*children)
"""
flatten_func = (
op.methodcaller("tree_flatten") if hasattr(cls, "tree_flatten") else None
)
register_pytree_with_keys(
cls, op.methodcaller("tree_flatten_with_keys"), cls.tree_unflatten,
flatten_func
)
has_dunder_flatten = hasattr(cls, "__tree_flatten_with_keys__")
has_dunder_unflatten = hasattr(cls, "__tree_unflatten__")
if has_dunder_flatten != has_dunder_unflatten:
defined = "__tree_flatten_with_keys__" if has_dunder_flatten else "__tree_unflatten__"
raise ValueError(
"register_pytree_with_keys_class: if using dunder methods, the type must"
" define both a __tree_flatten_with_keys__ method and a __tree_unflatten__"
f" classmethod. {cls=} only defines {defined}.")
if has_dunder_flatten and has_dunder_unflatten:
register_pytree_with_keys(
cls, op.methodcaller("__tree_flatten_with_keys__"),
cls.__tree_unflatten__, getattr(cls, "__tree_flatten__", None))
else:
register_pytree_with_keys(
cls, op.methodcaller("tree_flatten_with_keys"),
cls.tree_unflatten, getattr(cls, "tree_flatten", None))
return cls


Expand Down
52 changes: 52 additions & 0 deletions tests/tree_util_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,33 @@ def tree_flatten_with_keys(self):
(tree_util.GetAttrKey('y'), self.y)), None)


@tree_util.register_pytree_node_class
class SpecialDunder:
def __init__(self, x, y):
self.x = x
self.y = y

def __repr__(self):
return f"SpecialDunder(x={self.x}, y={self.y})"

def __tree_flatten__(self):
return ((self.x, self.y), None)

@classmethod
def __tree_unflatten__(cls, aux_data, children):
return cls(*children)

def __eq__(self, other):
return type(self) is type(other) and (self.x, self.y) == (other.x, other.y)


@tree_util.register_pytree_with_keys_class
class SpecialDunderWithKeys(SpecialDunder):
def __tree_flatten_with_keys__(self):
return (((tree_util.GetAttrKey('x'), self.x),
(tree_util.GetAttrKey('y'), self.y)), None)


@tree_util.register_pytree_node_class
class FlatCache:
def __init__(self, structured, *, leaves=None, treedef=None):
Expand Down Expand Up @@ -190,6 +217,7 @@ def __eq__(self, other):
([AnObject(3, None, [4, "foo"])],),
([AnObject2(3, None, [4, "foo"])],),
(Special(2, 3.0),),
(SpecialDunder(2, 3.0),),
({"a": 1, "b": 2},),
(StaticInt(1),),
(StaticTuple((2, 3)),),
Expand All @@ -210,6 +238,7 @@ def __eq__(self, other):
"PyTreeDef([CustomNode(AnObject[[4, 'foo']], [*, None])])",
"PyTreeDef([CustomNode(AnObject2[[4, 'foo']], [*, None])])",
"PyTreeDef(CustomNode(Special[None], [*, *]))",
"PyTreeDef(CustomNode(SpecialDunder[None], [*, *]))",
"PyTreeDef({'a': *, 'b': *})",
"PyTreeDef(CustomNode(StaticInt[1], []))",
"PyTreeDef(CustomNode(StaticTuple[(2, 3)], []))",
Expand Down Expand Up @@ -281,6 +310,7 @@ class ADataclassWithMeta:
([3, ATuple(foo=(3, ATuple(foo=3, bar=None)), bar={"baz": 34})],),
([AnObject2(3, None, [4, "foo"])],),
(SpecialWithKeys(2, 3.),),
(SpecialDunderWithKeys(2, 3.),),
({"a": 1, "b": 0},),
(collections.OrderedDict([("foo", 34), ("baz", 101), ("something", -42)]),),
(collections.defaultdict(dict,
Expand Down Expand Up @@ -953,6 +983,28 @@ def testFlattenOneLevel(self):
with self.assertRaisesRegex(ValueError, "can't tree-flatten type"):
flatten_one_level(jnp.array((1, 2)))

def testSingleDunderError(self):
class SingleDunder:
def __tree_flatten__(self):
return (), ''
@classmethod
def tree_unflatten(cls, aux_data, children):
return cls()
msg = "register_pytree_node_class.*SingleDunder.*only defines __tree_flatten__"
with self.assertRaisesRegex(ValueError, msg):
tree_util.register_pytree_node_class(SingleDunder)

def testSingleDunderErrorWithKeys(self):
class SingleDunder:
@classmethod
def __tree_unflatten__(cls, aux_data, children):
return cls()
def tree_flatten_with_keys(cls):
return (), None
msg = "register_pytree_with_keys_class.*SingleDunder.*only defines __tree_unflatten__"
with self.assertRaisesRegex(ValueError, msg):
tree_util.register_pytree_with_keys_class(SingleDunder)

def testOptionalFlatten(self):
@tree_util.register_pytree_with_keys_class
class FooClass:
Expand Down
Loading