Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion mipcandy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
from mipcandy.evaluation import EvalCase, EvalResult, Evaluator
from mipcandy.frontend import *
from mipcandy.inference import parse_predictant, Predictor
from mipcandy.layer import batch_int_multiply, batch_int_divide, LayerT, HasDevice, WithPaddingModule, auto_device
from mipcandy.layer import batch_int_multiply, batch_int_divide, LayerT, HasDevice, auto_device, WithPaddingModule, \
WithNetwork
from mipcandy.metrics import do_reduction, dice_similarity_coefficient_binary, \
dice_similarity_coefficient_multiclass, soft_dice_coefficient, accuracy_binary, accuracy_multiclass, \
precision_binary, precision_multiclass, recall_binary, recall_multiclass, iou_binary, iou_multiclass
Expand Down
23 changes: 11 additions & 12 deletions mipcandy/inference.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
from abc import ABCMeta, abstractmethod
from abc import ABCMeta
from math import log, ceil
from os import PathLike, listdir
from os.path import isdir, basename, exists
from typing import Sequence, Mapping, Any, override
from typing import Sequence, override

import torch
from torch import nn

from mipcandy.common import Pad2d, Pad3d, Restore2d, Restore3d
from mipcandy.data import save_image, Loader, UnsupervisedDataset, PathBasedUnsupervisedDataset
from mipcandy.layer import WithPaddingModule
from mipcandy.layer import WithPaddingModule, WithNetwork
from mipcandy.sliding_window import SlidingWindow
from mipcandy.types import SupportedPredictant, Device

