Skip to content

Commit a4b2998

Browse files
Cristian GarciaFlax Authors
Cristian Garcia
authored and
Flax Authors
committed
add dataclass
PiperOrigin-RevId: 755993917
1 parent a7157cf commit a4b2998

File tree

5 files changed

+188
-11
lines changed

5 files changed

+188
-11
lines changed

flax/nnx/__init__.py

+3
Original file line numberDiff line numberDiff line change
@@ -175,3 +175,6 @@
175175
from .extract import NodeStates as NodeStates
176176
from .summary import tabulate as tabulate
177177
from . import traversals as traversals
178+
from .dataclasses import dataclass as dataclass
179+
from .dataclasses import Static as Static
180+
from .dataclasses import field as field

flax/nnx/dataclasses.py

+152
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,152 @@
1+
# Copyright 2024 The Flax Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from __future__ import annotations
16+
17+
import dataclasses
18+
from dataclasses import field
19+
import typing as tp
20+
import typing_extensions as tpe
21+
22+
from flax import config
23+
from flax.nnx import object as objectlib
24+
25+
A = tp.TypeVar('A')
26+
T = tp.TypeVar('T', bound=type[objectlib.Object])
27+
28+
29+
class StaticTag:
30+
...
31+
32+
33+
Static = tp.Annotated[A, StaticTag]
34+
35+
36+
def _is_static(annotation):
37+
return getattr(annotation, '__metadata__', None) == (StaticTag,)
38+
39+
40+
@tp.overload
41+
def dataclass(
42+
cls: T,
43+
/,
44+
*,
45+
init: bool = True,
46+
eq: bool = True,
47+
order: bool = False,
48+
unsafe_hash: bool = False,
49+
match_args: bool = True,
50+
kw_only: bool = False,
51+
slots: bool = False,
52+
) -> T:
53+
...
54+
55+
56+
@tp.overload
57+
def dataclass(
58+
*,
59+
init: bool = True,
60+
eq: bool = True,
61+
order: bool = False,
62+
unsafe_hash: bool = False,
63+
match_args: bool = True,
64+
kw_only: bool = False,
65+
slots: bool = False,
66+
) -> tp.Callable[[T], T]:
67+
...
68+
69+
70+
@tpe.dataclass_transform(
71+
field_specifiers=(field,),
72+
)
73+
def dataclass(
74+
cls: T | None = None,
75+
/,
76+
*,
77+
init: bool = True,
78+
eq: bool = True,
79+
order: bool = False,
80+
unsafe_hash: bool = False,
81+
match_args: bool = True,
82+
kw_only: bool = False,
83+
slots: bool = False,
84+
) -> T | tp.Callable[[T], T]:
85+
"""Makes an nnx.Object type as a dataclass and defines its pytree node attributes using type hints.
86+
87+
``nnx.dataclass`` can be used to create pytree dataclass types using type
88+
hints instead of the ``__data__`` attribute. By default, all fields are
89+
considered to be nodes, to mark a field as static annotate it with
90+
``nnx.Static[T]``.
91+
92+
Example::
93+
94+
>>> from flax import nnx
95+
>>> import jax
96+
...
97+
>>> @nnx.dataclass
98+
... class Foo(nnx.Object):
99+
... a: int
100+
... b: jax.Array
101+
... c: nnx.Static[int]
102+
...
103+
>>> tree = Foo(a=1, b=jax.numpy.array(1), c=1)
104+
>>> assert len(jax.tree.leaves(tree)) == 2 # a and b
105+
106+
``dataclass`` will raise a ``ValueError`` if the class does not derive from
107+
``nnx.Object``, if the parent Object has ``pytree`` set to anything other than
108+
``'strict'``, or if the class has a ``__data__`` attribute.
109+
110+
``nnx.dataclass`` doesn't accept ``repr`` and defines it as ``False`` to avoid
111+
overwriting the default ``__repr__`` method from ``Object``.
112+
"""
113+
114+
def _dataclass(cls: T):
115+
if not issubclass(cls, objectlib.Object):
116+
raise ValueError(
117+
'dataclass can only be used with a class derived from nnx.Object'
118+
)
119+
if '__data__' in vars(cls):
120+
raise ValueError(
121+
'dataclass can only be used with a class without a __data__ attribute'
122+
)
123+
if config.flax_mutable_array:
124+
if cls._object__pytree_mode != 'strict':
125+
raise ValueError(
126+
"dataclass can only be used with a class with pytree='strict', "
127+
f'got {cls._object__pytree_mode}'
128+
)
129+
130+
# here we redefine _object__nodes using the type hints
131+
hints = cls.__annotations__
132+
all_nodes = list(cls._object__nodes)
133+
all_nodes.extend(name for name, typ in hints.items() if not _is_static(typ))
134+
cls._object__nodes = frozenset(all_nodes)
135+
136+
cls = dataclasses.dataclass( # type: ignore
137+
cls,
138+
init=init,
139+
repr=False,
140+
eq=eq,
141+
order=order,
142+
unsafe_hash=unsafe_hash,
143+
match_args=match_args,
144+
kw_only=kw_only,
145+
slots=slots,
146+
)
147+
return cls
148+
149+
if cls is None:
150+
return _dataclass
151+
else:
152+
return _dataclass(cls)

