18
18
import abc
19
19
from collections .abc import Mapping
20
20
import dataclasses
21
+ import functools
21
22
import gc
22
23
import os
23
24
import time
@@ -96,7 +97,6 @@ def export_model(self, model: keras.Model, model_dir: str):
96
97
model: The Keras model constructed by `create_model`.
97
98
model_dir: The model directory passed to the trainer.
98
99
"""
99
- model .save (os .path .join (model_dir , core .KERAS_MODEL_SAVEFILE ))
100
100
101
101
102
102
class KerasTrainer (core .Trainer [KerasTask ]):
@@ -118,6 +118,7 @@ def __init__(
118
118
max_checkpoints_to_keep : int = 5 ,
119
119
checkpoint_save_interval_epochs : int = 1 ,
120
120
rng_seed : int = core .DEFAULT_RNG_SEED ,
121
+ legacy_checkpoint_format : bool = True ,
121
122
):
122
123
"""Initializes the instance."""
123
124
@@ -143,60 +144,77 @@ def __init__(
143
144
self ._steps_per_eval = steps_per_eval
144
145
self ._continuous_eval_timeout = continuous_eval_timeout
145
146
self ._steps_per_loop = steps_per_loop
146
- self ._checkpoint_manager = None
147
147
self ._marker_path = os .path .join (
148
148
model_dir , core .TRAINING_COMPLETE_MARKER_FILE
149
149
)
150
150
self ._checkpoint_dir = os .path .join (model_dir , core .CHECKPOINT_DIR )
151
+ self ._max_checkpoints_to_keep = max_checkpoints_to_keep
152
+ self ._checkpoint_save_interval_epochs = checkpoint_save_interval_epochs
153
+ self ._legacy_checkpoint_format = legacy_checkpoint_format
151
154
155
+ @functools .cached_property
156
+ def train_callbacks (self ) -> list [keras .callbacks .Callback ]:
157
+ """Returns the training callbacks."""
152
158
if keras .backend .backend () == "jax" :
153
- self ._checkpoint_manager = keras_utils .KerasOrbaxCheckpointManager (
154
- checkpoint_dir = self ._checkpoint_dir ,
155
- max_to_keep = max_checkpoints_to_keep ,
156
- save_interval_epochs = checkpoint_save_interval_epochs ,
157
- )
158
- self ._train_callbacks = [
159
+ if self ._legacy_checkpoint_format :
160
+ checkpoint_manager = keras_utils .KerasOrbaxCheckpointManager (
161
+ checkpoint_dir = self ._checkpoint_dir ,
162
+ max_to_keep = self ._max_checkpoints_to_keep ,
163
+ save_interval_epochs = self ._checkpoint_save_interval_epochs ,
164
+ )
165
+ else :
166
+ checkpoint_manager = keras_utils .KerasOrbaxCheckpointManagerV2 (
167
+ checkpoint_dir = self ._checkpoint_dir ,
168
+ max_to_keep = self ._max_checkpoints_to_keep ,
169
+ save_interval_epochs = self ._checkpoint_save_interval_epochs ,
170
+ )
171
+ return [
159
172
keras_utils .EpochSummaryCallback (
160
- log_dir = os .path .join (model_dir , core .LOG_DIR ),
161
- steps_per_epoch = steps_per_loop ,
173
+ log_dir = os .path .join (self . _model_dir , core .LOG_DIR ),
174
+ steps_per_epoch = self . _steps_per_loop ,
162
175
write_steps_per_second = True ,
163
176
),
164
177
keras_utils .EpochOrbaxCheckpointAndRestoreCallback (
165
- checkpoint_manager = self . _checkpoint_manager ,
178
+ checkpoint_manager = checkpoint_manager ,
166
179
marker_path = self ._marker_path ,
167
180
),
168
181
]
169
- self ._eval_callbacks = [
182
+ return [
183
+ keras .callbacks .TensorBoard (
184
+ log_dir = os .path .join (self ._model_dir , core .LOG_DIR ),
185
+ write_steps_per_second = True ,
186
+ ),
187
+ keras .callbacks .BackupAndRestore (
188
+ backup_dir = os .path .join (self ._model_dir , core .BACKUP_DIR ),
189
+ ),
190
+ keras .callbacks .ModelCheckpoint (
191
+ filepath = os .path .join (
192
+ self ._model_dir ,
193
+ core .CHECKPOINT_DIR ,
194
+ "ckpt-{epoch:d}.weights.h5" ,
195
+ ),
196
+ save_weights_only = True ,
197
+ verbose = 1 ,
198
+ ),
199
+ ]
200
+
201
+ @functools .cached_property
202
+ def eval_callbacks (self ) -> list [keras .callbacks .Callback ]:
203
+ """Returns the evaluation callbacks."""
204
+ if keras .backend .backend () == "jax" :
205
+ return [
170
206
keras_utils .EpochSummaryCallback (
171
- log_dir = os .path .join (model_dir , core .LOG_DIR ),
172
- steps_per_epoch = steps_per_loop ,
207
+ log_dir = os .path .join (self . _model_dir , core .LOG_DIR ),
208
+ steps_per_epoch = self . _steps_per_loop ,
173
209
write_steps_per_second = False ,
174
210
),
175
211
]
176
- else :
177
- self ._checkpoint_manager = None
178
- self ._train_callbacks = [
179
- keras .callbacks .TensorBoard (
180
- log_dir = os .path .join (model_dir , core .LOG_DIR ),
181
- write_steps_per_second = True ,
182
- ),
183
- keras .callbacks .BackupAndRestore (
184
- backup_dir = os .path .join (model_dir , core .BACKUP_DIR ),
185
- ),
186
- keras .callbacks .ModelCheckpoint (
187
- filepath = os .path .join (
188
- model_dir , core .CHECKPOINT_DIR , "ckpt-{epoch:d}.weights.h5"
189
- ),
190
- save_weights_only = True ,
191
- verbose = 1 ,
192
- ),
193
- ]
194
- self ._eval_callbacks = [
195
- keras .callbacks .TensorBoard (
196
- log_dir = os .path .join (model_dir , core .LOG_DIR ),
197
- write_steps_per_second = True ,
198
- ),
199
- ]
212
+ return [
213
+ keras .callbacks .TensorBoard (
214
+ log_dir = os .path .join (self ._model_dir , core .LOG_DIR ),
215
+ write_steps_per_second = True ,
216
+ ),
217
+ ]
200
218
201
219
def _maybe_get_model_kws (
202
220
self , task : KerasTask , dataset : tf .data .Dataset
@@ -218,7 +236,7 @@ def train(self, task: KerasTask) -> core.Logs:
218
236
dataset ,
219
237
epochs = self ._train_epochs ,
220
238
steps_per_epoch = self ._steps_per_loop ,
221
- callbacks = self ._train_callbacks ,
239
+ callbacks = self .train_callbacks ,
222
240
)
223
241
model .summary (print_fn = logging .info )
224
242
@@ -237,14 +255,14 @@ def evaluate(self, task: KerasTask) -> core.Logs:
237
255
if keras .backend .backend () == "jax" :
238
256
[tb_cbk ] = [
239
257
cbk
240
- for cbk in self ._eval_callbacks
258
+ for cbk in self .eval_callbacks
241
259
if isinstance (cbk , keras_utils .EpochSummaryCallback )
242
260
]
243
261
epoch_start_time = time .time ()
244
262
history = model .evaluate (
245
263
dataset ,
246
264
steps = self ._steps_per_eval ,
247
- callbacks = self ._eval_callbacks ,
265
+ callbacks = self .eval_callbacks ,
248
266
return_dict = True ,
249
267
)
250
268
epoch_dt = time .time () - epoch_start_time
@@ -257,7 +275,7 @@ def evaluate(self, task: KerasTask) -> core.Logs:
257
275
return model .evaluate (
258
276
dataset ,
259
277
steps = self ._steps_per_eval ,
260
- callbacks = self ._eval_callbacks ,
278
+ callbacks = self .eval_callbacks ,
261
279
)
262
280
263
281
def train_and_evaluate (self , task : KerasTask ) -> core .Logs :
@@ -277,7 +295,7 @@ def train_and_evaluate(self, task: KerasTask) -> core.Logs:
277
295
steps_per_epoch = self ._steps_per_loop ,
278
296
# Explicitly set to None for deterministic evaluation.
279
297
validation_steps = None ,
280
- callbacks = self ._train_callbacks ,
298
+ callbacks = self .train_callbacks ,
281
299
)
282
300
model .summary (print_fn = logging .info )
283
301
@@ -308,7 +326,10 @@ def timeout_fn() -> bool:
308
326
else :
309
327
steps_msg = "running complete evaluation..."
310
328
329
+ use_legacy_checkpoint_format = self ._legacy_checkpoint_format
330
+
311
331
class _RestoreCallback (keras .callbacks .Callback ):
332
+ """Callback for restoring the model from the latest checkpoint."""
312
333
313
334
def __init__ (
314
335
self ,
@@ -319,9 +340,14 @@ def __init__(
319
340
self ._epoch = epoch
320
341
321
342
def on_test_begin (self , logs : Mapping [str , Any ] | None = None ):
322
- keras_utils .restore_keras_model (
323
- model , self ._checkpoint_dir , step = self ._epoch
324
- )
343
+ if use_legacy_checkpoint_format :
344
+ keras_utils .restore_keras_model (
345
+ model , self ._checkpoint_dir , step = self ._epoch
346
+ )
347
+ else :
348
+ keras_utils .restore_keras_checkpoint (
349
+ self ._checkpoint_dir , model = model , epoch = self ._epoch
350
+ )
325
351
326
352
history = None
327
353
for epoch in ocp .checkpoint_utils .checkpoints_iterator (
@@ -332,7 +358,7 @@ def on_test_begin(self, logs: Mapping[str, Any] | None = None):
332
358
restore_callback = _RestoreCallback (self ._checkpoint_dir , epoch )
333
359
[tb_cbk ] = [
334
360
cbk
335
- for cbk in self ._eval_callbacks
361
+ for cbk in self .eval_callbacks
336
362
if isinstance (cbk , keras_utils .EpochSummaryCallback )
337
363
]
338
364
try :
@@ -346,7 +372,7 @@ def on_test_begin(self, logs: Mapping[str, Any] | None = None):
346
372
history = model .evaluate (
347
373
eval_dataset ,
348
374
steps = self ._steps_per_eval ,
349
- callbacks = [restore_callback ] + self ._eval_callbacks ,
375
+ callbacks = [restore_callback ] + self .eval_callbacks ,
350
376
return_dict = True ,
351
377
)
352
378
0 commit comments