From 5af1b573be9ba19b6aa3bdeb13fffd99533c99ed Mon Sep 17 00:00:00 2001 From: Malyala Karthik Date: Mon, 10 Nov 2025 23:07:17 +0530 Subject: [PATCH 1/4] Add HardTerminateOnNaN callback for immediate training termination on NaN loss --- keras/src/callbacks/__init__.py | 1 + .../callbacks/hard_terminate_on_nan_test.py | 132 ++++++++++++++++++ keras/src/callbacks/terminate_on_nan.py | 46 ++++++ 3 files changed, 179 insertions(+) create mode 100644 keras/src/callbacks/hard_terminate_on_nan_test.py diff --git a/keras/src/callbacks/__init__.py b/keras/src/callbacks/__init__.py index 427c4f6da95f..948f40859a02 100644 --- a/keras/src/callbacks/__init__.py +++ b/keras/src/callbacks/__init__.py @@ -13,4 +13,5 @@ from keras.src.callbacks.remote_monitor import RemoteMonitor from keras.src.callbacks.swap_ema_weights import SwapEMAWeights from keras.src.callbacks.tensorboard import TensorBoard +from keras.src.callbacks.terminate_on_nan import HardTerminateOnNaN from keras.src.callbacks.terminate_on_nan import TerminateOnNaN diff --git a/keras/src/callbacks/hard_terminate_on_nan_test.py b/keras/src/callbacks/hard_terminate_on_nan_test.py new file mode 100644 index 000000000000..dd21d6e69b50 --- /dev/null +++ b/keras/src/callbacks/hard_terminate_on_nan_test.py @@ -0,0 +1,132 @@ +"""Tests for HardTerminateOnNaN callback.""" + +import os +import tempfile + +import numpy as np +import pytest + +import keras +from keras.src import layers +from keras.src import models +from keras.src import testing +from keras.src.callbacks import BackupAndRestore +from keras.src.callbacks import HardTerminateOnNaN + + +class HardTerminateOnNaNTest(testing.TestCase): + """Test suite for HardTerminateOnNaN callback.""" + + def test_hard_terminate_on_nan_raises_error(self): + """Test that HardTerminateOnNaN raises RuntimeError on NaN loss.""" + # Create a simple model + model = models.Sequential([layers.Dense(1, input_shape=(1,))]) + model.compile(optimizer="sgd", loss="mse") + + # Create data that will cause NaN (extreme values) + x = np.array([[1.0], [2.0]]) + y = np.array([[np.inf], [np.inf]]) # This should cause NaN + + callback = HardTerminateOnNaN() + + # Training should raise RuntimeError + with pytest.raises(RuntimeError, match="NaN or Inf loss encountered"): + model.fit( + x, y, epochs=1, batch_size=1, callbacks=[callback], verbose=0 + ) + + def test_hard_terminate_does_not_trigger_on_train_end(self): + """Test that on_train_end is NOT called when + HardTerminateOnNaN raises. + """ + + # Create a custom callback to track if on_train_end was called + class TrackingCallback(keras.src.callbacks.Callback): + def __init__(self): + super().__init__() + self.train_end_called = False + + def on_train_end(self, logs=None): + self.train_end_called = True + + model = models.Sequential([layers.Dense(1, input_shape=(1,))]) + model.compile(optimizer="sgd", loss="mse") + + x = np.array([[1.0]]) + y = np.array([[np.inf]]) + + tracking_callback = TrackingCallback() + hard_terminate_callback = HardTerminateOnNaN() + + # Should raise RuntimeError + with pytest.raises(RuntimeError): + model.fit( + x, + y, + epochs=1, + callbacks=[tracking_callback, hard_terminate_callback], + verbose=0, + ) + + # on_train_end should NOT have been called + self.assertFalse(tracking_callback.train_end_called) + + def test_hard_terminate_preserves_backup(self): + """Ensure BackupAndRestore directory is preserved when + HardTerminateOnNaN triggers. + """ + with tempfile.TemporaryDirectory() as tmpdir: + backup_dir = os.path.join(tmpdir, "backups") + os.makedirs(backup_dir, exist_ok=True) + + # Create a fake file in the backup folder + fake_file = os.path.join(backup_dir, "checkpoint.txt") + open(fake_file, "w").write("dummy checkpoint") + + # Define a simple model + model = models.Sequential([layers.Dense(1, input_shape=(1,))]) + model.compile(optimizer="sgd", loss="mse") + + # Data that causes NaN + x_nan = np.array([[1.0]]) + y_nan = np.array([[np.inf]]) + + hard_terminate_callback = HardTerminateOnNaN() + backup_callback = BackupAndRestore(backup_dir=backup_dir) + + # Monkeypatch BackupAndRestore to prevent cleanup on train_end + backup_callback.on_train_end = lambda logs=None: None + + # Training should raise RuntimeError + with pytest.raises(RuntimeError): + model.fit( + x_nan, + y_nan, + epochs=1, + callbacks=[backup_callback, hard_terminate_callback], + verbose=0, + ) + + # Verify backup directory still exists and file inside is untouched + assert os.path.exists(backup_dir), ( + f"Backup dir deleted: {backup_dir}" + ) + assert os.path.exists(fake_file), ( + "Backup file missing unexpectedly." + ) + + def test_normal_training_does_not_raise(self): + """Test that HardTerminateOnNaN does not raise on normal training.""" + model = models.Sequential([layers.Dense(1, input_shape=(1,))]) + model.compile(optimizer="sgd", loss="mse") + + x = np.array([[1.0], [2.0]]) + y = np.array([[1.0], [2.0]]) + + callback = HardTerminateOnNaN() + + # Should complete without raising + history = model.fit(x, y, epochs=2, callbacks=[callback], verbose=0) + + # Should have completed 2 epochs + self.assertEqual(len(history.history["loss"]), 2) diff --git a/keras/src/callbacks/terminate_on_nan.py b/keras/src/callbacks/terminate_on_nan.py index 55f7e4c06ab8..a28327644139 100644 --- a/keras/src/callbacks/terminate_on_nan.py +++ b/keras/src/callbacks/terminate_on_nan.py @@ -18,3 +18,49 @@ def on_batch_end(self, batch, logs=None): f"Batch {batch}: Invalid loss, terminating training" ) self.model.stop_training = True + + +@keras_export("keras.callbacks.HardTerminateOnNaN") +class HardTerminateOnNaN(Callback): + """Callback that terminates training immediately + when NaN/Inf loss is detected. + + This callback raises a RuntimeError when a NaN or Inf loss is encountered, + which immediately stops training without triggering `on_train_end()` hooks + for other callbacks. This is useful when you want to preserve backup states + or prevent early stopping from restoring weights after a NaN failure. + + Unlike `TerminateOnNaN`, which gracefully stops training using + `model.stop_training = True` and triggers all callback cleanup methods, + `HardTerminateOnNaN` crashes the training loop immediately. + + Example: + + ``` + callback = keras.callbacks.HardTerminateOnNaN() + model.fit(x, y, callbacks=[callback]) + ``` + """ + + def __init__(self): + super().__init__() + self._supports_tf_logs = True + + def on_batch_end(self, batch, logs=None): + """Check for NaN/Inf loss at the end of each batch. + + Args: + batch: Integer, index of batch within the current epoch. + logs: Dict, contains the return value of `model.train_step()`. + + Raises: + RuntimeError: If loss is NaN or Inf. + """ + logs = logs or {} + loss = logs.get("loss") + if loss is not None: + if np.isnan(loss) or np.isinf(loss): + raise RuntimeError( + f"NaN or Inf loss encountered at batch {batch}. " + f"Loss value: {loss}. Terminating training immediately." + ) From 617eb608688c47dd7a7b548b2d5ea6d269abc5d7 Mon Sep 17 00:00:00 2001 From: Malyala Karthik Date: Mon, 10 Nov 2025 23:47:09 +0530 Subject: [PATCH 2/4] Add HardTerminateOnNaN callback for immediate training termination on NaN loss --- keras/src/callbacks/hard_terminate_on_nan_test.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/keras/src/callbacks/hard_terminate_on_nan_test.py b/keras/src/callbacks/hard_terminate_on_nan_test.py index dd21d6e69b50..e92618f58173 100644 --- a/keras/src/callbacks/hard_terminate_on_nan_test.py +++ b/keras/src/callbacks/hard_terminate_on_nan_test.py @@ -7,6 +7,7 @@ import pytest import keras +from keras.src import backend from keras.src import layers from keras.src import models from keras.src import testing @@ -14,6 +15,10 @@ from keras.src.callbacks import HardTerminateOnNaN +@pytest.mark.skipif( + backend.backend() in ["numpy", "openvino"], + reason="HardTerminateOnNaN not supported for NumPy or OpenVINO backend", +) class HardTerminateOnNaNTest(testing.TestCase): """Test suite for HardTerminateOnNaN callback.""" From 3c2270abb979569f86c2ed282a3fb847bbc46995 Mon Sep 17 00:00:00 2001 From: Malyala Karthik Date: Fri, 14 Nov 2025 01:26:11 +0530 Subject: [PATCH 3/4] Add HardTerminateOnNaN callback for immediate training termination on NaN loss --- keras/src/callbacks/hard_terminate_on_nan_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/keras/src/callbacks/hard_terminate_on_nan_test.py b/keras/src/callbacks/hard_terminate_on_nan_test.py index e92618f58173..ccef154ac9e7 100644 --- a/keras/src/callbacks/hard_terminate_on_nan_test.py +++ b/keras/src/callbacks/hard_terminate_on_nan_test.py @@ -130,7 +130,7 @@ def test_normal_training_does_not_raise(self): callback = HardTerminateOnNaN() - # Should complete without raising + # Should complete without raising RuntimeError history = model.fit(x, y, epochs=2, callbacks=[callback], verbose=0) # Should have completed 2 epochs From 7eb295bff58808e165ee639ea34dd628baa05781 Mon Sep 17 00:00:00 2001 From: Malyala Karthik Date: Fri, 14 Nov 2025 18:41:28 +0530 Subject: [PATCH 4/4] Add hard option to TerminateOnNaN for immediate termination on NaN/Inf loss --- keras/src/callbacks/__init__.py | 1 - .../callbacks/hard_terminate_on_nan_test.py | 101 +++++++++++++----- keras/src/callbacks/terminate_on_nan.py | 63 +++++------ 3 files changed, 108 insertions(+), 57 deletions(-) diff --git a/keras/src/callbacks/__init__.py b/keras/src/callbacks/__init__.py index 948f40859a02..427c4f6da95f 100644 --- a/keras/src/callbacks/__init__.py +++ b/keras/src/callbacks/__init__.py @@ -13,5 +13,4 @@ from keras.src.callbacks.remote_monitor import RemoteMonitor from keras.src.callbacks.swap_ema_weights import SwapEMAWeights from keras.src.callbacks.tensorboard import TensorBoard -from keras.src.callbacks.terminate_on_nan import HardTerminateOnNaN from keras.src.callbacks.terminate_on_nan import TerminateOnNaN diff --git a/keras/src/callbacks/hard_terminate_on_nan_test.py b/keras/src/callbacks/hard_terminate_on_nan_test.py index ccef154ac9e7..d2d7a58d90e1 100644 --- a/keras/src/callbacks/hard_terminate_on_nan_test.py +++ b/keras/src/callbacks/hard_terminate_on_nan_test.py @@ -1,4 +1,4 @@ -"""Tests for HardTerminateOnNaN callback.""" +"""Tests for TerminateOnNaN callback.""" import os import tempfile @@ -12,27 +12,49 @@ from keras.src import models from keras.src import testing from keras.src.callbacks import BackupAndRestore -from keras.src.callbacks import HardTerminateOnNaN +from keras.src.callbacks import TerminateOnNaN @pytest.mark.skipif( backend.backend() in ["numpy", "openvino"], - reason="HardTerminateOnNaN not supported for NumPy or OpenVINO backend", + reason="TerminateOnNaN not supported for NumPy or OpenVINO backend", ) -class HardTerminateOnNaNTest(testing.TestCase): - """Test suite for HardTerminateOnNaN callback.""" +class TerminateOnNaNTest(testing.TestCase): + """Test suite for TerminateOnNaN callback.""" - def test_hard_terminate_on_nan_raises_error(self): - """Test that HardTerminateOnNaN raises RuntimeError on NaN loss.""" - # Create a simple model + def test_terminate_on_nan_graceful_stop(self): + """Test that TerminateOnNaN (default) gracefully stops training.""" model = models.Sequential([layers.Dense(1, input_shape=(1,))]) model.compile(optimizer="sgd", loss="mse") - # Create data that will cause NaN (extreme values) + # Create data that will cause NaN x = np.array([[1.0], [2.0]]) - y = np.array([[np.inf], [np.inf]]) # This should cause NaN + y = np.array([[np.inf], [np.inf]]) - callback = HardTerminateOnNaN() + callback = TerminateOnNaN(hard=False) + + # Training should complete without raising RuntimeError + # (graceful stop via stop_training = True) + history = model.fit( + x, y, epochs=2, batch_size=1, callbacks=[callback], verbose=0 + ) + + # Training should stop early, not complete all epochs + # 2 epochs * 2 batches = 4 + self.assertLess(len(history.history["loss"]), 4) + + def test_terminate_on_nan_hard_raises_error(self): + """Test that TerminateOnNaN(hard=True) raises + RuntimeError on NaN loss. + """ + model = models.Sequential([layers.Dense(1, input_shape=(1,))]) + model.compile(optimizer="sgd", loss="mse") + + # Create data that will cause NaN + x = np.array([[1.0], [2.0]]) + y = np.array([[np.inf], [np.inf]]) + + callback = TerminateOnNaN(hard=True) # Training should raise RuntimeError with pytest.raises(RuntimeError, match="NaN or Inf loss encountered"): @@ -42,7 +64,7 @@ def test_hard_terminate_on_nan_raises_error(self): def test_hard_terminate_does_not_trigger_on_train_end(self): """Test that on_train_end is NOT called when - HardTerminateOnNaN raises. + TerminateOnNaN(hard=True) raises. """ # Create a custom callback to track if on_train_end was called @@ -61,7 +83,7 @@ def on_train_end(self, logs=None): y = np.array([[np.inf]]) tracking_callback = TrackingCallback() - hard_terminate_callback = HardTerminateOnNaN() + hard_terminate_callback = TerminateOnNaN(hard=True) # Should raise RuntimeError with pytest.raises(RuntimeError): @@ -78,7 +100,7 @@ def on_train_end(self, logs=None): def test_hard_terminate_preserves_backup(self): """Ensure BackupAndRestore directory is preserved when - HardTerminateOnNaN triggers. + TerminateOnNaN(hard=True) triggers. """ with tempfile.TemporaryDirectory() as tmpdir: backup_dir = os.path.join(tmpdir, "backups") @@ -86,7 +108,8 @@ def test_hard_terminate_preserves_backup(self): # Create a fake file in the backup folder fake_file = os.path.join(backup_dir, "checkpoint.txt") - open(fake_file, "w").write("dummy checkpoint") + with open(fake_file, "w") as f: + f.write("dummy checkpoint") # Define a simple model model = models.Sequential([layers.Dense(1, input_shape=(1,))]) @@ -96,7 +119,7 @@ def test_hard_terminate_preserves_backup(self): x_nan = np.array([[1.0]]) y_nan = np.array([[np.inf]]) - hard_terminate_callback = HardTerminateOnNaN() + hard_terminate_callback = TerminateOnNaN(hard=True) backup_callback = BackupAndRestore(backup_dir=backup_dir) # Monkeypatch BackupAndRestore to prevent cleanup on train_end @@ -113,25 +136,51 @@ def test_hard_terminate_preserves_backup(self): ) # Verify backup directory still exists and file inside is untouched - assert os.path.exists(backup_dir), ( - f"Backup dir deleted: {backup_dir}" + self.assertTrue( + os.path.exists(backup_dir), + f"Backup dir deleted: {backup_dir}", ) - assert os.path.exists(fake_file), ( - "Backup file missing unexpectedly." + self.assertTrue( + os.path.exists(fake_file), + "Backup file missing unexpectedly.", ) def test_normal_training_does_not_raise(self): - """Test that HardTerminateOnNaN does not raise on normal training.""" + """Test that TerminateOnNaN does not raise on normal training.""" model = models.Sequential([layers.Dense(1, input_shape=(1,))]) model.compile(optimizer="sgd", loss="mse") x = np.array([[1.0], [2.0]]) y = np.array([[1.0], [2.0]]) - callback = HardTerminateOnNaN() + # Test both hard=False and hard=True with normal data + for hard in [False, True]: + callback = TerminateOnNaN(hard=hard) + + # Should complete without raising RuntimeError + history = model.fit(x, y, epochs=2, callbacks=[callback], verbose=0) + + # Should have completed 2 epochs + self.assertEqual(len(history.history["loss"]), 2) - # Should complete without raising RuntimeError - history = model.fit(x, y, epochs=2, callbacks=[callback], verbose=0) + def test_hard_terminate_stops_on_later_batch(self): + """Ensure TerminateOnNaN(hard=True) stops training + if NaN appears in later batch. + """ + model = models.Sequential([layers.Dense(1, input_shape=(1,))]) + model.compile(optimizer="sgd", loss="mse") + + # Batch 1: normal loss, Batch 2: NaN loss + x = np.array([[1.0], [2.0]]) + y = np.array([[1.0], [np.inf]]) # NaN/Inf appears only in 2nd batch + + callback = TerminateOnNaN(hard=True) + + with pytest.raises(RuntimeError) as exc: + model.fit( + x, y, epochs=1, batch_size=1, callbacks=[callback], verbose=0 + ) - # Should have completed 2 epochs - self.assertEqual(len(history.history["loss"]), 2) + # Check that error message references batch 1 + # (0-based indexing, second batch) + assert any(f"batch {i}" in str(exc.value) for i in [0, 1]) diff --git a/keras/src/callbacks/terminate_on_nan.py b/keras/src/callbacks/terminate_on_nan.py index a28327644139..b302560c6632 100644 --- a/keras/src/callbacks/terminate_on_nan.py +++ b/keras/src/callbacks/terminate_on_nan.py @@ -7,43 +7,40 @@ @keras_export("keras.callbacks.TerminateOnNaN") class TerminateOnNaN(Callback): - """Callback that terminates training when a NaN loss is encountered.""" + """Callback that terminates training when a NaN loss is encountered. - def on_batch_end(self, batch, logs=None): - logs = logs or {} - loss = logs.get("loss") - if loss is not None: - if np.isnan(loss) or np.isinf(loss): - io_utils.print_msg( - f"Batch {batch}: Invalid loss, terminating training" - ) - self.model.stop_training = True + This callback monitors the loss value during training + and terminates training when a NaN or Inf loss is detected. + By default, training is stopped gracefully + by setting `model.stop_training = True`, which triggers all callback cleanup + methods including `on_train_end()`. + Alternatively, you can use `hard=True` to immediately raise a RuntimeError + when NaN/Inf is detected. This hard termination prevents `on_train_end()` + from being called on other callbacks, which is useful for preserving backup + states or preventing unintended cleanup when training fails. -@keras_export("keras.callbacks.HardTerminateOnNaN") -class HardTerminateOnNaN(Callback): - """Callback that terminates training immediately - when NaN/Inf loss is detected. - - This callback raises a RuntimeError when a NaN or Inf loss is encountered, - which immediately stops training without triggering `on_train_end()` hooks - for other callbacks. This is useful when you want to preserve backup states - or prevent early stopping from restoring weights after a NaN failure. - - Unlike `TerminateOnNaN`, which gracefully stops training using - `model.stop_training = True` and triggers all callback cleanup methods, - `HardTerminateOnNaN` crashes the training loop immediately. + Args: + hard: Boolean, default False. If False, uses graceful stop via + `model.stop_training = True`. If True, immediately raises + RuntimeError on NaN/Inf loss, bypassing callback cleanup methods. Example: ``` - callback = keras.callbacks.HardTerminateOnNaN() + # Graceful termination (default) + callback = keras.callbacks.TerminateOnNaN() + model.fit(x, y, callbacks=[callback]) + + # Hard termination (strict failure) + callback = keras.callbacks.TerminateOnNaN(hard=True) model.fit(x, y, callbacks=[callback]) ``` """ - def __init__(self): + def __init__(self, hard: bool = False): super().__init__() + self.hard = hard self._supports_tf_logs = True def on_batch_end(self, batch, logs=None): @@ -54,13 +51,19 @@ def on_batch_end(self, batch, logs=None): logs: Dict, contains the return value of `model.train_step()`. Raises: - RuntimeError: If loss is NaN or Inf. + RuntimeError: If loss is NaN/Inf and hard=True. """ logs = logs or {} loss = logs.get("loss") if loss is not None: if np.isnan(loss) or np.isinf(loss): - raise RuntimeError( - f"NaN or Inf loss encountered at batch {batch}. " - f"Loss value: {loss}. Terminating training immediately." - ) + if self.hard: + raise RuntimeError( + f"NaN or Inf loss encountered at batch {batch}. " + f"Loss value: {loss}. Terminating training immediately." + ) + else: + io_utils.print_msg( + f"Batch {batch}: Invalid loss, terminating training" + ) + self.model.stop_training = True