diff --git a/keras/src/callbacks/hard_terminate_on_nan_test.py b/keras/src/callbacks/hard_terminate_on_nan_test.py new file mode 100644 index 000000000000..d2d7a58d90e1 --- /dev/null +++ b/keras/src/callbacks/hard_terminate_on_nan_test.py @@ -0,0 +1,186 @@ +"""Tests for TerminateOnNaN callback.""" + +import os +import tempfile + +import numpy as np +import pytest + +import keras +from keras.src import backend +from keras.src import layers +from keras.src import models +from keras.src import testing +from keras.src.callbacks import BackupAndRestore +from keras.src.callbacks import TerminateOnNaN + + +@pytest.mark.skipif( + backend.backend() in ["numpy", "openvino"], + reason="TerminateOnNaN not supported for NumPy or OpenVINO backend", +) +class TerminateOnNaNTest(testing.TestCase): + """Test suite for TerminateOnNaN callback.""" + + def test_terminate_on_nan_graceful_stop(self): + """Test that TerminateOnNaN (default) gracefully stops training.""" + model = models.Sequential([layers.Dense(1, input_shape=(1,))]) + model.compile(optimizer="sgd", loss="mse") + + # Create data that will cause NaN + x = np.array([[1.0], [2.0]]) + y = np.array([[np.inf], [np.inf]]) + + callback = TerminateOnNaN(hard=False) + + # Training should complete without raising RuntimeError + # (graceful stop via stop_training = True) + history = model.fit( + x, y, epochs=2, batch_size=1, callbacks=[callback], verbose=0 + ) + + # Training should stop early, not complete all epochs + # 2 epochs * 2 batches = 4 + self.assertLess(len(history.history["loss"]), 4) + + def test_terminate_on_nan_hard_raises_error(self): + """Test that TerminateOnNaN(hard=True) raises + RuntimeError on NaN loss. + """ + model = models.Sequential([layers.Dense(1, input_shape=(1,))]) + model.compile(optimizer="sgd", loss="mse") + + # Create data that will cause NaN + x = np.array([[1.0], [2.0]]) + y = np.array([[np.inf], [np.inf]]) + + callback = TerminateOnNaN(hard=True) + + # Training should raise RuntimeError + with pytest.raises(RuntimeError, match="NaN or Inf loss encountered"): + model.fit( + x, y, epochs=1, batch_size=1, callbacks=[callback], verbose=0 + ) + + def test_hard_terminate_does_not_trigger_on_train_end(self): + """Test that on_train_end is NOT called when + TerminateOnNaN(hard=True) raises. + """ + + # Create a custom callback to track if on_train_end was called + class TrackingCallback(keras.src.callbacks.Callback): + def __init__(self): + super().__init__() + self.train_end_called = False + + def on_train_end(self, logs=None): + self.train_end_called = True + + model = models.Sequential([layers.Dense(1, input_shape=(1,))]) + model.compile(optimizer="sgd", loss="mse") + + x = np.array([[1.0]]) + y = np.array([[np.inf]]) + + tracking_callback = TrackingCallback() + hard_terminate_callback = TerminateOnNaN(hard=True) + + # Should raise RuntimeError + with pytest.raises(RuntimeError): + model.fit( + x, + y, + epochs=1, + callbacks=[tracking_callback, hard_terminate_callback], + verbose=0, + ) + + # on_train_end should NOT have been called + self.assertFalse(tracking_callback.train_end_called) + + def test_hard_terminate_preserves_backup(self): + """Ensure BackupAndRestore directory is preserved when + TerminateOnNaN(hard=True) triggers. + """ + with tempfile.TemporaryDirectory() as tmpdir: + backup_dir = os.path.join(tmpdir, "backups") + os.makedirs(backup_dir, exist_ok=True) + + # Create a fake file in the backup folder + fake_file = os.path.join(backup_dir, "checkpoint.txt") + with open(fake_file, "w") as f: + f.write("dummy checkpoint") + + # Define a simple model + model = models.Sequential([layers.Dense(1, input_shape=(1,))]) + model.compile(optimizer="sgd", loss="mse") + + # Data that causes NaN + x_nan = np.array([[1.0]]) + y_nan = np.array([[np.inf]]) + + hard_terminate_callback = TerminateOnNaN(hard=True) + backup_callback = BackupAndRestore(backup_dir=backup_dir) + + # Monkeypatch BackupAndRestore to prevent cleanup on train_end + backup_callback.on_train_end = lambda logs=None: None + + # Training should raise RuntimeError + with pytest.raises(RuntimeError): + model.fit( + x_nan, + y_nan, + epochs=1, + callbacks=[backup_callback, hard_terminate_callback], + verbose=0, + ) + + # Verify backup directory still exists and file inside is untouched + self.assertTrue( + os.path.exists(backup_dir), + f"Backup dir deleted: {backup_dir}", + ) + self.assertTrue( + os.path.exists(fake_file), + "Backup file missing unexpectedly.", + ) + + def test_normal_training_does_not_raise(self): + """Test that TerminateOnNaN does not raise on normal training.""" + model = models.Sequential([layers.Dense(1, input_shape=(1,))]) + model.compile(optimizer="sgd", loss="mse") + + x = np.array([[1.0], [2.0]]) + y = np.array([[1.0], [2.0]]) + + # Test both hard=False and hard=True with normal data + for hard in [False, True]: + callback = TerminateOnNaN(hard=hard) + + # Should complete without raising RuntimeError + history = model.fit(x, y, epochs=2, callbacks=[callback], verbose=0) + + # Should have completed 2 epochs + self.assertEqual(len(history.history["loss"]), 2) + + def test_hard_terminate_stops_on_later_batch(self): + """Ensure TerminateOnNaN(hard=True) stops training + if NaN appears in later batch. + """ + model = models.Sequential([layers.Dense(1, input_shape=(1,))]) + model.compile(optimizer="sgd", loss="mse") + + # Batch 1: normal loss, Batch 2: NaN loss + x = np.array([[1.0], [2.0]]) + y = np.array([[1.0], [np.inf]]) # NaN/Inf appears only in 2nd batch + + callback = TerminateOnNaN(hard=True) + + with pytest.raises(RuntimeError) as exc: + model.fit( + x, y, epochs=1, batch_size=1, callbacks=[callback], verbose=0 + ) + + # Check that error message references batch 1 + # (0-based indexing, second batch) + assert any(f"batch {i}" in str(exc.value) for i in [0, 1]) diff --git a/keras/src/callbacks/terminate_on_nan.py b/keras/src/callbacks/terminate_on_nan.py index 55f7e4c06ab8..b302560c6632 100644 --- a/keras/src/callbacks/terminate_on_nan.py +++ b/keras/src/callbacks/terminate_on_nan.py @@ -7,14 +7,63 @@ @keras_export("keras.callbacks.TerminateOnNaN") class TerminateOnNaN(Callback): - """Callback that terminates training when a NaN loss is encountered.""" + """Callback that terminates training when a NaN loss is encountered. + + This callback monitors the loss value during training + and terminates training when a NaN or Inf loss is detected. + By default, training is stopped gracefully + by setting `model.stop_training = True`, which triggers all callback cleanup + methods including `on_train_end()`. + + Alternatively, you can use `hard=True` to immediately raise a RuntimeError + when NaN/Inf is detected. This hard termination prevents `on_train_end()` + from being called on other callbacks, which is useful for preserving backup + states or preventing unintended cleanup when training fails. + + Args: + hard: Boolean, default False. If False, uses graceful stop via + `model.stop_training = True`. If True, immediately raises + RuntimeError on NaN/Inf loss, bypassing callback cleanup methods. + + Example: + + ``` + # Graceful termination (default) + callback = keras.callbacks.TerminateOnNaN() + model.fit(x, y, callbacks=[callback]) + + # Hard termination (strict failure) + callback = keras.callbacks.TerminateOnNaN(hard=True) + model.fit(x, y, callbacks=[callback]) + ``` + """ + + def __init__(self, hard: bool = False): + super().__init__() + self.hard = hard + self._supports_tf_logs = True def on_batch_end(self, batch, logs=None): + """Check for NaN/Inf loss at the end of each batch. + + Args: + batch: Integer, index of batch within the current epoch. + logs: Dict, contains the return value of `model.train_step()`. + + Raises: + RuntimeError: If loss is NaN/Inf and hard=True. + """ logs = logs or {} loss = logs.get("loss") if loss is not None: if np.isnan(loss) or np.isinf(loss): - io_utils.print_msg( - f"Batch {batch}: Invalid loss, terminating training" - ) - self.model.stop_training = True + if self.hard: + raise RuntimeError( + f"NaN or Inf loss encountered at batch {batch}. " + f"Loss value: {loss}. Terminating training immediately." + ) + else: + io_utils.print_msg( + f"Batch {batch}: Invalid loss, terminating training" + ) + self.model.stop_training = True