Skip to content

Commit 117ce2a

Browse files
chandrasekhard2recml authors
authored andcommitted
Integrate sparsecore into DLRM HSTU implementation. Additionally implemented KV caching for STU layers
Reverts changelist 793734230 PiperOrigin-RevId: 791891966
1 parent 847628b commit 117ce2a

24 files changed

+4608
-155
lines changed

recml/core/ops/hstu_ops.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -125,9 +125,9 @@ def _apply_mask(
125125
masks = []
126126
if mask_ref is not None:
127127
if k_in_lanes:
128-
mask = pl.load(mask_ref, (slice(None), k_slice))
128+
mask = mask_ref[:, k_slice]
129129
else:
130-
mask = pl.load(mask_ref, (k_slice, slice(None)))
130+
mask = mask_ref[k_slice, :]
131131

132132
snm = jnp.where(should_not_mask, 1, 0)
133133
masks.append(jnp.bitwise_or(mask, jnp.broadcast_to(snm, mask.shape)) != 0)
@@ -156,7 +156,7 @@ def _apply_mask(
156156
k_sequence = k_offset + jax.lax.broadcasted_iota(
157157
jnp.int32, (k_slice.size, bq), 0
158158
)
159-
q_sequence = pl.load(q_sequence_ref, (pl.ds(1), slice(None))) # [1, bq]
159+
q_sequence = q_sequence_ref[:1, :] # [1, bq]
160160
q_sequence = jnp.broadcast_to(q_sequence, (k_slice.size, bq))
161161

162162
assert q_sequence.shape == k_sequence.shape
@@ -170,7 +170,7 @@ def _apply_mask(
170170

171171
if q_segment_ids_ref is not None:
172172
if k_in_lanes:
173-
kv_ids = pl.load(kv_segment_ids_ref, (pl.ds(1), k_slice)) # [1, k_slice]
173+
kv_ids = kv_segment_ids_ref[:1, k_slice] # [1, k_slice]
174174
repeats, rem = divmod(kv_ids.shape[1], NUM_LANES)
175175
if rem:
176176
raise NotImplementedError(f"block_kv must be a multiple of {NUM_LANES}")
@@ -181,9 +181,9 @@ def _apply_mask(
181181
if rem:
182182
raise NotImplementedError(f"block_q must be a multiple of {NUM_LANES}")
183183
kv_ids = pltpu.repeat(
184-
pl.load(kv_segment_ids_ref, (k_slice, slice(None))), repeats, axis=1
184+
kv_segment_ids_ref[k_slice, :], repeats, axis=1
185185
) # [k_slice, bq]
186-
q_ids = pl.load(q_segment_ids_ref, (pl.ds(1), slice(None))) # [1, bq]
186+
q_ids = q_segment_ids_ref[:1, :] # [1, bq]
187187
masks.append(q_ids == kv_ids)
188188

189189
if masks:
@@ -228,7 +228,7 @@ def body(kv_compute_index, _):
228228
slice_k = pl.ds(kv_compute_index * bkv_compute, bkv_compute)
229229

230230
q = q_ref[...]
231-
k = pl.load(k_ref, (slice_k, slice(None)))
231+
k = k_ref[slice_k, :]
232232
qk = jax.lax.dot_general(
233233
q, k, NT_DIM_NUMBERS, preferred_element_type=jnp.float32
234234
)
@@ -256,7 +256,7 @@ def body(kv_compute_index, _):
256256
)
257257

258258
sv_dims = NN_DIM_NUMBERS
259-
v = pl.load(v_ref, (slice_k, slice(None)))
259+
v = v_ref[slice_k, :]
260260

261261
to_float32 = lambda x: x.astype(jnp.float32)
262262
v = to_float32(v)

recml/core/training/keras_trainer.py

Lines changed: 74 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import abc
1919
from collections.abc import Mapping
2020
import dataclasses
21+
import functools
2122
import gc
2223
import os
2324
import time
@@ -96,7 +97,6 @@ def export_model(self, model: keras.Model, model_dir: str):
9697
model: The Keras model constructed by `create_model`.
9798
model_dir: The model directory passed to the trainer.
9899
"""
99-
model.save(os.path.join(model_dir, core.KERAS_MODEL_SAVEFILE))
100100

101101

102102
class KerasTrainer(core.Trainer[KerasTask]):
@@ -118,6 +118,7 @@ def __init__(
118118
max_checkpoints_to_keep: int = 5,
119119
checkpoint_save_interval_epochs: int = 1,
120120
rng_seed: int = core.DEFAULT_RNG_SEED,
121+
legacy_checkpoint_format: bool = True,
121122
):
122123
"""Initializes the instance."""
123124

@@ -143,60 +144,77 @@ def __init__(
143144
self._steps_per_eval = steps_per_eval
144145
self._continuous_eval_timeout = continuous_eval_timeout
145146
self._steps_per_loop = steps_per_loop
146-
self._checkpoint_manager = None
147147
self._marker_path = os.path.join(
148148
model_dir, core.TRAINING_COMPLETE_MARKER_FILE
149149
)
150150
self._checkpoint_dir = os.path.join(model_dir, core.CHECKPOINT_DIR)
151+
self._max_checkpoints_to_keep = max_checkpoints_to_keep
152+
self._checkpoint_save_interval_epochs = checkpoint_save_interval_epochs
153+
self._legacy_checkpoint_format = legacy_checkpoint_format
151154

155+
@functools.cached_property
156+
def train_callbacks(self) -> list[keras.callbacks.Callback]:
157+
"""Returns the training callbacks."""
152158
if keras.backend.backend() == "jax":
153-
self._checkpoint_manager = keras_utils.KerasOrbaxCheckpointManager(
154-
checkpoint_dir=self._checkpoint_dir,
155-
max_to_keep=max_checkpoints_to_keep,
156-
save_interval_epochs=checkpoint_save_interval_epochs,
157-
)
158-
self._train_callbacks = [
159+
if self._legacy_checkpoint_format:
160+
checkpoint_manager = keras_utils.KerasOrbaxCheckpointManager(
161+
checkpoint_dir=self._checkpoint_dir,
162+
max_to_keep=self._max_checkpoints_to_keep,
163+
save_interval_epochs=self._checkpoint_save_interval_epochs,
164+
)
165+
else:
166+
checkpoint_manager = keras_utils.KerasOrbaxCheckpointManagerV2(
167+
checkpoint_dir=self._checkpoint_dir,
168+
max_to_keep=self._max_checkpoints_to_keep,
169+
save_interval_epochs=self._checkpoint_save_interval_epochs,
170+
)
171+
return [
159172
keras_utils.EpochSummaryCallback(
160-
log_dir=os.path.join(model_dir, core.LOG_DIR),
161-
steps_per_epoch=steps_per_loop,
173+
log_dir=os.path.join(self._model_dir, core.LOG_DIR),
174+
steps_per_epoch=self._steps_per_loop,
162175
write_steps_per_second=True,
163176
),
164177
keras_utils.EpochOrbaxCheckpointAndRestoreCallback(
165-
checkpoint_manager=self._checkpoint_manager,
178+
checkpoint_manager=checkpoint_manager,
166179
marker_path=self._marker_path,
167180
),
168181
]
169-
self._eval_callbacks = [
182+
return [
183+
keras.callbacks.TensorBoard(
184+
log_dir=os.path.join(self._model_dir, core.LOG_DIR),
185+
write_steps_per_second=True,
186+
),
187+
keras.callbacks.BackupAndRestore(
188+
backup_dir=os.path.join(self._model_dir, core.BACKUP_DIR),
189+
),
190+
keras.callbacks.ModelCheckpoint(
191+
filepath=os.path.join(
192+
self._model_dir,
193+
core.CHECKPOINT_DIR,
194+
"ckpt-{epoch:d}.weights.h5",
195+
),
196+
save_weights_only=True,
197+
verbose=1,
198+
),
199+
]
200+
201+
@functools.cached_property
202+
def eval_callbacks(self) -> list[keras.callbacks.Callback]:
203+
"""Returns the evaluation callbacks."""
204+
if keras.backend.backend() == "jax":
205+
return [
170206
keras_utils.EpochSummaryCallback(
171-
log_dir=os.path.join(model_dir, core.LOG_DIR),
172-
steps_per_epoch=steps_per_loop,
207+
log_dir=os.path.join(self._model_dir, core.LOG_DIR),
208+
steps_per_epoch=self._steps_per_loop,
173209
write_steps_per_second=False,
174210
),
175211
]
176-
else:
177-
self._checkpoint_manager = None
178-
self._train_callbacks = [
179-
keras.callbacks.TensorBoard(
180-
log_dir=os.path.join(model_dir, core.LOG_DIR),
181-
write_steps_per_second=True,
182-
),
183-
keras.callbacks.BackupAndRestore(
184-
backup_dir=os.path.join(model_dir, core.BACKUP_DIR),
185-
),
186-
keras.callbacks.ModelCheckpoint(
187-
filepath=os.path.join(
188-
model_dir, core.CHECKPOINT_DIR, "ckpt-{epoch:d}.weights.h5"
189-
),
190-
save_weights_only=True,
191-
verbose=1,
192-
),
193-
]
194-
self._eval_callbacks = [
195-
keras.callbacks.TensorBoard(
196-
log_dir=os.path.join(model_dir, core.LOG_DIR),
197-
write_steps_per_second=True,
198-
),
199-
]
212+
return [
213+
keras.callbacks.TensorBoard(
214+
log_dir=os.path.join(self._model_dir, core.LOG_DIR),
215+
write_steps_per_second=True,
216+
),
217+
]
200218

201219
def _maybe_get_model_kws(
202220
self, task: KerasTask, dataset: tf.data.Dataset
@@ -218,7 +236,7 @@ def train(self, task: KerasTask) -> core.Logs:
218236
dataset,
219237
epochs=self._train_epochs,
220238
steps_per_epoch=self._steps_per_loop,
221-
callbacks=self._train_callbacks,
239+
callbacks=self.train_callbacks,
222240
)
223241
model.summary(print_fn=logging.info)
224242

@@ -237,14 +255,14 @@ def evaluate(self, task: KerasTask) -> core.Logs:
237255
if keras.backend.backend() == "jax":
238256
[tb_cbk] = [
239257
cbk
240-
for cbk in self._eval_callbacks
258+
for cbk in self.eval_callbacks
241259
if isinstance(cbk, keras_utils.EpochSummaryCallback)
242260
]
243261
epoch_start_time = time.time()
244262
history = model.evaluate(
245263
dataset,
246264
steps=self._steps_per_eval,
247-
callbacks=self._eval_callbacks,
265+
callbacks=self.eval_callbacks,
248266
return_dict=True,
249267
)
250268
epoch_dt = time.time() - epoch_start_time
@@ -257,7 +275,7 @@ def evaluate(self, task: KerasTask) -> core.Logs:
257275
return model.evaluate(
258276
dataset,
259277
steps=self._steps_per_eval,
260-
callbacks=self._eval_callbacks,
278+
callbacks=self.eval_callbacks,
261279
)
262280

263281
def train_and_evaluate(self, task: KerasTask) -> core.Logs:
@@ -277,7 +295,7 @@ def train_and_evaluate(self, task: KerasTask) -> core.Logs:
277295
steps_per_epoch=self._steps_per_loop,
278296
# Explicitly set to None for deterministic evaluation.
279297
validation_steps=None,
280-
callbacks=self._train_callbacks,
298+
callbacks=self.train_callbacks,
281299
)
282300
model.summary(print_fn=logging.info)
283301

@@ -308,7 +326,10 @@ def timeout_fn() -> bool:
308326
else:
309327
steps_msg = "running complete evaluation..."
310328

329+
use_legacy_checkpoint_format = self._legacy_checkpoint_format
330+
311331
class _RestoreCallback(keras.callbacks.Callback):
332+
"""Callback for restoring the model from the latest checkpoint."""
312333

313334
def __init__(
314335
self,
@@ -319,9 +340,14 @@ def __init__(
319340
self._epoch = epoch
320341

321342
def on_test_begin(self, logs: Mapping[str, Any] | None = None):
322-
keras_utils.restore_keras_model(
323-
model, self._checkpoint_dir, step=self._epoch
324-
)
343+
if use_legacy_checkpoint_format:
344+
keras_utils.restore_keras_model(
345+
model, self._checkpoint_dir, step=self._epoch
346+
)
347+
else:
348+
keras_utils.restore_keras_checkpoint(
349+
self._checkpoint_dir, model=model, epoch=self._epoch
350+
)
325351

326352
history = None
327353
for epoch in ocp.checkpoint_utils.checkpoints_iterator(
@@ -332,7 +358,7 @@ def on_test_begin(self, logs: Mapping[str, Any] | None = None):
332358
restore_callback = _RestoreCallback(self._checkpoint_dir, epoch)
333359
[tb_cbk] = [
334360
cbk
335-
for cbk in self._eval_callbacks
361+
for cbk in self.eval_callbacks
336362
if isinstance(cbk, keras_utils.EpochSummaryCallback)
337363
]
338364
try:
@@ -346,7 +372,7 @@ def on_test_begin(self, logs: Mapping[str, Any] | None = None):
346372
history = model.evaluate(
347373
eval_dataset,
348374
steps=self._steps_per_eval,
349-
callbacks=[restore_callback] + self._eval_callbacks,
375+
callbacks=[restore_callback] + self.eval_callbacks,
350376
return_dict=True,
351377
)
352378

recml/core/training/keras_trainer_test.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,11 +59,23 @@ def setUp(self):
5959
"mode": core.Experiment.Mode.TRAIN_AND_EVAL,
6060
},
6161
{
62-
"testcase_name": "continuous_eval",
62+
"testcase_name": "continuous_eval_",
6363
"mode": core.Experiment.Mode.CONTINUOUS_EVAL,
6464
},
65+
{
66+
"testcase_name": "train_and_eval_legacy_checkpoint_format",
67+
"mode": core.Experiment.Mode.TRAIN_AND_EVAL,
68+
"legacy_checkpoint_format": True,
69+
},
70+
{
71+
"testcase_name": "continuous_eval_legacy_checkpoint_format",
72+
"mode": core.Experiment.Mode.CONTINUOUS_EVAL,
73+
"legacy_checkpoint_format": True,
74+
},
6575
)
66-
def test_keras_task_and_trainer(self, mode: str):
76+
def test_keras_task_and_trainer(
77+
self, mode: str, legacy_checkpoint_format: bool = False
78+
):
6779
if keras.backend.backend() == "jax":
6880
distribution = keras.distribution.DataParallel()
6981
else:
@@ -78,6 +90,7 @@ def test_keras_task_and_trainer(self, mode: str):
7890
steps_per_loop=2,
7991
model_dir=self.create_tempdir().full_path,
8092
continuous_eval_timeout=5,
93+
legacy_checkpoint_format=legacy_checkpoint_format,
8194
)
8295
experiment = core.Experiment(_KerasTask(), trainer)
8396

recml/core/training/partitioning.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ def _shard(x: np.ndarray) -> jax.Array:
107107
def partition_init(
108108
self, init_fn: CreateStateFn, *, abstract_batch: PyTree | None = None
109109
) -> CreateStateFn:
110-
with jax.sharding.use_mesh(self.mesh):
110+
with jax.set_mesh(self.mesh):
111111
if abstract_batch is not None:
112112
abstract_state = jax.eval_shape(init_fn, abstract_batch)
113113
specs = nn.get_partition_spec(abstract_state)
@@ -117,7 +117,7 @@ def partition_init(
117117
init_fn = jax.jit(init_fn, out_shardings=self.state_sharding)
118118

119119
def _wrapped_init(batch: PyTree) -> State:
120-
with jax.sharding.use_mesh(self.mesh):
120+
with jax.set_mesh(self.mesh):
121121
state = init_fn(batch)
122122
state = _maybe_unbox_state(state)
123123
return state
@@ -130,15 +130,15 @@ def partition_step(self, fn: StepFn, *, training: bool = False) -> StepFn:
130130
jit_kws["out_shardings"] = (self.state_sharding, None)
131131
jit_kws["donate_argnums"] = (1,)
132132

133-
with jax.sharding.use_mesh(self.mesh):
133+
with jax.set_mesh(self.mesh):
134134
step_fn = jax.jit(
135135
fn,
136136
in_shardings=(self.data_sharding, self.state_sharding),
137137
**jit_kws,
138138
)
139139

140140
def _wrapped_step(batch: PyTree, state: State) -> Any:
141-
with jax.sharding.use_mesh(self.mesh):
141+
with jax.set_mesh(self.mesh):
142142
return step_fn(batch, state)
143143

144144
return _wrapped_step
@@ -217,7 +217,7 @@ def __init__(
217217
def mesh_context_manager(
218218
self,
219219
) -> Callable[[jax.sharding.Mesh], ContextManager[None]]:
220-
return jax.sharding.use_mesh
220+
return jax.set_mesh
221221

222222
def shard_inputs(self, inputs: PyTree) -> PyTree:
223223
def _shard(x: np.ndarray) -> jax.Array:

0 commit comments

Comments
 (0)