flax/nnx/object.py

+3-4
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,6 @@
4242
O = tp.TypeVar('O', bound='Object')
4343

4444
BUILDING_DOCS = 'FLAX_DOC_BUILD' in os.environ
45-
DEFAULT_PYTREE_MODE = 'strict' if config.flax_mutable_array else None
4645

4746
def _collect_stats(
4847
node: tp.Any, node_stats: dict[int, dict[type[Variable], SizeBytes]]
@@ -192,9 +191,8 @@ class Object(reprlib.Representable, metaclass=ObjectMeta):
192191
_object__state: ObjectState
193192

194193
def __init_subclass__(
195-
cls, pytree: tp.Literal['strict', 'auto', 'all'] | None = DEFAULT_PYTREE_MODE, **kwargs
194+
cls, pytree: tp.Literal['strict', 'auto', 'all'] = 'strict', **kwargs
196195
) -> None:
197-
cls._object__pytree_mode = pytree
198196
super().__init_subclass__(**kwargs)
199197

200198
graph.register_graph_node_type(
@@ -207,6 +205,7 @@ def __init_subclass__(
207205
init=cls._graph_node_init, # type: ignore
208206
)
209207
if config.flax_mutable_array and pytree is not None:
208+
cls._object__pytree_mode = pytree
210209
parent_pytree_mode = getattr(cls, '_object__pytree_mode', None)
211210
if (
212211
parent_pytree_mode is not None
@@ -463,7 +462,7 @@ def _object__strict_unflatten(
463462
vars_obj.update(zip(node_names, node_attrs, strict=True))
464463
vars_obj.update(static_attrs)
465464
return obj
466-
465+
467466
# all
468467
def _object__all_flatten_with_paths(self):
469468
obj_vars = vars(self)

tests/nnx/graph_utils_test.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
class List(nnx.Module):
2828
if config.flax_mutable_array:
2929
__data__ = ('items',)
30-
30+
3131
def __init__(self, items):
3232
self.items = list(items)
3333

tests/nnx/mutable_array_test.py

+29-6
Original file line numberDiff line numberDiff line change
@@ -12,20 +12,20 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
import pytest
16-
17-
import flax.errors
15+
from absl.testing import absltest
1816
from flax import config
1917
from flax import nnx
18+
import flax.errors
2019
import jax
2120
import jax.numpy as jnp
22-
from absl.testing import absltest
21+
import pytest
22+
2323

2424
@pytest.mark.skipif(
2525
not config.flax_mutable_array, reason='MutableArray not enabled'
2626
)
27-
class TestMutableArray(absltest.TestCase):
28-
def test_tree_map(self):
27+
class TestObject(absltest.TestCase):
28+
def test_pytree(self):
2929
class Foo(nnx.Module):
3030
__data__ = ('node',)
3131

@@ -40,6 +40,25 @@ def __init__(self):
4040
assert m.node == 2
4141
assert m.meta == 1
4242

43+
def test_pytree_dataclass(self):
44+
@nnx.dataclass
45+
class Foo(nnx.Module):
46+
node: jax.Array
47+
meta: nnx.Static[int]
48+
49+
m = Foo(node=jnp.array(1), meta=1)
50+
51+
m: Foo = jax.tree.map(lambda x: x + 1, m)
52+
53+
assert m.node == 2
54+
assert m.meta == 1
55+
56+
57+
@pytest.mark.skipif(
58+
not config.flax_mutable_array, reason='MutableArray not enabled'
59+
)
60+
class TestMutableArray(absltest.TestCase):
61+
4362
def test_static(self):
4463
class C(nnx.Module):
4564
def __init__(self, meta):
@@ -185,3 +204,7 @@ def test_rngs_call(self):
185204
rngs = nnx.Rngs(0)
186205
key = rngs()
187206
self.assertIsInstance(key, jax.Array)
207+
208+
209+
if __name__ == '__main__':
210+
absltest.main()

0 commit comments

Comments
 (0)