Expand Down Expand Up @@ -39,24 +39,23 @@ def parse_predictant(x: SupportedPredictant, loader: type[Loader], *, as_label:
return r, filenames


class Predictor(WithPaddingModule, metaclass=ABCMeta):
def __init__(self, experiment_folder: str | PathLike[str], *, checkpoint: str = "checkpoint_best.pth",
device: Device = "cpu") -> None:
super().__init__(device)
class Predictor(WithPaddingModule, WithNetwork, metaclass=ABCMeta):
def __init__(self, experiment_folder: str | PathLike[str], example_shape: tuple[int, ...], *,
checkpoint: str = "checkpoint_best.pth", device: Device = "cpu") -> None:
WithPaddingModule.__init__(self, device)
WithNetwork.__init__(self, device)
self._experiment_folder: str = experiment_folder
self._example_shape: tuple[int, ...] = example_shape
self._checkpoint: str = checkpoint
self._model: nn.Module | None = None

def lazy_load_model(self) -> None:
if self._model:
return
self._model = self.build_network(torch.load(f"{self._experiment_folder}/{self._checkpoint}")).to(self._device)
self._model = self.load_model(self._example_shape,
checkpoint=torch.load(f"{self._experiment_folder}/{self._checkpoint}"))
self._model.eval()

@abstractmethod
def build_network(self, checkpoint: Mapping[str, Any]) -> nn.Module:
raise NotImplementedError

def predict_image(self, image: torch.Tensor, *, batch: bool = False) -> torch.Tensor:
self.lazy_load_model()
image = image.to(self._device)
Expand Down
22 changes: 21 additions & 1 deletion mipcandy/layer.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import Any, Generator, Self
from abc import ABCMeta, abstractmethod
from typing import Any, Generator, Self, Mapping

import torch
from torch import nn
Expand Down Expand Up @@ -93,3 +94,22 @@ def get_padding_module(self) -> nn.Module | None:
def get_restoring_module(self) -> nn.Module | None:
self._lazy_load_padding_module()
return self._restoring_module


class WithNetwork(HasDevice, metaclass=ABCMeta):
def __init__(self, device: Device) -> None:
super().__init__(device)

@abstractmethod
def build_network(self, example_shape: tuple[int, ...]) -> nn.Module:
raise NotImplementedError

def build_network_from_checkpoint(self, example_shape: tuple[int, ...], checkpoint: Mapping[str, Any]) -> nn.Module:
network = self.build_network(example_shape)
network.load_state_dict(checkpoint)
return network

def load_model(self, example_shape: tuple[int, ...], *, checkpoint: Mapping[str, Any] | None = None) -> nn.Module:
if checkpoint:
return self.build_network_from_checkpoint(example_shape, checkpoint).to(self._device)
return self.build_network(example_shape).to(self._device)
57 changes: 45 additions & 12 deletions mipcandy/training.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,20 @@
from abc import ABCMeta, abstractmethod
from dataclasses import dataclass
from dataclasses import dataclass, asdict
from datetime import datetime
from hashlib import md5
from json import load, dump
from os import PathLike, urandom, makedirs, environ
from os.path import exists
from random import seed as random_seed, randint
from shutil import copy
from threading import Lock
from time import time
from typing import Sequence, override, Callable
from typing import Sequence, override, Callable, Self

import numpy as np
import torch
from matplotlib import pyplot as plt
from pandas import DataFrame
from pandas import DataFrame, read_csv
from rich.console import Console
from rich.progress import Progress, SpinnerColumn
from rich.table import Table
Expand All @@ -23,7 +24,7 @@
from mipcandy.common import Pad2d, Pad3d, quotient_regression, quotient_derivative, quotient_bounds
from mipcandy.config import load_settings, load_secrets
from mipcandy.frontend import Frontend
from mipcandy.layer import WithPaddingModule
from mipcandy.layer import WithPaddingModule, WithNetwork
from mipcandy.sanity_check import sanity_check
from mipcandy.sliding_window import SWMetadata, SlidingWindow
from mipcandy.types import Params, Setting
Expand Down Expand Up @@ -57,11 +58,12 @@ class TrainerTracker(object):
worst_case: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None


class Trainer(WithPaddingModule, metaclass=ABCMeta):
class Trainer(WithPaddingModule, WithNetwork, metaclass=ABCMeta):
def __init__(self, trainer_folder: str | PathLike[str], dataloader: DataLoader[tuple[torch.Tensor, torch.Tensor]],
validation_dataloader: DataLoader[tuple[torch.Tensor, torch.Tensor]], *,
device: torch.device | str = "cpu", console: Console = Console()) -> None:
super().__init__(device)
WithPaddingModule.__init__(self, device)
WithNetwork.__init__(self, device)
self._trainer_folder: str = trainer_folder
self._trainer_variant: str = self.__class__.__name__
self._experiment_id: str = "tbd"
Expand All @@ -74,6 +76,39 @@ def __init__(self, trainer_folder: str | PathLike[str], dataloader: DataLoader[t
self._lock: Lock = Lock()
self._tracker: TrainerTracker = TrainerTracker()

# Recovery methods (PR #108 at https://github.com/ProjectNeura/MIPCandy/pull/108)

def save_everything_for_recovery(self, toolbox: TrainerToolbox, tracker: TrainerTracker,
**training_arguments) -> None:
torch.save(toolbox.optimizer, f"{self.experiment_folder()}/optimizer.pth")
torch.save(toolbox.scheduler, f"{self.experiment_folder()}/scheduler.pth")
torch.save(toolbox.criterion, f"{self.experiment_folder()}/criterion.pth")
with open(f"{self.experiment_folder()}/recovery_orbs.json", "w") as f:
dump({"arguments": training_arguments, "tracker": asdict(tracker)}, f)

def load_recovery_orbs(self) -> dict[str, Setting]:
with open(f"{self.experiment_folder()}/recovery_orbs.json") as f:
return load(f)

def load_tracker(self) -> TrainerTracker:
return TrainerTracker(**self.load_recovery_orbs()["tracker"])

def load_training_arguments(self) -> dict[str, Setting]:
return self.filter_train_params(**self.load_recovery_orbs()["arguments"])

def load_metrics(self) -> dict[str, list[float]]:
df = read_csv(f"{self.experiment_folder()}/metrics.csv", index_col="epoch")
return {column: df[column].astype(float).tolist() for column in df.columns}

def recover_from(self, experiment_id: str) -> Self:
self._experiment_id = experiment_id
self._metrics = self.load_metrics()
self._tracker = self.load_tracker()
return self

def continue_training(self, num_epochs: int) -> None:
self.train(num_epochs, **self.load_training_arguments())

# Getters

def trainer_folder(self) -> str:
Expand Down Expand Up @@ -262,10 +297,6 @@ def show_metrics(self, epoch: int, *, metrics: dict[str, list[float]] | None = N

# Builder interfaces

@abstractmethod
def build_network(self, example_shape: tuple[int, ...]) -> nn.Module:
raise NotImplementedError

@abstractmethod
def build_optimizer(self, params: Params) -> optim.Optimizer:
raise NotImplementedError
Expand All @@ -279,7 +310,7 @@ def build_criterion(self) -> nn.Module:
raise NotImplementedError

def build_toolbox(self, num_epochs: int, example_shape: tuple[int, ...]) -> TrainerToolbox:
model = self.build_network(example_shape).to(self._device)
model = self.load_model(example_shape)
optimizer = self.build_optimizer(model.parameters())
scheduler = self.build_scheduler(optimizer, num_epochs)
criterion = self.build_criterion().to(self._device)
Expand Down Expand Up @@ -323,6 +354,7 @@ def train_epoch(self, epoch: int, toolbox: TrainerToolbox) -> None:
def train(self, num_epochs: int, *, note: str = "", num_checkpoints: int = 5, ema: bool = True,
seed: int | None = None, early_stop_tolerance: int = 5, val_score_prediction: bool = True,
val_score_prediction_degree: int = 5, save_preview: bool = True, preview_quality: float = .75) -> None:
training_arguments = locals()
self.init_experiment()
if note:
self.log(f"Note: {note}")
Expand All @@ -349,7 +381,7 @@ def train(self, num_epochs: int, *, note: str = "", num_checkpoints: int = 5, em
sanity_check_result.num_macs, sanity_check_result.num_params, num_epochs,
early_stop_tolerance)
try:
for epoch in range(1, num_epochs + 1):
for epoch in range(self._tracker.epoch, self._tracker.epoch + num_epochs):
if early_stop_tolerance == -1:
epoch -= 1
self.log(f"Early stopping triggered because the validation score has not improved for {
Expand Down Expand Up @@ -400,6 +432,7 @@ def train(self, num_epochs: int, *, note: str = "", num_checkpoints: int = 5, em
self.save_metrics()
self.save_progress()
self.save_metric_curves()
self.save_everything_for_recovery(toolbox, self._tracker, **training_arguments)
self._frontend.on_experiment_updated(self._experiment_id, epoch, self._metrics, early_stop_tolerance)
except Exception as e:
self.log("Training interrupted")
Expand Down
Loading