From f0f813c0d9d9793dc1b96cfa9a54e172060f64ae Mon Sep 17 00:00:00 2001 From: Orbax Authors Date: Mon, 1 Dec 2025 18:15:42 -0800 Subject: [PATCH] Internal change PiperOrigin-RevId: 839010273 --- .../orbax/checkpoint/_src/arrays/fragments.py | 73 +++++++++++++------ 1 file changed, 51 insertions(+), 22 deletions(-) diff --git a/checkpoint/orbax/checkpoint/_src/arrays/fragments.py b/checkpoint/orbax/checkpoint/_src/arrays/fragments.py index 0e7ce84bd..49ffb09d2 100644 --- a/checkpoint/orbax/checkpoint/_src/arrays/fragments.py +++ b/checkpoint/orbax/checkpoint/_src/arrays/fragments.py @@ -19,7 +19,7 @@ """ import dataclasses -from typing import Optional, Sequence, TypeAlias +from typing import ClassVar, Optional, Sequence, TypeAlias import jax import numpy as np @@ -43,7 +43,7 @@ def _index_from_ndarray(a: NpIndex) -> Index: @dataclasses.dataclass(frozen=True, init=False) -class Fragment: +class _Fragment: """One of a collection of slices into the same (abstract or concrete) array. Fields: @@ -113,8 +113,8 @@ def shape(self) -> Shape: def size(self) -> int: return np.prod(self.shape) - def __eq__(self, other: 'Fragment'): - if not isinstance(other, Fragment): + def __eq__(self, other: '_Fragment'): + if not isinstance(other, _Fragment): return False if not np.array_equal(self.np_index, other.np_index): return False @@ -159,15 +159,15 @@ def nbytes_astype(self, dtype: np.dtype) -> int: def offset_by( self, delta: np.ndarray, # shape=[{rank}], dtype=int - ) -> 'Fragment': + ) -> '_Fragment': out_idx = self.np_index.copy() out_idx[:, :2] += np.expand_dims(delta, axis=1) - return Fragment(np_index=out_idx, value=self.value) + return _Fragment(np_index=out_idx, value=self.value) def slice( self, np_index: NpIndex, # shape=[{rank}, 3], dtype=int - ) -> Optional['Fragment']: + ) -> Optional['_Fragment']: """Slices this fragment to find the part that overlaps the given NpIndex.""" if (self.step != 1).any() or (np_index[:, 2] != 1).any(): raise NotImplementedError('Coming ... soon?') @@ -190,7 +190,7 @@ def slice_of_value( start = self.start stop = self.stop # This is just a convenient way to construct the required tuple of slices. - f = Fragment( + f = _Fragment( np_index=np.stack([ np.maximum(start, new_np_idx[:, 0]), np.minimum(stop, new_np_idx[:, 1]), @@ -201,7 +201,7 @@ def slice_of_value( @dataclasses.dataclass(frozen=True) -class Fragments: +class _Fragments: """An abstract or concrete collection of fragments. A `Fragments` is a lot like a `jax.Array` (or a `jax.ShapeDtypeStruct`) but @@ -210,14 +210,20 @@ class Fragments: of a `jax.Array` (fragments are not required to have the same shape, or to map to a device mesh). """ + # Keep printed representation the same as before the leading underscore + # was added. TODO(b/465183318): Remove this once there are separate + # classes for abstract and concrete fragments. + __qualname__ = 'Fragments' + + FRAGMENT_T: ClassVar[type[_Fragment]] = _Fragment shape: Shape dtype: np.dtype - fragments: Sequence[Fragment] + fragments: Sequence[_Fragment] def __post_init__(self): for fragment in self.fragments: - if not isinstance(fragment, Fragment): + if not isinstance(fragment, _Fragment): raise TypeError( f'Fragments must contain Fragment, not {type(fragment)}.' ) @@ -265,7 +271,7 @@ def __array__(self) -> np.ndarray: def slice( self, index: NpIndex | Index, # shape=[{rank}, 3], dtype=int - ) -> 'Fragments': + ) -> '_Fragments': """Returns a slice of this object.""" if not isinstance(index, np.ndarray): index = np_utils.resolve_slice(index, self.shape) @@ -280,7 +286,7 @@ def slice( f'with out-of-bounds index {_index_from_ndarray(index)}' ) - return Fragments( + return _Fragments( tuple(d.item() for d in sliced_shape), self.dtype, [ @@ -291,7 +297,18 @@ def slice( ) -def _is_full(fragments: Fragments) -> bool: +# TODO(b/465188418): Remove these two aliases once all users have been migrated +# to the more specific ones. +Fragment: TypeAlias = _Fragment +Fragments: TypeAlias = _Fragments + +AbstractFragment: TypeAlias = _Fragment +AbstractFragments: TypeAlias = _Fragments +ConcreteFragment: TypeAlias = _Fragment +ConcreteFragments: TypeAlias = _Fragments + + +def _is_full(fragments: _Fragments) -> bool: """True iff every array element is covered by some fragment.""" present = np.zeros(fragments.shape, dtype=bool) for f in fragments.fragments: @@ -323,19 +340,31 @@ def addressable_shards(x: jax.Array | jax.ShapeDtypeStruct) -> list[Index]: def abstract_fragments( - x: jax.Array | jax.ShapeDtypeStruct | Fragments, -) -> Fragments: + x: jax.Array | jax.ShapeDtypeStruct | AbstractFragments | ConcreteFragments, +) -> AbstractFragments: """Returns abstract fragments matching the given array.""" - if isinstance(x, Fragments): - return x - return Fragments( + if isinstance(x, _Fragments): + # TODO(b/465183318): Replace this condition with an instance check + # once AbstractFragments and ConcreteFragments are separate classes. + if all(f.value is None for f in x.fragments): + return x + else: + return AbstractFragments( + x.shape, + x.dtype, + [AbstractFragment(index=f.index) for f in x.fragments], + ) + return AbstractFragments( x.shape, x.dtype, - [Fragment(index=index, value=None) for index in addressable_shards(x)], + [ + AbstractFragment(index=index, value=None) + for index in addressable_shards(x) + ], ) -def validate_fragments_can_be_stacked(fragments: Fragments) -> None: +def validate_fragments_can_be_stacked(fragments: ConcreteFragments) -> None: """Validates that the given fragments can be stacked.""" if not fragments.fragments: raise ValueError('No fragments to stack.') @@ -353,7 +382,7 @@ def validate_fragments_can_be_stacked(fragments: Fragments) -> None: ) -def stack_fragments(fragments: Fragments | None) -> np.ndarray | None: +def stack_fragments(fragments: ConcreteFragments | None) -> np.ndarray | None: """Stacks the given fragments, which must all have the same shape.""" if fragments is None: return fragments