Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 11 additions & 8 deletions bigearthnet/configs/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,11 @@ optimizer:

logger:
# This project uses tensorboard as a logger.
_target_: pytorch_lightning.loggers.TensorBoardLogger
save_dir: "." # actual save_dir will be set by hydra
_target_: aim.pytorch_lightning.AimLogger
repo: "${oc.env:HOME}/bigearthnet/bigearthnet/outputs/.aim" # where to gather aim runs
experiment: ${experiment.group}
train_metric_prefix: 'train_'
val_metric_prefix: 'val_'

loss:
class_weights: null # specify a path to the class_weights.json file if you want to re-balance the loss.
Expand All @@ -32,15 +35,15 @@ callbacks:
filename: "last-model"
- _target_: pytorch_lightning.callbacks.ModelCheckpoint # checkpoints the best model according to val metric
save_top_k: 1
monitor: ${monitor.name}/val
monitor: ${logger.val_metric_prefix}${monitor.name} # e.g. val_f1_score
mode: ${monitor.mode}
filename: "best-model"
- _target_: pytorch_lightning.callbacks.EarlyStopping # stops training after N epochs if val metric hasn't improved
monitor: ${monitor.name}/val
monitor: ${logger.val_metric_prefix}${monitor.name} # e.g. val_f1_score
mode: ${monitor.mode}
patience: ${monitor.patience}
- _target_: bigearthnet.utils.callbacks.ReproducibilityLogging # logs useful system info
- _target_: bigearthnet.utils.callbacks.MonitorHyperParameters # monitors hyper parameters for tensorboard
- _target_: bigearthnet.utils.callbacks.MonitorBestValues # monitors hyper parameters for tensorboard

trainer:
_target_: pytorch_lightning.Trainer
Expand All @@ -50,15 +53,15 @@ trainer:
profiler: "pytorch" # Profiles GPU usage, can be viewed in tensorboard

experiment:
group: ??? # Useful to group experiments when doing hyper-parameter tuning
group: "default_group" # Useful to group experiments when doing hyper-parameter tuning
seed: ??? # Set for reproducible experiments

hydra:
run:
# Specifies where to store all training aretefacts (model checkpoints, logs, results, etc.)
dir: outputs/${datamodule.dataset_name}/${oc.select:experiment.group,default_group}/${now:%Y-%m-%dT%H:%M:%S}/${model.model_name}_lr_${optimizer.lr}_${optimizer.name}/
dir: outputs/${datamodule.dataset_name}/${experiment.group}/${now:%Y-%m-%dT%H:%M:%S}/${model.model_name}_lr_${optimizer.lr}_${optimizer.name}/
sweep:
dir: outputs/${datamodule.dataset_name}/${oc.select:experiment.group,default_group}/${now:%Y-%m-%dT%H:%M:%S}/multirun/
dir: outputs/${datamodule.dataset_name}/${experiment.group}/${now:%Y-%m-%dT%H:%M:%S}/multirun/
subdir: ${model.model_name}_lr_${optimizer.lr}_${optimizer.name}/
job:
chdir: true
Expand Down
23 changes: 16 additions & 7 deletions bigearthnet/models/bigearthnet_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def __init__(self, cfg: DictConfig):
super().__init__()
self.cfg = cfg
self.model = instantiate(cfg.model)
self.save_hyperparameters(cfg, logger=False)
self.save_hyperparameters(cfg)
self.init_loss()

