@@ -1852,7 +1852,6 @@ def test_load_single_object(obj_to_save, dirname):
1852
1852
1853
1853
def test_checkpoint_saved_event ():
1854
1854
"""Test that SAVED_CHECKPOINT event is fired correctly."""
1855
- from ignite .handlers .checkpoint import CheckpointEvents
1856
1855
1857
1856
save_handler = MagicMock (spec = BaseSaveHandler )
1858
1857
to_save = {"model" : DummyModel ()}
@@ -1862,31 +1861,26 @@ def test_checkpoint_saved_event():
1862
1861
trainer = Engine (lambda e , b : None )
1863
1862
trainer .state = State (epoch = 0 , iteration = 0 )
1864
1863
1865
- # Register the event first
1866
- trainer .register_events (CheckpointEvents .SAVED_CHECKPOINT )
1867
-
1868
1864
# Track event firing
1869
1865
event_count = 0
1870
- received_handlers = []
1866
+
1867
+ # First, call the checkpoint handler to trigger automatic event registration
1868
+ checkpointer (trainer )
1871
1869
1872
1870
@trainer .on (Checkpoint .SAVED_CHECKPOINT )
1873
1871
def on_checkpoint_saved (engine ):
1874
1872
nonlocal event_count
1875
1873
event_count += 1
1876
- received_handlers .append (engine ._current_checkpoint_handler )
1877
1874
1878
- # First checkpoint - should fire event
1879
- checkpointer (trainer )
1880
- assert event_count == 1
1881
- assert received_handlers [0 ] is checkpointer
1875
+ # Verify the first checkpoint didn't trigger our handler (attached after)
1876
+ assert event_count == 0
1882
1877
1883
- # Second checkpoint - should fire event
1878
+ # Second checkpoint - should fire event and trigger our handler
1884
1879
trainer .state .iteration = 1
1885
1880
checkpointer (trainer )
1886
- assert event_count == 2
1887
- assert received_handlers [1 ] is checkpointer
1881
+ assert event_count == 1
1888
1882
1889
- # Verify save handler was called
1883
+ # Verify save handler was called twice
1890
1884
assert save_handler .call_count == 2
1891
1885
1892
1886
0 commit comments