Skip to content
1 change: 1 addition & 0 deletions docs/source/handlers.rst
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ Complete list of generic handlers
:toctree: generated

checkpoint.Checkpoint
checkpoint.CheckpointEvents
DiskSaver
checkpoint.ModelCheckpoint
ema_handler.EMAHandler
Expand Down
47 changes: 44 additions & 3 deletions ignite/handlers/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,21 @@

import ignite.distributed as idist
from ignite.base import Serializable
from ignite.engine import Engine, Events
from ignite.engine import Engine, Events, EventEnum
from ignite.utils import _tree_apply2, _tree_map

__all__ = ["Checkpoint", "DiskSaver", "ModelCheckpoint", "BaseSaveHandler"]
__all__ = ["Checkpoint", "DiskSaver", "ModelCheckpoint", "BaseSaveHandler", "CheckpointEvents"]


class CheckpointEvents(EventEnum):
"""Events fired by :class:`~ignite.handlers.checkpoint.Checkpoint`

- SAVED_CHECKPOINT : triggered when checkpoint handler has saved objects

.. versionadded:: 0.5.3
"""

SAVED_CHECKPOINT = "saved_checkpoint"


class BaseSaveHandler(metaclass=ABCMeta):
Expand Down Expand Up @@ -264,6 +275,29 @@ class Checkpoint(Serializable):
to_save, save_handler=DiskSaver('/tmp/models', create_dir=True, **kwargs), n_saved=2
)

Respond to checkpoint events:

.. code-block:: python

from ignite.handlers import Checkpoint
from ignite.engine import Engine, Events

checkpoint_handler = Checkpoint(
{'model': model, 'optimizer': optimizer},
save_dir,
n_saved=2
)

@trainer.on(Checkpoint.SAVED_CHECKPOINT)
def on_checkpoint_saved(engine):
print(f"Checkpoint saved at epoch {engine.state.epoch}")

trainer.add_event_handler(Events.EPOCH_COMPLETED, checkpoint_handler)

Attributes:
SAVED_CHECKPOINT: Alias of ``SAVED_CHECKPOINT`` from
:class:`~ignite.handlers.checkpoint.CheckpointEvents`.

.. versionchanged:: 0.4.3

- Checkpoint can save model with same filename.
Expand All @@ -274,8 +308,13 @@ class Checkpoint(Serializable):
- `score_name` can be used to define `score_function` automatically without providing `score_function`.
- `save_handler` automatically saves to disk if path to directory is provided.
- `save_on_rank` saves objects on this rank in a distributed configuration.

.. versionchanged:: 0.5.3

- Added ``SAVED_CHECKPOINT`` class attribute.
"""

SAVED_CHECKPOINT = CheckpointEvents.SAVED_CHECKPOINT
Item = NamedTuple("Item", [("priority", int), ("filename", str)])
_state_dict_all_req_keys = ("_saved",)

Expand Down Expand Up @@ -400,6 +439,8 @@ def _compare_fn(self, new: Union[int, float]) -> bool:
return new > self._saved[0].priority

def __call__(self, engine: Engine) -> None:
if not engine.has_registered_events(CheckpointEvents.SAVED_CHECKPOINT):
engine.register_events(*CheckpointEvents)
global_step = None
if self.global_step_transform is not None:
global_step = self.global_step_transform(engine, engine.last_event_name)
Expand Down Expand Up @@ -460,11 +501,11 @@ def __call__(self, engine: Engine) -> None:
if self.include_self:
# Now that we've updated _saved, we can add our own state_dict.
checkpoint["checkpointer"] = self.state_dict()

try:
self.save_handler(checkpoint, filename, metadata)
except TypeError:
self.save_handler(checkpoint, filename)
engine.fire_event(CheckpointEvents.SAVED_CHECKPOINT)

def _setup_checkpoint(self) -> Dict[str, Any]:
if self.to_save is not None:
Expand Down
33 changes: 33 additions & 0 deletions tests/ignite/handlers/test_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -1850,6 +1850,39 @@ def test_load_single_object(obj_to_save, dirname):
Checkpoint.load_objects(to_load=to_save, checkpoint=str(checkpoint_fp))


def test_checkpoint_saved_event():
"""Test that SAVED_CHECKPOINT event is fired correctly."""
save_handler = MagicMock(spec=BaseSaveHandler)
to_save = {"model": DummyModel()}

checkpointer = Checkpoint(to_save, save_handler=save_handler, n_saved=2)

trainer = Engine(lambda e, b: None)
trainer.state = State(epoch=0, iteration=0)

# Track event firing
event_count = 0

# First, call the checkpoint handler to trigger automatic event registration
checkpointer(trainer)

@trainer.on(Checkpoint.SAVED_CHECKPOINT)
def on_checkpoint_saved(engine):
nonlocal event_count
event_count += 1

# Verify the first checkpoint didn't trigger our handler (attached after)
assert event_count == 0

# Second checkpoint - should fire event and trigger our handler
trainer.state.iteration = 1
checkpointer(trainer)
assert event_count == 1

# Verify save handler was called twice
assert save_handler.call_count == 2


@pytest.mark.distributed
@pytest.mark.skipif(not idist.has_native_dist_support, reason="Skip if no native dist support")
@pytest.mark.parametrize("atomic", [False, True])
Expand Down
Loading