-
Notifications
You must be signed in to change notification settings - Fork 15
Expand file tree
/
Copy pathutils.py
More file actions
49 lines (37 loc) · 1.71 KB
/
utils.py
File metadata and controls
49 lines (37 loc) · 1.71 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
import re
from lightning.pytorch import Callback, Trainer, LightningModule
class LrLogger(Callback):
"""Log learning rate in each epoch start."""
def on_train_epoch_start(self, trainer: Trainer, pl_module: LightningModule) -> None:
for i, optimizer in enumerate(trainer.optimizers):
for j, params in enumerate(optimizer.param_groups):
key = f"opt{i}_lr{j}"
value = params["lr"]
pl_module.logger.log_metrics({key: value}, step=trainer.global_step)
pl_module.log(key, value, logger=False, sync_dist=pl_module.distributed)
class EarlyStoppingLR(Callback):
"""Early stop model training when the LR is lower than threshold."""
def __init__(self, lr_threshold: float, mode="all"):
self.lr_threshold = lr_threshold
if mode in ("any", "all"):
self.mode = mode
else:
raise ValueError(f"mode must be one of ('any', 'all')")
def on_train_epoch_start(self, trainer: Trainer, pl_module: LightningModule) -> None:
self._run_early_stop_checking(trainer)
def _run_early_stop_checking(self, trainer: Trainer) -> None:
metrics = trainer._logger_connector.callback_metrics
if len(metrics) == 0:
return
all_lr = []
for key, value in metrics.items():
if re.match(r"opt\d+_lr\d+", key):
all_lr.append(value)
if len(all_lr) == 0:
return
if self.mode == "all":
if all(lr <= self.lr_threshold for lr in all_lr):
trainer.should_stop = True
elif self.mode == "any":
if any(lr <= self.lr_threshold for lr in all_lr):
trainer.should_stop = True