Skip to content

Commit e9ce7f9

Browse files
authored
Support loading a checkpoint in Trainer (#108)
* Added `WithNetwork` abstraction in `layer.py` to streamline network creation and model loading. Updated imports in `__init__.py`. (#107) * Integrated `WithNetwork` into `Predictor` and `Trainer` for unified model handling and removed redundant `build_network()` abstraction. (#107) * Added recovery handling methods in `Trainer` for saving/loading checkpoints, metrics, and training arguments. (#107) * Enhanced training recovery system in `Trainer` by adding `recover_from()` and `continue_training()` methods, tracking training arguments, and updating epoch handling for checkpoint recovery (#107).
1 parent 3895058 commit e9ce7f9

File tree

4 files changed

+79
-26
lines changed

4 files changed

+79
-26
lines changed

mipcandy/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,8 @@
55
from mipcandy.evaluation import EvalCase, EvalResult, Evaluator
66
from mipcandy.frontend import *
77
from mipcandy.inference import parse_predictant, Predictor
8-
from mipcandy.layer import batch_int_multiply, batch_int_divide, LayerT, HasDevice, WithPaddingModule, auto_device
8+
from mipcandy.layer import batch_int_multiply, batch_int_divide, LayerT, HasDevice, auto_device, WithPaddingModule, \
9+
WithNetwork
910
from mipcandy.metrics import do_reduction, dice_similarity_coefficient_binary, \
1011
dice_similarity_coefficient_multiclass, soft_dice_coefficient, accuracy_binary, accuracy_multiclass, \
1112
precision_binary, precision_multiclass, recall_binary, recall_multiclass, iou_binary, iou_multiclass

mipcandy/inference.py

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,15 @@
1-
from abc import ABCMeta, abstractmethod
1+
from abc import ABCMeta
22
from math import log, ceil
33
from os import PathLike, listdir
44
from os.path import isdir, basename, exists
5-
from typing import Sequence, Mapping, Any, override
5+
from typing import Sequence, override
66

77
import torch
88
from torch import nn
99

1010
from mipcandy.common import Pad2d, Pad3d, Restore2d, Restore3d
1111
from mipcandy.data import save_image, Loader, UnsupervisedDataset, PathBasedUnsupervisedDataset
12-
from mipcandy.layer import WithPaddingModule
12+
from mipcandy.layer import WithPaddingModule, WithNetwork
1313
from mipcandy.sliding_window import SlidingWindow
1414
from mipcandy.types import SupportedPredictant, Device
1515

@@ -39,24 +39,23 @@ def parse_predictant(x: SupportedPredictant, loader: type[Loader], *, as_label:
3939
return r, filenames
4040

4141

42-
class Predictor(WithPaddingModule, metaclass=ABCMeta):
43-
def __init__(self, experiment_folder: str | PathLike[str], *, checkpoint: str = "checkpoint_best.pth",
44-
device: Device = "cpu") -> None:
45-
super().__init__(device)
42+
class Predictor(WithPaddingModule, WithNetwork, metaclass=ABCMeta):
43+
def __init__(self, experiment_folder: str | PathLike[str], example_shape: tuple[int, ...], *,
44+
checkpoint: str = "checkpoint_best.pth", device: Device = "cpu") -> None:
45+
WithPaddingModule.__init__(self, device)
46+
WithNetwork.__init__(self, device)
4647
self._experiment_folder: str = experiment_folder
48+
self._example_shape: tuple[int, ...] = example_shape
4749
self._checkpoint: str = checkpoint
4850
self._model: nn.Module | None = None
4951

5052
def lazy_load_model(self) -> None:
5153
if self._model:
5254
return
53-
self._model = self.build_network(torch.load(f"{self._experiment_folder}/{self._checkpoint}")).to(self._device)
55+
self._model = self.load_model(self._example_shape,
56+
checkpoint=torch.load(f"{self._experiment_folder}/{self._checkpoint}"))
5457
self._model.eval()
5558

56-
@abstractmethod
57-
def build_network(self, checkpoint: Mapping[str, Any]) -> nn.Module:
58-
raise NotImplementedError
59-
6059
def predict_image(self, image: torch.Tensor, *, batch: bool = False) -> torch.Tensor:
6160
self.lazy_load_model()
6261
image = image.to(self._device)

mipcandy/layer.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
from typing import Any, Generator, Self
1+
from abc import ABCMeta, abstractmethod
2+
from typing import Any, Generator, Self, Mapping
23

34
import torch
45
from torch import nn
@@ -93,3 +94,22 @@ def get_padding_module(self) -> nn.Module | None:
9394
def get_restoring_module(self) -> nn.Module | None:
9495
self._lazy_load_padding_module()
9596
return self._restoring_module
97+
98+
99+
class WithNetwork(HasDevice, metaclass=ABCMeta):
100+
def __init__(self, device: Device) -> None:
101+
super().__init__(device)
102+
103+
@abstractmethod
104+
def build_network(self, example_shape: tuple[int, ...]) -> nn.Module:
105+
raise NotImplementedError
106+
107+
def build_network_from_checkpoint(self, example_shape: tuple[int, ...], checkpoint: Mapping[str, Any]) -> nn.Module:
108+
network = self.build_network(example_shape)
109+
network.load_state_dict(checkpoint)
110+
return network
111+
112+
def load_model(self, example_shape: tuple[int, ...], *, checkpoint: Mapping[str, Any] | None = None) -> nn.Module:
113+
if checkpoint:
114+
return self.build_network_from_checkpoint(example_shape, checkpoint).to(self._device)
115+
return self.build_network(example_shape).to(self._device)

mipcandy/training.py

Lines changed: 45 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,20 @@
11
from abc import ABCMeta, abstractmethod
2-
from dataclasses import dataclass
2+
from dataclasses import dataclass, asdict
33
from datetime import datetime
44
from hashlib import md5
5+
from json import load, dump
56
from os import PathLike, urandom, makedirs, environ
67
from os.path import exists
78
from random import seed as random_seed, randint
89
from shutil import copy
910
from threading import Lock
1011
from time import time
11-
from typing import Sequence, override, Callable
12+
from typing import Sequence, override, Callable, Self
1213

1314
import numpy as np
1415
import torch
1516
from matplotlib import pyplot as plt
16-
from pandas import DataFrame
17+
from pandas import DataFrame, read_csv
1718
from rich.console import Console
1819
from rich.progress import Progress, SpinnerColumn
1920
from rich.table import Table
@@ -23,7 +24,7 @@
2324
from mipcandy.common import Pad2d, Pad3d, quotient_regression, quotient_derivative, quotient_bounds
2425
from mipcandy.config import load_settings, load_secrets
2526
from mipcandy.frontend import Frontend
26-
from mipcandy.layer import WithPaddingModule
27+
from mipcandy.layer import WithPaddingModule, WithNetwork
2728
from mipcandy.sanity_check import sanity_check
2829
from mipcandy.sliding_window import SWMetadata, SlidingWindow
2930
from mipcandy.types import Params, Setting
@@ -57,11 +58,12 @@ class TrainerTracker(object):
5758
worst_case: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None
5859

5960

60-
class Trainer(WithPaddingModule, metaclass=ABCMeta):
61+
class Trainer(WithPaddingModule, WithNetwork, metaclass=ABCMeta):
6162
def __init__(self, trainer_folder: str | PathLike[str], dataloader: DataLoader[tuple[torch.Tensor, torch.Tensor]],
6263
validation_dataloader: DataLoader[tuple[torch.Tensor, torch.Tensor]], *,
6364
device: torch.device | str = "cpu", console: Console = Console()) -> None:
64-
super().__init__(device)
65+
WithPaddingModule.__init__(self, device)
66+
WithNetwork.__init__(self, device)
6567
self._trainer_folder: str = trainer_folder
6668
self._trainer_variant: str = self.__class__.__name__
6769
self._experiment_id: str = "tbd"
@@ -74,6 +76,39 @@ def __init__(self, trainer_folder: str | PathLike[str], dataloader: DataLoader[t
7476
self._lock: Lock = Lock()
7577
self._tracker: TrainerTracker = TrainerTracker()
7678

79+
# Recovery methods (PR #108 at https://github.com/ProjectNeura/MIPCandy/pull/108)
80+
81+
def save_everything_for_recovery(self, toolbox: TrainerToolbox, tracker: TrainerTracker,
82+
**training_arguments) -> None:
83+
torch.save(toolbox.optimizer, f"{self.experiment_folder()}/optimizer.pth")
84+
torch.save(toolbox.scheduler, f"{self.experiment_folder()}/scheduler.pth")
85+
torch.save(toolbox.criterion, f"{self.experiment_folder()}/criterion.pth")
86+
with open(f"{self.experiment_folder()}/recovery_orbs.json", "w") as f:
87+
dump({"arguments": training_arguments, "tracker": asdict(tracker)}, f)
88+
89+
def load_recovery_orbs(self) -> dict[str, Setting]:
90+
with open(f"{self.experiment_folder()}/recovery_orbs.json") as f:
91+
return load(f)
92+
93+
def load_tracker(self) -> TrainerTracker:
94+
return TrainerTracker(**self.load_recovery_orbs()["tracker"])
95+
96+
def load_training_arguments(self) -> dict[str, Setting]:
97+
return self.filter_train_params(**self.load_recovery_orbs()["arguments"])
98+
99+
def load_metrics(self) -> dict[str, list[float]]:
100+
df = read_csv(f"{self.experiment_folder()}/metrics.csv", index_col="epoch")
101+
return {column: df[column].astype(float).tolist() for column in df.columns}
102+
103+
def recover_from(self, experiment_id: str) -> Self:
104+
self._experiment_id = experiment_id
105+
self._metrics = self.load_metrics()
106+
self._tracker = self.load_tracker()
107+
return self
108+
109+
def continue_training(self, num_epochs: int) -> None:
110+
self.train(num_epochs, **self.load_training_arguments())
111+
77112
# Getters
78113

79114
def trainer_folder(self) -> str:
@@ -262,10 +297,6 @@ def show_metrics(self, epoch: int, *, metrics: dict[str, list[float]] | None = N
262297

263298
# Builder interfaces
264299

265-
@abstractmethod
266-
def build_network(self, example_shape: tuple[int, ...]) -> nn.Module:
267-
raise NotImplementedError
268-
269300
@abstractmethod
270301
def build_optimizer(self, params: Params) -> optim.Optimizer:
271302
raise NotImplementedError
@@ -279,7 +310,7 @@ def build_criterion(self) -> nn.Module:
279310
raise NotImplementedError
280311

281312
def build_toolbox(self, num_epochs: int, example_shape: tuple[int, ...]) -> TrainerToolbox:
282-
model = self.build_network(example_shape).to(self._device)
313+
model = self.load_model(example_shape)
283314
optimizer = self.build_optimizer(model.parameters())
284315
scheduler = self.build_scheduler(optimizer, num_epochs)
285316
criterion = self.build_criterion().to(self._device)
@@ -323,6 +354,7 @@ def train_epoch(self, epoch: int, toolbox: TrainerToolbox) -> None:
323354
def train(self, num_epochs: int, *, note: str = "", num_checkpoints: int = 5, ema: bool = True,
324355
seed: int | None = None, early_stop_tolerance: int = 5, val_score_prediction: bool = True,
325356
val_score_prediction_degree: int = 5, save_preview: bool = True, preview_quality: float = .75) -> None:
357+
training_arguments = locals()
326358
self.init_experiment()
327359
if note:
328360
self.log(f"Note: {note}")
@@ -349,7 +381,7 @@ def train(self, num_epochs: int, *, note: str = "", num_checkpoints: int = 5, em
349381
sanity_check_result.num_macs, sanity_check_result.num_params, num_epochs,
350382
early_stop_tolerance)
351383
try:
352-
for epoch in range(1, num_epochs + 1):
384+
for epoch in range(self._tracker.epoch, self._tracker.epoch + num_epochs):
353385
if early_stop_tolerance == -1:
354386
epoch -= 1
355387
self.log(f"Early stopping triggered because the validation score has not improved for {
@@ -400,6 +432,7 @@ def train(self, num_epochs: int, *, note: str = "", num_checkpoints: int = 5, em
400432
self.save_metrics()
401433
self.save_progress()
402434
self.save_metric_curves()
435+
self.save_everything_for_recovery(toolbox, self._tracker, **training_arguments)
403436
self._frontend.on_experiment_updated(self._experiment_id, epoch, self._metrics, early_stop_tolerance)
404437
except Exception as e:
405438
self.log("Training interrupted")

0 commit comments

Comments
 (0)