self.class_names = get_class_names()
Expand Down Expand Up @@ -117,8 +117,9 @@ def _generic_epoch_end(self, step_outputs):
def training_step(self, batch, batch_idx):
"""Runs a prediction step for training, returning the loss."""
outputs = self._generic_step(batch, batch_idx)
prefix = self.cfg.logger.train_metric_prefix
self.log(
"loss/train",
f"{prefix}loss",
outputs["loss"],
on_step=True,
on_epoch=True,
Expand All @@ -134,8 +135,9 @@ def training_epoch_end(self, training_step_outputs):
def validation_step(self, batch, batch_idx):
"""Runs a prediction step for validation, logging the loss."""
outputs = self._generic_step(batch, batch_idx)
prefix = self.cfg.logger.val_metric_prefix
self.log(
"loss/val",
f"{prefix}loss",
outputs["loss"],
on_step=True,
on_epoch=True,
Expand Down Expand Up @@ -173,8 +175,15 @@ def log_metrics(self, metrics: typing.Dict, split: str):
)
log.info(metrics_summary)

if split == "test":
return

prefix = (
self.cfg.logger.train_metric_prefix
if split == "train"
else self.cfg.logger.val_metric_prefix
)
# log metrics to tensorboard
if split in ["train", "val"]:
self.log(f"precision/{split}", metrics["precision"], on_epoch=True)
self.log(f"recall/{split}", metrics["recall"], on_epoch=True)
self.log(f"f1_score/{split}", metrics["f1_score"], on_epoch=True)
self.log(f"{prefix}precision", metrics["precision"], on_epoch=True)
self.log(f"{prefix}recall", metrics["recall"], on_epoch=True)
self.log(f"{prefix}f1_score", metrics["f1_score"], on_epoch=True)
130 changes: 66 additions & 64 deletions bigearthnet/utils/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import matplotlib.pyplot as plt
import numpy as np
from aim import Image
from hydra.utils import get_original_cwd
from omegaconf import DictConfig, OmegaConf
from pip._internal.operations import freeze
Expand Down Expand Up @@ -57,6 +58,28 @@ def _summarize_metrics(metrics, class_names, split, current_epoch=None):
"""


def check_requires_update(current_value, best_value, mode):
"""Returns True if a metric needs to be updated."""
assert mode in ["min", "max"]
if mode == "min" and current_value < best_value:
return True
if mode == "max" and current_value > best_value:
return True
return False


def get_output_dir(trainer):

# FIXME: Find a good way to check which logger is being used
# if using tensorboad
# output_dir = trainer.logger.log_dir

# if using aim
output_dir = trainer.logger.save_dir

return output_dir


class ReproducibilityLogging(Callback):
"""Log experiment details for reproducibility.

Expand Down Expand Up @@ -92,10 +115,12 @@ def log_exp_info(self, trainer, pl_module):
log.info("Experiment info:" + exp_details + "\n")

# dump the omegaconf config for reproducibility
output_dir = os.path.join(trainer.logger.log_dir) if trainer.logger else "."
output_dir = get_output_dir(trainer)
if not os.path.isdir(output_dir):
os.makedirs(output_dir)

# FIXME: Check if this is being saved in the correct experiment folder.
# It might just be overwriting current existing files right now.
OmegaConf.save(cfg, os.path.join(output_dir, "exp_config.yaml"))

def on_train_start(self, trainer, pl_module):
Expand All @@ -105,68 +130,44 @@ def on_test_start(self, trainer, pl_module):
self.log_exp_info(trainer, pl_module)


class MonitorHyperParameters(Callback):
"""Keeps track of hyper parameters in tensorboard.
class MonitorBestValues(Callback):
"""Keeps track of best values throughout training.

Useful for generating parallel coordinates view.
Only tested with AIM logger.
"""

@staticmethod
def extract_hparams(cfg) -> typing.Dict:
"""Select which of the config params to log in logger."""
hparams = {
"optimizer": cfg.optimizer,
"transforms": cfg.transforms.description,
"datamodule": {
k: cfg.datamodule[k] for k in ["batch_size", "dataset_name"]
},
"model": cfg.model,
}
if cfg.model.get("pretrained"):
# tensorboard doesn't log bool values, convert to int
hparams["model"]["pretrained"] = int(hparams["model"]["pretrained"])
return hparams

def init_hparams_metrics(self, trainer, pl_module):
def init_metrics(self, trainer, pl_module):
"""Set up initial metrics associated to hyper params before training.

Initially, all metrics are set to +/- infinity depending on if the value is to be maximized/minimized.
After every validation epoch, if the metric is better than the last, it is updated.
The metric to monitor is the same as the metric used for EarlyStopping, ModelSelection, etc.
"""
cfg = pl_module.cfg
prefix = self.prefix = cfg.logger.val_metric_prefix
best_prefix = self.best_prefix = prefix + "best_" # e.g. val_best_{metric}
init_metrics = {
"val_best_metrics/loss": float("inf"),
"val_best_metrics/precision": float("-inf"),
"val_best_metrics/recall": float("-inf"),
"val_best_metrics/f1_score": float("-inf"),
"loss": float("inf"),
"precision": float("-inf"),
"recall": float("-inf"),
"f1_score": float("-inf"),
}

# verify that the value we want to monitor is valid
monitor_name = pl_module.cfg.monitor.name
possible_monitor_names = ["loss", "precision", "recall", "f1_score"]
if monitor_name not in possible_monitor_names:
self.available_metrics = ["loss", "precision", "recall", "f1_score"]
if monitor_name not in self.available_metrics:
raise ValueError(
f"Specified monitor.name as {monitor_name}. Value to monitor must be one of {possible_monitor_names}"
f"Specified monitor.name as {monitor_name}. Value to monitor must be in {available_metrics}"
)

# set the best value as the initial value
pl_module.val_best_metric = init_metrics[f"val_best_metrics/{monitor_name}"]

# Log the initialized values to tensorboard
trainer.logger.log_hyperparams(
self.extract_hparams(pl_module.cfg), metrics=init_metrics
)

@staticmethod
def requires_update(metrics, mode, name, best_value):
"""Returns True if a metric needs to be updated."""
assert mode in ["min", "max"]
current_value = metrics[name]
if mode == "min" and current_value < best_value:
return True
if mode == "max" and current_value > best_value:
return True
return False
pl_module.val_best_metric = init_metrics[f"{monitor_name}"]
best_metrics = {
f"{best_prefix}{metric}": init_metrics[metric]
for metric in self.available_metrics
}
trainer.logger.log_metrics(best_metrics)

def save_best_metrics(self, trainer, pl_module, split):
"""Saves the best metrics information to disk.
Expand All @@ -176,7 +177,7 @@ def save_best_metrics(self, trainer, pl_module, split):
* the best metrics in a numpy object
* the best confusion matrices in a .png

It will also upload the .png to the tensorboard logger during validation.
It will also upload the .png to the logger during validation.
"""
assert split in ["val", "test"]
class_names = pl_module.class_names
Expand All @@ -194,7 +195,7 @@ def save_best_metrics(self, trainer, pl_module, split):
)

if split == "val":
output_dir = os.path.join(trainer.logger.log_dir) if trainer.logger else "."
output_dir = get_output_dir(trainer)
if split == "test":
output_dir = "."

Expand All @@ -219,40 +220,41 @@ def save_best_metrics(self, trainer, pl_module, split):
conf_mat_figure = _plot_conf_mats(conf_mats, class_names, title=fig_title)

# save confusion matrix figures
fname = os.path.join(output_dir, f"{split}_conf_mats.png")
plt.savefig(fname)
img_fname = os.path.join(output_dir, f"{split}_conf_mats.png")
plt.savefig(img_fname)
plt.close(conf_mat_figure)

# log figure to tensorboard on validation splits only
if split == "val":
trainer.logger.experiment.add_figure(
tag=f"best_confusion_matrix/{split}",
figure=conf_mat_figure,
global_step=pl_module.global_step,
# if aim
aim_image = Image(img_fname, format="jpeg", optimize=True, quality=50)
trainer.logger.experiment.track(
aim_image, name="conf_mat", step=current_epoch
)
plt.close(conf_mat_figure)

def update_best_metric(self, trainer, pl_module, split):
"""Updates the best hparams metrics after validation epoch."""
val_metrics = pl_module.val_metrics # the latest combined metrics
best_value = pl_module.val_best_metric # the best value (so far)
best_prefix = self.best_prefix
mode = pl_module.cfg.monitor.mode
name = pl_module.cfg.monitor.name
current_value = val_metrics[name]
best_value = pl_module.val_best_metric # the best value (so far)

if self.requires_update(val_metrics, mode, name, best_value):
trainer.logger.log_hyperparams(
self.extract_hparams(pl_module.cfg),
metrics={
f"val_best_metrics/{k}": val_metrics[k]
for k in ["loss", "precision", "recall", "f1_score"]
},
)
if check_requires_update(current_value, best_value, mode):

best_metrics = {
f"{best_prefix}{metric}": val_metrics[metric]
for metric in self.available_metrics
}
trainer.logger.log_metrics(best_metrics)
# updates to the new best metric
pl_module.val_best_metric = val_metrics[name]

self.save_best_metrics(trainer, pl_module, split)

def on_train_start(self, trainer, pl_module):
self.init_hparams_metrics(trainer, pl_module)
self.init_metrics(trainer, pl_module)

def on_validation_epoch_end(self, trainer, pl_module):
if not trainer.sanity_checking:
Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
packages=find_packages(include=["bigearthnet", "bigearthnet.*"]),
python_requires=">=3.8",
install_requires=[
"aim",
"gdown",
"gitpython",
"hub",
Expand Down