Skip to content

Commit 2720dde

Browse files
Tonyrwightman
authored andcommitted
initial commit:
- added val-interval argument. Eval and checkpointing is only applied every val-interval epochs. - Changed `float` to `Optional[float]` in typing Scheduler step function parameter `metric` - Skipping step of base scheduler in plateau scheduler to avoid `TypeError` when converting `None` to `float` - added `or last_batch` to logging logic during training to be consistent with validation
1 parent fa2e3cc commit 2720dde

File tree

3 files changed

+20
-6
lines changed

3 files changed

+20
-6
lines changed

timm/scheduler/plateau_lr.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
Hacked together by / Copyright 2020 Ross Wightman
66
"""
77
import torch
8-
from typing import List
8+
from typing import List, Optional
99

1010
from .scheduler import Scheduler
1111

@@ -86,12 +86,14 @@ def step(self, epoch, metric=None):
8686
param_group['lr'] = self.restore_lr[i]
8787
self.restore_lr = None
8888

89-
self.lr_scheduler.step(metric, epoch) # step the base scheduler
89+
# step the base scheduler if metric given
90+
if metric is not None:
91+
self.lr_scheduler.step(metric, epoch)
9092

9193
if self._is_apply_noise(epoch):
9294
self._apply_noise(epoch)
9395

94-
def step_update(self, num_updates: int, metric: float = None):
96+
def step_update(self, num_updates: int, metric: Optional[float] = None):
9597
return None
9698

9799
def _apply_noise(self, epoch):

timm/scheduler/scheduler.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -74,14 +74,14 @@ def _get_values(self, t: int, on_epoch: bool = True) -> Optional[List[float]]:
7474
return None
7575
return self._get_lr(t)
7676

77-
def step(self, epoch: int, metric: float = None) -> None:
77+
def step(self, epoch: int, metric: Optional[float] = None) -> None:
7878
self.metric = metric
7979
values = self._get_values(epoch, on_epoch=True)
8080
if values is not None:
8181
values = self._add_noise(values, epoch)
8282
self.update_groups(values)
8383

84-
def step_update(self, num_updates: int, metric: float = None):
84+
def step_update(self, num_updates: int, metric: Optional[float] = None):
8585
self.metric = metric
8686
values = self._get_values(num_updates, on_epoch=False)
8787
if values is not None:

train.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -371,6 +371,8 @@
371371
help='worker seed mode (default: all)')
372372
group.add_argument('--log-interval', type=int, default=50, metavar='N',
373373
help='how many batches to wait before logging training status')
374+
group.add_argument('--val-interval', type=int, default=1, metavar='N',
375+
help='how many epochs between validation and checkpointing')
374376
group.add_argument('--recovery-interval', type=int, default=0, metavar='N',
375377
help='how many batches to wait before writing recovery checkpoint')
376378
group.add_argument('--checkpoint-hist', type=int, default=10, metavar='N',
@@ -1034,6 +1036,16 @@ def main():
10341036
_logger.info("Distributing BatchNorm running means and vars")
10351037
utils.distribute_bn(model, args.world_size, args.dist_bn == 'reduce')
10361038

1039+
if (epoch + 1) % args.val_interval != 0:
1040+
if utils.is_primary(args):
1041+
_logger.info("Skipping eval and checkpointing ")
1042+
if lr_scheduler is not None:
1043+
# step LR for next epoch
1044+
# careful when using metric dependent lr_scheduler
1045+
lr_scheduler.step(epoch + 1, metric=None)
1046+
# skip validation and metric logic
1047+
continue
1048+
10371049
if loader_eval is not None:
10381050
eval_metrics = validate(
10391051
model,
@@ -1287,7 +1299,7 @@ def _backward(_loss):
12871299
update_time_m.update(time.time() - update_start_time)
12881300
update_start_time = time_now
12891301

1290-
if update_idx % args.log_interval == 0:
1302+
if update_idx % args.log_interval == 0 or last_batch:
12911303
lrl = [param_group['lr'] for param_group in optimizer.param_groups]
12921304
lr = sum(lrl) / len(lrl)
12931305

0 commit comments

Comments
 (0)