|
| 1 | +# Copyright 2024 The Flax Authors. |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | + |
| 15 | +from __future__ import annotations |
| 16 | + |
| 17 | +import dataclasses |
| 18 | +from dataclasses import field |
| 19 | +import typing as tp |
| 20 | +import typing_extensions as tpe |
| 21 | + |
| 22 | +from flax import config |
| 23 | +from flax.nnx import object as objectlib |
| 24 | + |
| 25 | +A = tp.TypeVar('A') |
| 26 | +T = tp.TypeVar('T', bound=type[objectlib.Object]) |
| 27 | + |
| 28 | + |
| 29 | +class StaticTag: |
| 30 | + ... |
| 31 | + |
| 32 | + |
| 33 | +Static = tp.Annotated[A, StaticTag] |
| 34 | + |
| 35 | + |
| 36 | +def _is_static(annotation): |
| 37 | + return getattr(annotation, '__metadata__', None) == (StaticTag,) |
| 38 | + |
| 39 | + |
| 40 | +@tp.overload |
| 41 | +def dataclass( |
| 42 | + cls: T, |
| 43 | + /, |
| 44 | + *, |
| 45 | + init: bool = True, |
| 46 | + eq: bool = True, |
| 47 | + order: bool = False, |
| 48 | + unsafe_hash: bool = False, |
| 49 | + match_args: bool = True, |
| 50 | + kw_only: bool = False, |
| 51 | + slots: bool = False, |
| 52 | +) -> T: |
| 53 | + ... |
| 54 | + |
| 55 | + |
| 56 | +@tp.overload |
| 57 | +def dataclass( |
| 58 | + *, |
| 59 | + init: bool = True, |
| 60 | + eq: bool = True, |
| 61 | + order: bool = False, |
| 62 | + unsafe_hash: bool = False, |
| 63 | + match_args: bool = True, |
| 64 | + kw_only: bool = False, |
| 65 | + slots: bool = False, |
| 66 | +) -> tp.Callable[[T], T]: |
| 67 | + ... |
| 68 | + |
| 69 | + |
| 70 | +@tpe.dataclass_transform( |
| 71 | + field_specifiers=(field,), |
| 72 | +) |
| 73 | +def dataclass( |
| 74 | + cls: T | None = None, |
| 75 | + /, |
| 76 | + *, |
| 77 | + init: bool = True, |
| 78 | + eq: bool = True, |
| 79 | + order: bool = False, |
| 80 | + unsafe_hash: bool = False, |
| 81 | + match_args: bool = True, |
| 82 | + kw_only: bool = False, |
| 83 | + slots: bool = False, |
| 84 | +) -> T | tp.Callable[[T], T]: |
| 85 | + """Makes an nnx.Object type as a dataclass and defines its pytree node attributes using type hints. |
| 86 | +
|
| 87 | + ``nnx.dataclass`` can be used to create pytree dataclass types using type |
| 88 | + hints instead of the ``__data__`` attribute. By default, all fields are |
| 89 | + considered to be nodes, to mark a field as static annotate it with |
| 90 | + ``nnx.Static[T]``. |
| 91 | +
|
| 92 | + Example:: |
| 93 | +
|
| 94 | + >>> from flax import nnx |
| 95 | + >>> import jax |
| 96 | + ... |
| 97 | + >>> @nnx.dataclass |
| 98 | + ... class Foo(nnx.Object): |
| 99 | + ... a: int |
| 100 | + ... b: jax.Array |
| 101 | + ... c: nnx.Static[int] |
| 102 | + ... |
| 103 | + >>> tree = Foo(a=1, b=jax.numpy.array(1), c=1) |
| 104 | + >>> assert len(jax.tree.leaves(tree)) == 2 # a and b |
| 105 | +
|
| 106 | + ``dataclass`` will raise a ``ValueError`` if the class does not derive from |
| 107 | + ``nnx.Object``, if the parent Object has ``pytree`` set to anything other than |
| 108 | + ``'strict'``, or if the class has a ``__data__`` attribute. |
| 109 | +
|
| 110 | + ``nnx.dataclass`` doesn't accept ``repr`` and defines it as ``False`` to avoid |
| 111 | + overwriting the default ``__repr__`` method from ``Object``. |
| 112 | + """ |
| 113 | + |
| 114 | + def _dataclass(cls: T): |
| 115 | + if not issubclass(cls, objectlib.Object): |
| 116 | + raise ValueError( |
| 117 | + 'dataclass can only be used with a class derived from nnx.Object' |
| 118 | + ) |
| 119 | + if '__data__' in vars(cls): |
| 120 | + raise ValueError( |
| 121 | + 'dataclass can only be used with a class without a __data__ attribute' |
| 122 | + ) |
| 123 | + if config.flax_mutable_array: |
| 124 | + if cls._object__pytree_mode != 'strict': |
| 125 | + raise ValueError( |
| 126 | + "dataclass can only be used with a class with pytree='strict', " |
| 127 | + f'got {cls._object__pytree_mode}' |
| 128 | + ) |
| 129 | + |
| 130 | + # here we redefine _object__nodes using the type hints |
| 131 | + hints = cls.__annotations__ |
| 132 | + all_nodes = list(cls._object__nodes) |
| 133 | + all_nodes.extend(name for name, typ in hints.items() if not _is_static(typ)) |
| 134 | + cls._object__nodes = frozenset(all_nodes) |
| 135 | + |
| 136 | + cls = dataclasses.dataclass( # type: ignore |
| 137 | + cls, |
| 138 | + init=init, |
| 139 | + repr=False, |
| 140 | + eq=eq, |
| 141 | + order=order, |
| 142 | + unsafe_hash=unsafe_hash, |
| 143 | + match_args=match_args, |
| 144 | + kw_only=kw_only, |
| 145 | + slots=slots, |
| 146 | + ) |
| 147 | + return cls |
| 148 | + |
| 149 | + if cls is None: |
| 150 | + return _dataclass |
| 151 | + else: |
| 152 | + return _dataclass(cls) |
0 commit comments