Skip to content

Conversation

@spencerirving
Copy link

Description

Adds optional checkpoint saving functionality to the Trainer class, enabling users to save model state during training without modifying existing behavior.

This fixes issue #48

Changes Made

Added checkpoint parameters to Trainer.__init__():

  • checkpoint_dir: Directory for saving checkpoints
  • checkpoint_epochs: When to save(int for intervals, list for specific epochs)
  • checkpoint_name: Base filename for checkpoints
  • save_best: Whether to track and save best model

Added private helper methods:

  • _should_save_checkpoint(): Determines when to save based setup
  • _save_checkpoint(): Handles checkpoint file creation

Added public method:

  • load_checkpoint: Restores model and optimizer state from checkpoint

Integrated checkpoint logic into 'fit()' method.

Usage

Basic Usage - Save Every 10 Epochs

trainer = Trainer(
model, loss_fn, optimizer, 
epochs=100, 
checkpoint_dir="./checkpoints", checkpoint_epochs=10, 
checkpoint_name="my_model"
)
trainer.fit(train_loader)

Track Best Model

trainer = Trainer(
model, loss_fn, optimizer, 
epochs=100, 
checkpoint_dir="./checkpoints", checkpoint_epochs=10, 
checkpoint_name="my_model",
save_best=True # Saves best model as my_model_best.pt
)
trainer.fit(train_loader)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant