Skip to content

Commit ad7f746

Browse files
author
Orbax Authors
committed
Internal change.
PiperOrigin-RevId: 839869322
1 parent 8d7de0f commit ad7f746

File tree

2 files changed

+165
-48
lines changed

2 files changed

+165
-48
lines changed

checkpoint/orbax/checkpoint/_src/arrays/fragments.py

Lines changed: 71 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,66 @@ def offset_by(
142142
out_idx[:, :2] += np.expand_dims(delta, axis=1)
143143
return type(self)(np_index=out_idx, value=self.value)
144144

145+
def intersect(
146+
self,
147+
np_index: NpIndex, # shape=[{rank}, 3], dtype=int
148+
) -> _GenericFragment[A] | None:
149+
"""Intersects this fragment with the given NpIndex.
150+
151+
The result is in this fragment's coordinate space. For example,
152+
intersecting a fragment with its own index gives an identical fragment.
153+
154+
Args:
155+
np_index: The NpIndex to intersect with.
156+
157+
Returns:
158+
A new fragment representing the intersection, or None if there is no
159+
overlap.
160+
"""
161+
if (self.step != 1).any() or (np_index[:, 2] != 1).any():
162+
raise NotImplementedError('index steps other than 1 are not supported.')
163+
164+
out_np_index = np_index.copy()
165+
start = out_np_index[:, 0] = np.maximum(out_np_index[:, 0], self.start)
166+
stop = out_np_index[:, 1] = np.minimum(out_np_index[:, 1], self.stop)
167+
if not (start < stop).all():
168+
return None
169+
return type(self)(
170+
np_index=out_np_index, value=self.slice_of_value(out_np_index)
171+
)
172+
173+
def slice(
174+
self,
175+
np_index: NpIndex, # shape=[{rank}, 3], dtype=int
176+
) -> _GenericFragment[A] | None: # Use typing.Self once 3.11 is minimum.
177+
"""Slices this fragment by the given NpIndex.
178+
179+
The result is in the slice's coordinate space. For example, slicing a
180+
fragment by its own index gives a fragment whose start is zero.
181+
182+
Args:
183+
np_index: The NpIndex to slice by.
184+
185+
Returns:
186+
A new fragment representing the slice, or None if there is no overlap.
187+
"""
188+
intersection = self.intersect(np_index)
189+
return intersection.offset_by(-np_index[:, 0]) if intersection else None
190+
191+
def slice_of_value(self, np_index: NpIndex) -> A:
192+
"""Takes a slice of the value of this fragment.
193+
194+
It is required that `np_index` has already been clamped to the fragment's
195+
bounds; otherwise a ValueError will result.
196+
197+
Args:
198+
np_index: The NpIndex to slice by.
199+
200+
Returns:
201+
A slice of the fragment's value.
202+
"""
203+
raise NotImplementedError()
204+
145205

146206
@dataclasses.dataclass(frozen=True, init=False, eq=False, repr=False)
147207
class AbstractFragment(_GenericFragment[type(None)]):
@@ -178,21 +238,9 @@ def offset_by(
178238
out_idx[:, :2] += np.expand_dims(delta, axis=1)
179239
return type(self)(np_index=out_idx)
180240

181-
def slice(
182-
self,
183-
np_index: NpIndex, # shape=[{rank}, 3], dtype=int
184-
) -> AbstractFragment | None: # Use typing.Self once 3.11 is minimum.
185-
"""Slices this fragment to find the part that overlaps the given NpIndex."""
186-
if (self.step != 1).any() or (np_index[:, 2] != 1).any():
187-
raise NotImplementedError('Coming ... soon?')
188-
189-
slice_shape = np_index[:, 1] - np_index[:, 0]
190-
out = self.offset_by(-np_index[:, 0])
191-
start = out.start[:] = np.maximum(out.start, 0)
192-
stop = out.stop[:] = np.minimum(out.stop, slice_shape)
193-
if not (start < stop).all():
194-
return None
195-
return out
241+
def slice_of_value(self, np_index: NpIndex) -> None:
242+
del np_index
243+
return None
196244

197245

198246
@dataclasses.dataclass(frozen=True, init=False)
@@ -230,39 +278,14 @@ def __array__(self) -> np.ndarray:
230278
def nbytes(self) -> int:
231279
return self.value.nbytes
232280

233-
def slice(
234-
self,
235-
np_index: NpIndex, # shape=[{rank}, 3], dtype=int
236-
) -> _ConcreteFragment | None: # Use typing.Self once 3.11 is minimum.
237-
"""Slices this fragment to find the part that overlaps the given NpIndex."""
238-
if (self.step != 1).any() or (np_index[:, 2] != 1).any():
239-
raise NotImplementedError('Coming ... soon?')
240-
241-
slice_shape = np_index[:, 1] - np_index[:, 0]
242-
out = self.offset_by(-np_index[:, 0])
243-
start = out.start[:] = np.maximum(out.start, 0)
244-
stop = out.stop[:] = np.minimum(out.stop, slice_shape)
245-
if not (start < stop).all():
246-
return None
247-
return type(self)(
248-
np_index=out.np_index, value=self.slice_of_value(np_index)
249-
)
250-
251-
def slice_of_value(
252-
self,
253-
new_np_idx: NpIndex,
254-
) -> A:
255-
"""Returns a slice of `value`."""
256-
start = self.start
257-
stop = self.stop
281+
def slice_of_value(self, np_index: NpIndex) -> Aconcrete:
258282
# This is just a convenient way to construct the required tuple of slices.
259-
f = AbstractFragment(
260-
np_index=np.stack([
261-
np.maximum(start, new_np_idx[:, 0]),
262-
np.minimum(stop, new_np_idx[:, 1]),
263-
new_np_idx[:, 2],
264-
], axis=1)
265-
).offset_by(-start)
283+
f = AbstractFragment(np_index=np_index).offset_by(-self.start)
284+
if (f.start < 0).any() or (f.stop > self.value.shape).any():
285+
raise ValueError(
286+
f'Attempt to slice fragment value of shape {self.shape} with'
287+
f' out-of-bounds index {f}'
288+
)
266289
return self.value[f.index or ...]
267290

268291

@@ -353,7 +376,7 @@ def __array__(self) -> np.ndarray:
353376
def slice(
354377
self,
355378
index: NpIndex | Index, # shape=[{rank}, 3], dtype=int
356-
) -> '_GenericFragments[F]': # Use typing.Self once 3.11 is minimum.
379+
) -> '_GenericFragments[_GenericFragment[A]]': # Use typing.Self once >=3.11.
357380
"""Returns a slice of this object."""
358381
if not isinstance(index, np.ndarray):
359382
index = np_utils.resolve_slice(index, self.shape)

checkpoint/orbax/checkpoint/_src/arrays/fragments_test.py

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -214,6 +214,66 @@ def test_nbytes_astype_of_abstract_fragment_uses_given_dtype(self):
214214
).nbytes_astype(np.dtype(jax.numpy.bfloat16)),
215215
)
216216

