From bb49be5d11721c8256d85b65439fcf2e5858811b Mon Sep 17 00:00:00 2001 From: ATATC Date: Wed, 12 Nov 2025 20:48:34 -0500 Subject: [PATCH 1/4] Added `WithNetwork` abstraction in `layer.py` to streamline network creation and model loading. Updated imports in `__init__.py`. (#107) --- mipcandy/__init__.py | 3 ++- mipcandy/layer.py | 22 +++++++++++++++++++++- 2 files changed, 23 insertions(+), 2 deletions(-) diff --git a/mipcandy/__init__.py b/mipcandy/__init__.py index 80fde57..e39bfe2 100644 --- a/mipcandy/__init__.py +++ b/mipcandy/__init__.py @@ -5,7 +5,8 @@ from mipcandy.evaluation import EvalCase, EvalResult, Evaluator from mipcandy.frontend import * from mipcandy.inference import parse_predictant, Predictor -from mipcandy.layer import batch_int_multiply, batch_int_divide, LayerT, HasDevice, WithPaddingModule, auto_device +from mipcandy.layer import batch_int_multiply, batch_int_divide, LayerT, HasDevice, auto_device, WithPaddingModule, \ + WithNetwork from mipcandy.metrics import do_reduction, dice_similarity_coefficient_binary, \ dice_similarity_coefficient_multiclass, soft_dice_coefficient, accuracy_binary, accuracy_multiclass, \ precision_binary, precision_multiclass, recall_binary, recall_multiclass, iou_binary, iou_multiclass diff --git a/mipcandy/layer.py b/mipcandy/layer.py index 4baf9f9..6393ea8 100644 --- a/mipcandy/layer.py +++ b/mipcandy/layer.py @@ -1,4 +1,5 @@ -from typing import Any, Generator, Self +from abc import ABCMeta, abstractmethod +from typing import Any, Generator, Self, Mapping import torch from torch import nn @@ -93,3 +94,22 @@ def get_padding_module(self) -> nn.Module | None: def get_restoring_module(self) -> nn.Module | None: self._lazy_load_padding_module() return self._restoring_module + + +class WithNetwork(HasDevice, metaclass=ABCMeta): + def __init__(self, device: Device) -> None: + super().__init__(device) + + @abstractmethod + def build_network(self, example_shape: tuple[int, ...]) -> nn.Module: + raise NotImplementedError + + def build_network_from_checkpoint(self, example_shape: tuple[int, ...], checkpoint: Mapping[str, Any]) -> nn.Module: + network = self.build_network(example_shape) + network.load_state_dict(checkpoint) + return network + + def load_model(self, example_shape: tuple[int, ...], *, checkpoint: Mapping[str, Any] | None = None) -> nn.Module: + if checkpoint: + return self.build_network_from_checkpoint(example_shape, checkpoint).to(self._device) + return self.build_network(example_shape).to(self._device) From 8620d9dc431fc362f91fa05da2512f0be21c9685 Mon Sep 17 00:00:00 2001 From: ATATC Date: Wed, 12 Nov 2025 20:49:17 -0500 Subject: [PATCH 2/4] Integrated `WithNetwork` into `Predictor` and `Trainer` for unified model handling and removed redundant `build_network()` abstraction. (#107) --- mipcandy/inference.py | 23 +++++++++++------------ mipcandy/training.py | 13 +++++-------- 2 files changed, 16 insertions(+), 20 deletions(-) diff --git a/mipcandy/inference.py b/mipcandy/inference.py index 07c447e..882fd30 100644 --- a/mipcandy/inference.py +++ b/mipcandy/inference.py @@ -1,15 +1,15 @@ -from abc import ABCMeta, abstractmethod +from abc import ABCMeta from math import log, ceil from os import PathLike, listdir from os.path import isdir, basename, exists -from typing import Sequence, Mapping, Any, override +from typing import Sequence, override import torch from torch import nn from mipcandy.common import Pad2d, Pad3d, Restore2d, Restore3d from mipcandy.data import save_image, Loader, UnsupervisedDataset, PathBasedUnsupervisedDataset -from mipcandy.layer import WithPaddingModule +from mipcandy.layer import WithPaddingModule, WithNetwork from mipcandy.sliding_window import SlidingWindow from mipcandy.types import SupportedPredictant, Device @@ -39,24 +39,23 @@ def parse_predictant(x: SupportedPredictant, loader: type[Loader], *, as_label: return r, filenames -class Predictor(WithPaddingModule, metaclass=ABCMeta): - def __init__(self, experiment_folder: str | PathLike[str], *, checkpoint: str = "checkpoint_best.pth", - device: Device = "cpu") -> None: - super().__init__(device) +class Predictor(WithPaddingModule, WithNetwork, metaclass=ABCMeta): + def __init__(self, experiment_folder: str | PathLike[str], example_shape: tuple[int, ...], *, + checkpoint: str = "checkpoint_best.pth", device: Device = "cpu") -> None: + WithPaddingModule.__init__(self, device) + WithNetwork.__init__(self, device) self._experiment_folder: str = experiment_folder + self._example_shape: tuple[int, ...] = example_shape self._checkpoint: str = checkpoint self._model: nn.Module | None = None def lazy_load_model(self) -> None: if self._model: return - self._model = self.build_network(torch.load(f"{self._experiment_folder}/{self._checkpoint}")).to(self._device) + self._model = self.load_model(self._example_shape, + checkpoint=torch.load(f"{self._experiment_folder}/{self._checkpoint}")) self._model.eval() - @abstractmethod - def build_network(self, checkpoint: Mapping[str, Any]) -> nn.Module: - raise NotImplementedError - def predict_image(self, image: torch.Tensor, *, batch: bool = False) -> torch.Tensor: self.lazy_load_model() image = image.to(self._device) diff --git a/mipcandy/training.py b/mipcandy/training.py index f285868..590c858 100644 --- a/mipcandy/training.py +++ b/mipcandy/training.py @@ -23,7 +23,7 @@ from mipcandy.common import Pad2d, Pad3d, quotient_regression, quotient_derivative, quotient_bounds from mipcandy.config import load_settings, load_secrets from mipcandy.frontend import Frontend -from mipcandy.layer import WithPaddingModule +from mipcandy.layer import WithPaddingModule, WithNetwork from mipcandy.sanity_check import sanity_check from mipcandy.sliding_window import SWMetadata, SlidingWindow from mipcandy.types import Params, Setting @@ -57,11 +57,12 @@ class TrainerTracker(object): worst_case: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None -class Trainer(WithPaddingModule, metaclass=ABCMeta): +class Trainer(WithPaddingModule, WithNetwork, metaclass=ABCMeta): def __init__(self, trainer_folder: str | PathLike[str], dataloader: DataLoader[tuple[torch.Tensor, torch.Tensor]], validation_dataloader: DataLoader[tuple[torch.Tensor, torch.Tensor]], *, device: torch.device | str = "cpu", console: Console = Console()) -> None: - super().__init__(device) + WithPaddingModule.__init__(self, device) + WithNetwork.__init__(self, device) self._trainer_folder: str = trainer_folder self._trainer_variant: str = self.__class__.__name__ self._experiment_id: str = "tbd" @@ -262,10 +263,6 @@ def show_metrics(self, epoch: int, *, metrics: dict[str, list[float]] | None = N # Builder interfaces - @abstractmethod - def build_network(self, example_shape: tuple[int, ...]) -> nn.Module: - raise NotImplementedError - @abstractmethod def build_optimizer(self, params: Params) -> optim.Optimizer: raise NotImplementedError @@ -279,7 +276,7 @@ def build_criterion(self) -> nn.Module: raise NotImplementedError def build_toolbox(self, num_epochs: int, example_shape: tuple[int, ...]) -> TrainerToolbox: - model = self.build_network(example_shape).to(self._device) + model = self.load_model(example_shape) optimizer = self.build_optimizer(model.parameters()) scheduler = self.build_scheduler(optimizer, num_epochs) criterion = self.build_criterion().to(self._device) From 6114e67b280f5832a1445aae5b9478d4b77b9c5a Mon Sep 17 00:00:00 2001 From: ATATC Date: Wed, 12 Nov 2025 22:00:29 -0500 Subject: [PATCH 3/4] Added recovery handling methods in `Trainer` for saving/loading checkpoints, metrics, and training arguments. (#107) --- mipcandy/training.py | 29 +++++++++++++++++++++++++++-- 1 file changed, 27 insertions(+), 2 deletions(-) diff --git a/mipcandy/training.py b/mipcandy/training.py index 590c858..563529b 100644 --- a/mipcandy/training.py +++ b/mipcandy/training.py @@ -1,7 +1,8 @@ from abc import ABCMeta, abstractmethod -from dataclasses import dataclass +from dataclasses import dataclass, asdict from datetime import datetime from hashlib import md5 +from json import load, dump from os import PathLike, urandom, makedirs, environ from os.path import exists from random import seed as random_seed, randint @@ -13,7 +14,7 @@ import numpy as np import torch from matplotlib import pyplot as plt -from pandas import DataFrame +from pandas import DataFrame, read_csv from rich.console import Console from rich.progress import Progress, SpinnerColumn from rich.table import Table @@ -75,6 +76,30 @@ def __init__(self, trainer_folder: str | PathLike[str], dataloader: DataLoader[t self._lock: Lock = Lock() self._tracker: TrainerTracker = TrainerTracker() + # Recovery methods (PR #108 at https://github.com/ProjectNeura/MIPCandy/pull/108) + + def save_everything_for_recovery(self, toolbox: TrainerToolbox, tracker: TrainerTracker, + **training_arguments) -> None: + torch.save(toolbox.optimizer, f"{self.experiment_folder()}/optimizer.pth") + torch.save(toolbox.scheduler, f"{self.experiment_folder()}/scheduler.pth") + torch.save(toolbox.criterion, f"{self.experiment_folder()}/criterion.pth") + with open(f"{self.experiment_folder()}/recovery_orbs.json", "w") as f: + dump({"arguments": training_arguments, "tracker": asdict(tracker)}, f) + + def load_recovery_orbs(self) -> dict[str, Setting]: + with open(f"{self.experiment_folder()}/recovery_orbs.json") as f: + return load(f) + + def load_tracker(self) -> TrainerTracker: + return TrainerTracker(**self.load_recovery_orbs()["tracker"]) + + def load_training_arguments(self) -> dict[str, Setting]: + return self.filter_train_params(**self.load_recovery_orbs()["arguments"]) + + def load_metrics(self) -> dict[str, list[float]]: + df = read_csv(f"{self.experiment_folder()}/metrics.csv", index_col="epoch") + return {column: df[column].astype(float).tolist() for column in df.columns} + # Getters def trainer_folder(self) -> str: From 245dad4c993e43277bd46b9b921d20cc48dfef4a Mon Sep 17 00:00:00 2001 From: ATATC Date: Sun, 30 Nov 2025 13:00:52 -0500 Subject: [PATCH 4/4] Enhanced training recovery system in `Trainer` by adding `recover_from()` and `continue_training()` methods, tracking training arguments, and updating epoch handling for checkpoint recovery (#107). --- mipcandy/training.py | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/mipcandy/training.py b/mipcandy/training.py index 563529b..d6b64fe 100644 --- a/mipcandy/training.py +++ b/mipcandy/training.py @@ -9,7 +9,7 @@ from shutil import copy from threading import Lock from time import time -from typing import Sequence, override, Callable +from typing import Sequence, override, Callable, Self import numpy as np import torch @@ -100,6 +100,15 @@ def load_metrics(self) -> dict[str, list[float]]: df = read_csv(f"{self.experiment_folder()}/metrics.csv", index_col="epoch") return {column: df[column].astype(float).tolist() for column in df.columns} + def recover_from(self, experiment_id: str) -> Self: + self._experiment_id = experiment_id + self._metrics = self.load_metrics() + self._tracker = self.load_tracker() + return self + + def continue_training(self, num_epochs: int) -> None: + self.train(num_epochs, **self.load_training_arguments()) + # Getters def trainer_folder(self) -> str: @@ -345,6 +354,7 @@ def train_epoch(self, epoch: int, toolbox: TrainerToolbox) -> None: def train(self, num_epochs: int, *, note: str = "", num_checkpoints: int = 5, ema: bool = True, seed: int | None = None, early_stop_tolerance: int = 5, val_score_prediction: bool = True, val_score_prediction_degree: int = 5, save_preview: bool = True, preview_quality: float = .75) -> None: + training_arguments = locals() self.init_experiment() if note: self.log(f"Note: {note}") @@ -371,7 +381,7 @@ def train(self, num_epochs: int, *, note: str = "", num_checkpoints: int = 5, em sanity_check_result.num_macs, sanity_check_result.num_params, num_epochs, early_stop_tolerance) try: - for epoch in range(1, num_epochs + 1): + for epoch in range(self._tracker.epoch, self._tracker.epoch + num_epochs): if early_stop_tolerance == -1: epoch -= 1 self.log(f"Early stopping triggered because the validation score has not improved for { @@ -422,6 +432,7 @@ def train(self, num_epochs: int, *, note: str = "", num_checkpoints: int = 5, em self.save_metrics() self.save_progress() self.save_metric_curves() + self.save_everything_for_recovery(toolbox, self._tracker, **training_arguments) self._frontend.on_experiment_updated(self._experiment_id, epoch, self._metrics, early_stop_tolerance) except Exception as e: self.log("Training interrupted")