diff --git a/docs/source/handlers.rst b/docs/source/handlers.rst index 4c312be8b45b..bb9b50b9fb82 100644 --- a/docs/source/handlers.rst +++ b/docs/source/handlers.rst @@ -11,6 +11,7 @@ Complete list of generic handlers :toctree: generated checkpoint.Checkpoint + checkpoint.CheckpointEvents DiskSaver checkpoint.ModelCheckpoint ema_handler.EMAHandler diff --git a/ignite/handlers/checkpoint.py b/ignite/handlers/checkpoint.py index aa408a478ef0..f8648f783a06 100644 --- a/ignite/handlers/checkpoint.py +++ b/ignite/handlers/checkpoint.py @@ -21,10 +21,21 @@ import ignite.distributed as idist from ignite.base import Serializable -from ignite.engine import Engine, Events +from ignite.engine import Engine, Events, EventEnum from ignite.utils import _tree_apply2, _tree_map -__all__ = ["Checkpoint", "DiskSaver", "ModelCheckpoint", "BaseSaveHandler"] +__all__ = ["Checkpoint", "DiskSaver", "ModelCheckpoint", "BaseSaveHandler", "CheckpointEvents"] + + +class CheckpointEvents(EventEnum): + """Events fired by :class:`~ignite.handlers.checkpoint.Checkpoint` + + - SAVED_CHECKPOINT : triggered when checkpoint handler has saved objects + + .. versionadded:: 0.5.3 + """ + + SAVED_CHECKPOINT = "saved_checkpoint" class BaseSaveHandler(metaclass=ABCMeta): @@ -264,6 +275,29 @@ class Checkpoint(Serializable): to_save, save_handler=DiskSaver('/tmp/models', create_dir=True, **kwargs), n_saved=2 ) + Respond to checkpoint events: + + .. code-block:: python + + from ignite.handlers import Checkpoint + from ignite.engine import Engine, Events + + checkpoint_handler = Checkpoint( + {'model': model, 'optimizer': optimizer}, + save_dir, + n_saved=2 + ) + + @trainer.on(Checkpoint.SAVED_CHECKPOINT) + def on_checkpoint_saved(engine): + print(f"Checkpoint saved at epoch {engine.state.epoch}") + + trainer.add_event_handler(Events.EPOCH_COMPLETED, checkpoint_handler) + + Attributes: + SAVED_CHECKPOINT: Alias of ``SAVED_CHECKPOINT`` from + :class:`~ignite.handlers.checkpoint.CheckpointEvents`. + .. versionchanged:: 0.4.3 - Checkpoint can save model with same filename. @@ -274,8 +308,13 @@ class Checkpoint(Serializable): - `score_name` can be used to define `score_function` automatically without providing `score_function`. - `save_handler` automatically saves to disk if path to directory is provided. - `save_on_rank` saves objects on this rank in a distributed configuration. + + .. versionchanged:: 0.5.3 + + - Added ``SAVED_CHECKPOINT`` class attribute. """ + SAVED_CHECKPOINT = CheckpointEvents.SAVED_CHECKPOINT Item = NamedTuple("Item", [("priority", int), ("filename", str)]) _state_dict_all_req_keys = ("_saved",) @@ -400,6 +439,8 @@ def _compare_fn(self, new: Union[int, float]) -> bool: return new > self._saved[0].priority def __call__(self, engine: Engine) -> None: + if not engine.has_registered_events(CheckpointEvents.SAVED_CHECKPOINT): + engine.register_events(*CheckpointEvents) global_step = None if self.global_step_transform is not None: global_step = self.global_step_transform(engine, engine.last_event_name) @@ -460,11 +501,11 @@ def __call__(self, engine: Engine) -> None: if self.include_self: # Now that we've updated _saved, we can add our own state_dict. checkpoint["checkpointer"] = self.state_dict() - try: self.save_handler(checkpoint, filename, metadata) except TypeError: self.save_handler(checkpoint, filename) + engine.fire_event(CheckpointEvents.SAVED_CHECKPOINT) def _setup_checkpoint(self) -> Dict[str, Any]: if self.to_save is not None: diff --git a/tests/ignite/handlers/test_checkpoint.py b/tests/ignite/handlers/test_checkpoint.py index 445c84d7205b..762c1733abde 100644 --- a/tests/ignite/handlers/test_checkpoint.py +++ b/tests/ignite/handlers/test_checkpoint.py @@ -1850,6 +1850,39 @@ def test_load_single_object(obj_to_save, dirname): Checkpoint.load_objects(to_load=to_save, checkpoint=str(checkpoint_fp)) +def test_checkpoint_saved_event(): + """Test that SAVED_CHECKPOINT event is fired correctly.""" + save_handler = MagicMock(spec=BaseSaveHandler) + to_save = {"model": DummyModel()} + + checkpointer = Checkpoint(to_save, save_handler=save_handler, n_saved=2) + + trainer = Engine(lambda e, b: None) + trainer.state = State(epoch=0, iteration=0) + + # Track event firing + event_count = 0 + + # First, call the checkpoint handler to trigger automatic event registration + checkpointer(trainer) + + @trainer.on(Checkpoint.SAVED_CHECKPOINT) + def on_checkpoint_saved(engine): + nonlocal event_count + event_count += 1 + + # Verify the first checkpoint didn't trigger our handler (attached after) + assert event_count == 0 + + # Second checkpoint - should fire event and trigger our handler + trainer.state.iteration = 1 + checkpointer(trainer) + assert event_count == 1 + + # Verify save handler was called twice + assert save_handler.call_count == 2 + + @pytest.mark.distributed @pytest.mark.skipif(not idist.has_native_dist_support, reason="Skip if no native dist support") @pytest.mark.parametrize("atomic", [False, True])