11from abc import ABCMeta , abstractmethod
2- from dataclasses import dataclass
2+ from dataclasses import dataclass , asdict
33from datetime import datetime
44from hashlib import md5
5+ from json import load , dump
56from os import PathLike , urandom , makedirs , environ
67from os .path import exists
78from random import seed as random_seed , randint
89from shutil import copy
910from threading import Lock
1011from time import time
11- from typing import Sequence , override , Callable
12+ from typing import Sequence , override , Callable , Self
1213
1314import numpy as np
1415import torch
1516from matplotlib import pyplot as plt
16- from pandas import DataFrame
17+ from pandas import DataFrame , read_csv
1718from rich .console import Console
1819from rich .progress import Progress , SpinnerColumn
1920from rich .table import Table
2324from mipcandy .common import Pad2d , Pad3d , quotient_regression , quotient_derivative , quotient_bounds
2425from mipcandy .config import load_settings , load_secrets
2526from mipcandy .frontend import Frontend
26- from mipcandy .layer import WithPaddingModule
27+ from mipcandy .layer import WithPaddingModule , WithNetwork
2728from mipcandy .sanity_check import sanity_check
2829from mipcandy .sliding_window import SWMetadata , SlidingWindow
2930from mipcandy .types import Params , Setting
@@ -57,11 +58,12 @@ class TrainerTracker(object):
5758 worst_case : tuple [torch .Tensor , torch .Tensor , torch .Tensor ] | None = None
5859
5960
60- class Trainer (WithPaddingModule , metaclass = ABCMeta ):
61+ class Trainer (WithPaddingModule , WithNetwork , metaclass = ABCMeta ):
6162 def __init__ (self , trainer_folder : str | PathLike [str ], dataloader : DataLoader [tuple [torch .Tensor , torch .Tensor ]],
6263 validation_dataloader : DataLoader [tuple [torch .Tensor , torch .Tensor ]], * ,
6364 device : torch .device | str = "cpu" , console : Console = Console ()) -> None :
64- super ().__init__ (device )
65+ WithPaddingModule .__init__ (self , device )
66+ WithNetwork .__init__ (self , device )
6567 self ._trainer_folder : str = trainer_folder
6668 self ._trainer_variant : str = self .__class__ .__name__
6769 self ._experiment_id : str = "tbd"
@@ -74,6 +76,39 @@ def __init__(self, trainer_folder: str | PathLike[str], dataloader: DataLoader[t
7476 self ._lock : Lock = Lock ()
7577 self ._tracker : TrainerTracker = TrainerTracker ()
7678
79+ # Recovery methods (PR #108 at https://github.com/ProjectNeura/MIPCandy/pull/108)
80+
81+ def save_everything_for_recovery (self , toolbox : TrainerToolbox , tracker : TrainerTracker ,
82+ ** training_arguments ) -> None :
83+ torch .save (toolbox .optimizer , f"{ self .experiment_folder ()} /optimizer.pth" )
84+ torch .save (toolbox .scheduler , f"{ self .experiment_folder ()} /scheduler.pth" )
85+ torch .save (toolbox .criterion , f"{ self .experiment_folder ()} /criterion.pth" )
86+ with open (f"{ self .experiment_folder ()} /recovery_orbs.json" , "w" ) as f :
87+ dump ({"arguments" : training_arguments , "tracker" : asdict (tracker )}, f )
88+
89+ def load_recovery_orbs (self ) -> dict [str , Setting ]:
90+ with open (f"{ self .experiment_folder ()} /recovery_orbs.json" ) as f :
91+ return load (f )
92+
93+ def load_tracker (self ) -> TrainerTracker :
94+ return TrainerTracker (** self .load_recovery_orbs ()["tracker" ])
95+
96+ def load_training_arguments (self ) -> dict [str , Setting ]:
97+ return self .filter_train_params (** self .load_recovery_orbs ()["arguments" ])
98+
99+ def load_metrics (self ) -> dict [str , list [float ]]:
100+ df = read_csv (f"{ self .experiment_folder ()} /metrics.csv" , index_col = "epoch" )
101+ return {column : df [column ].astype (float ).tolist () for column in df .columns }
102+
103+ def recover_from (self , experiment_id : str ) -> Self :
104+ self ._experiment_id = experiment_id
105+ self ._metrics = self .load_metrics ()
106+ self ._tracker = self .load_tracker ()
107+ return self
108+
109+ def continue_training (self , num_epochs : int ) -> None :
110+ self .train (num_epochs , ** self .load_training_arguments ())
111+
77112 # Getters
78113
79114 def trainer_folder (self ) -> str :
@@ -262,10 +297,6 @@ def show_metrics(self, epoch: int, *, metrics: dict[str, list[float]] | None = N
262297
263298 # Builder interfaces
264299
265- @abstractmethod
266- def build_network (self , example_shape : tuple [int , ...]) -> nn .Module :
267- raise NotImplementedError
268-
269300 @abstractmethod
270301 def build_optimizer (self , params : Params ) -> optim .Optimizer :
271302 raise NotImplementedError
@@ -279,7 +310,7 @@ def build_criterion(self) -> nn.Module:
279310 raise NotImplementedError
280311
281312 def build_toolbox (self , num_epochs : int , example_shape : tuple [int , ...]) -> TrainerToolbox :
282- model = self .build_network (example_shape ). to ( self . _device )
313+ model = self .load_model (example_shape )
283314 optimizer = self .build_optimizer (model .parameters ())
284315 scheduler = self .build_scheduler (optimizer , num_epochs )
285316 criterion = self .build_criterion ().to (self ._device )
@@ -323,6 +354,7 @@ def train_epoch(self, epoch: int, toolbox: TrainerToolbox) -> None:
323354 def train (self , num_epochs : int , * , note : str = "" , num_checkpoints : int = 5 , ema : bool = True ,
324355 seed : int | None = None , early_stop_tolerance : int = 5 , val_score_prediction : bool = True ,
325356 val_score_prediction_degree : int = 5 , save_preview : bool = True , preview_quality : float = .75 ) -> None :
357+ training_arguments = locals ()
326358 self .init_experiment ()
327359 if note :
328360 self .log (f"Note: { note } " )
@@ -349,7 +381,7 @@ def train(self, num_epochs: int, *, note: str = "", num_checkpoints: int = 5, em
349381 sanity_check_result .num_macs , sanity_check_result .num_params , num_epochs ,
350382 early_stop_tolerance )
351383 try :
352- for epoch in range (1 , num_epochs + 1 ):
384+ for epoch in range (self . _tracker . epoch , self . _tracker . epoch + num_epochs ):
353385 if early_stop_tolerance == - 1 :
354386 epoch -= 1
355387 self .log (f"Early stopping triggered because the validation score has not improved for {
@@ -400,6 +432,7 @@ def train(self, num_epochs: int, *, note: str = "", num_checkpoints: int = 5, em
400432 self .save_metrics ()
401433 self .save_progress ()
402434 self .save_metric_curves ()
435+ self .save_everything_for_recovery (toolbox , self ._tracker , ** training_arguments )
403436 self ._frontend .on_experiment_updated (self ._experiment_id , epoch , self ._metrics , early_stop_tolerance )
404437 except Exception as e :
405438 self .log ("Training interrupted" )
0 commit comments