217+
@parameterized.named_parameters(
218+
('np_fragment', NpFragment),
219+
('jax_fragment', JaxFragment),
220+
)
221+
def test_intersect(
222+
self,
223+
fragment_t: ConcreteFragmentT,
224+
):
225+
np_api = fragment_t.NP_API
226+
full_value = np_api.arange(8 * 9).reshape((8, 9))
227+
fragment_index = np.s_[4:8:1, 3:9:1]
228+
229+
f = fragment_t(index=fragment_index, value=full_value[fragment_index])
230+
231+
with self.subTest('fully_within_fragment_index'):
232+
bounds = np.s_[5:7:1, 4:8:1]
233+
s = f.intersect(array_fragments._ndarray_from_index(bounds))
234+
self.assertEqual(
235+
fragment_t(index=np.s_[5:7:1, 4:8:1], value=full_value[bounds]),
236+
s,
237+
)
238+
239+
with self.subTest('fully_enclosing_fragment_index'):
240+
bounds = np.s_[2:10:1, 1:11:1]
241+
s = f.intersect(array_fragments._ndarray_from_index(bounds))
242+
self.assertEqual(fragment_t(index=np.s_[4:8:1, 3:9:1], value=f.value), s)
243+
244+
with self.subTest('spanning_fragment_start'):
245+
bounds = np.s_[2:6:1, 2:4:1]
246+
s = f.intersect(array_fragments._ndarray_from_index(bounds))
247+
self.assertEqual(
248+
fragment_t(index=np.s_[4:6:1, 3:4:1], value=f.value[:2, :1]), s
249+
)
250+
251+
with self.subTest('spanning_fragment_stop'):
252+
bounds = np.s_[6:10:1, 6:10:1]
253+
s = f.intersect(array_fragments._ndarray_from_index(bounds))
254+
self.assertEqual(
255+
fragment_t(index=np.s_[6:8:1, 6:9:1], value=f.value[2:, 3:]), s
256+
)
257+
258+
with self.subTest('with_no_overlap'):
259+
self.assertIsNone(
260+
f.intersect(
261+
array_fragments._ndarray_from_index(np.s_[10:12:1, 10:12:1])
262+
)
263+
)
264+
# This is within the bounds of the fragment but spans no elements.
265+
self.assertIsNone(
266+
f.intersect(array_fragments._ndarray_from_index(np.s_[6:6:1, 3:9:1]))
267+
)
268+
269+
with self.subTest('rank_0'):
270+
s = fragment_t(index=(), value=np_api.ones([])).intersect(
271+
np.zeros([0, 3], dtype=int)
272+
)
273+
self.assertIsNotNone(s)
274+
self.assertEqual((), s.index)
275+
self.assertIsInstance(s.value, np_api.ndarray)
276+
217277
@parameterized.named_parameters(
218278
('np_fragment', NpFragment),
219279
('jax_fragment', JaxFragment),
@@ -272,6 +332,40 @@ def test_slice(
272332
self.assertEqual((), s.index)
273333
self.assertIsInstance(s.value, np_api.ndarray)
274334

335+
@parameterized.named_parameters(
336+
('np_fragment', NpFragment),
337+
('jax_fragment', JaxFragment),
338+
)
339+
def test_slice_of_value(
340+
self,
341+
fragment_t: ConcreteFragmentT,
342+
):
343+
np_api = fragment_t.NP_API
344+
full_value = np_api.arange(8 * 9).reshape((8, 9))
345+
fragment_index = np.s_[4:8:1, 3:9:1]
346+
fragment = fragment_t(
347+
index=fragment_index, value=full_value[fragment_index]
348+
)
349+
350+
with self.subTest('returns_slice_of_value'):
351+
np.testing.assert_array_equal(
352+
full_value[np.s_[5:7:1, 4:8:1]],
353+
fragment.slice_of_value(
354+
array_fragments._ndarray_from_index(np.s_[5:7:1, 4:8:1])
355+
),
356+
)
357+
358+
with self.subTest('raises_if_slice_is_out_of_bounds'):
359+
with self.assertRaises(ValueError):
360+
fragment.slice_of_value(
361+
array_fragments._ndarray_from_index(np.s_[2:6:1, 3:9:1])
362+
)
363+
364+
with self.assertRaises(ValueError):
365+
fragment.slice_of_value(
366+
array_fragments._ndarray_from_index(np.s_[4:8:1, 8:12:1])
367+
)
368+
275369

276370
@parameterized.named_parameters(
277371
('abstract_fragments', AbstractFragments),

0 commit comments

Comments
 (0)