Skip to content

Commit a3c1528

Browse files
committed
Use engine.has_registered_events method instead of private attribute
1 parent 02313d2 commit a3c1528

File tree

2 files changed

+9
-17
lines changed

2 files changed

+9
-17
lines changed

ignite/handlers/checkpoint.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -407,10 +407,8 @@ def _compare_fn(self, new: Union[int, float]) -> bool:
407407
return new > self._saved[0].priority
408408

409409
def __call__(self, engine: Engine) -> None:
410-
# Register the custom event if not already registered
411-
if not hasattr(engine, "_checkpoint_events_registered"):
410+
if not engine.has_registered_events(CheckpointEvents.SAVED_CHECKPOINT):
412411
engine.register_events(CheckpointEvents.SAVED_CHECKPOINT)
413-
engine._checkpoint_events_registered = True
414412
global_step = None
415413
if self.global_step_transform is not None:
416414
global_step = self.global_step_transform(engine, engine.last_event_name)

tests/ignite/handlers/test_checkpoint.py

Lines changed: 8 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1852,7 +1852,6 @@ def test_load_single_object(obj_to_save, dirname):
18521852

18531853
def test_checkpoint_saved_event():
18541854
"""Test that SAVED_CHECKPOINT event is fired correctly."""
1855-
from ignite.handlers.checkpoint import CheckpointEvents
18561855

18571856
save_handler = MagicMock(spec=BaseSaveHandler)
18581857
to_save = {"model": DummyModel()}
@@ -1862,31 +1861,26 @@ def test_checkpoint_saved_event():
18621861
trainer = Engine(lambda e, b: None)
18631862
trainer.state = State(epoch=0, iteration=0)
18641863

1865-
# Register the event first
1866-
trainer.register_events(CheckpointEvents.SAVED_CHECKPOINT)
1867-
18681864
# Track event firing
18691865
event_count = 0
1870-
received_handlers = []
1866+
1867+
# First, call the checkpoint handler to trigger automatic event registration
1868+
checkpointer(trainer)
18711869

18721870
@trainer.on(Checkpoint.SAVED_CHECKPOINT)
18731871
def on_checkpoint_saved(engine):
18741872
nonlocal event_count
18751873
event_count += 1
1876-
received_handlers.append(engine._current_checkpoint_handler)
18771874

1878-
# First checkpoint - should fire event
1879-
checkpointer(trainer)
1880-
assert event_count == 1
1881-
assert received_handlers[0] is checkpointer
1875+
# Verify the first checkpoint didn't trigger our handler (attached after)
1876+
assert event_count == 0
18821877

1883-
# Second checkpoint - should fire event
1878+
# Second checkpoint - should fire event and trigger our handler
18841879
trainer.state.iteration = 1
18851880
checkpointer(trainer)
1886-
assert event_count == 2
1887-
assert received_handlers[1] is checkpointer
1881+
assert event_count == 1
18881882

1889-
# Verify save handler was called
1883+
# Verify save handler was called twice
18901884
assert save_handler.call_count == 2
18911885

18921886

0 commit comments

Comments
 (0)