diff --git a/bigearthnet/configs/config.yaml b/bigearthnet/configs/config.yaml index 6e9423a..d18fa5c 100644 --- a/bigearthnet/configs/config.yaml +++ b/bigearthnet/configs/config.yaml @@ -12,8 +12,11 @@ optimizer: logger: # This project uses tensorboard as a logger. - _target_: pytorch_lightning.loggers.TensorBoardLogger - save_dir: "." # actual save_dir will be set by hydra + _target_: aim.pytorch_lightning.AimLogger + repo: "${oc.env:HOME}/bigearthnet/bigearthnet/outputs/.aim" # where to gather aim runs + experiment: ${experiment.group} + train_metric_prefix: 'train_' + val_metric_prefix: 'val_' loss: class_weights: null # specify a path to the class_weights.json file if you want to re-balance the loss. @@ -32,15 +35,15 @@ callbacks: filename: "last-model" - _target_: pytorch_lightning.callbacks.ModelCheckpoint # checkpoints the best model according to val metric save_top_k: 1 - monitor: ${monitor.name}/val + monitor: ${logger.val_metric_prefix}${monitor.name} # e.g. val_f1_score mode: ${monitor.mode} filename: "best-model" - _target_: pytorch_lightning.callbacks.EarlyStopping # stops training after N epochs if val metric hasn't improved - monitor: ${monitor.name}/val + monitor: ${logger.val_metric_prefix}${monitor.name} # e.g. val_f1_score mode: ${monitor.mode} patience: ${monitor.patience} - _target_: bigearthnet.utils.callbacks.ReproducibilityLogging # logs useful system info - - _target_: bigearthnet.utils.callbacks.MonitorHyperParameters # monitors hyper parameters for tensorboard + - _target_: bigearthnet.utils.callbacks.MonitorBestValues # monitors hyper parameters for tensorboard trainer: _target_: pytorch_lightning.Trainer @@ -50,15 +53,15 @@ trainer: profiler: "pytorch" # Profiles GPU usage, can be viewed in tensorboard experiment: - group: ??? # Useful to group experiments when doing hyper-parameter tuning + group: "default_group" # Useful to group experiments when doing hyper-parameter tuning seed: ??? # Set for reproducible experiments hydra: run: # Specifies where to store all training aretefacts (model checkpoints, logs, results, etc.) - dir: outputs/${datamodule.dataset_name}/${oc.select:experiment.group,default_group}/${now:%Y-%m-%dT%H:%M:%S}/${model.model_name}_lr_${optimizer.lr}_${optimizer.name}/ + dir: outputs/${datamodule.dataset_name}/${experiment.group}/${now:%Y-%m-%dT%H:%M:%S}/${model.model_name}_lr_${optimizer.lr}_${optimizer.name}/ sweep: - dir: outputs/${datamodule.dataset_name}/${oc.select:experiment.group,default_group}/${now:%Y-%m-%dT%H:%M:%S}/multirun/ + dir: outputs/${datamodule.dataset_name}/${experiment.group}/${now:%Y-%m-%dT%H:%M:%S}/multirun/ subdir: ${model.model_name}_lr_${optimizer.lr}_${optimizer.name}/ job: chdir: true diff --git a/bigearthnet/models/bigearthnet_module.py b/bigearthnet/models/bigearthnet_module.py index fe8e926..b6898ed 100644 --- a/bigearthnet/models/bigearthnet_module.py +++ b/bigearthnet/models/bigearthnet_module.py @@ -33,7 +33,7 @@ def __init__(self, cfg: DictConfig): super().__init__() self.cfg = cfg self.model = instantiate(cfg.model) - self.save_hyperparameters(cfg, logger=False) + self.save_hyperparameters(cfg) self.init_loss() self.class_names = get_class_names() @@ -117,8 +117,9 @@ def _generic_epoch_end(self, step_outputs): def training_step(self, batch, batch_idx): """Runs a prediction step for training, returning the loss.""" outputs = self._generic_step(batch, batch_idx) + prefix = self.cfg.logger.train_metric_prefix self.log( - "loss/train", + f"{prefix}loss", outputs["loss"], on_step=True, on_epoch=True, @@ -134,8 +135,9 @@ def training_epoch_end(self, training_step_outputs): def validation_step(self, batch, batch_idx): """Runs a prediction step for validation, logging the loss.""" outputs = self._generic_step(batch, batch_idx) + prefix = self.cfg.logger.val_metric_prefix self.log( - "loss/val", + f"{prefix}loss", outputs["loss"], on_step=True, on_epoch=True, @@ -173,8 +175,15 @@ def log_metrics(self, metrics: typing.Dict, split: str): ) log.info(metrics_summary) + if split == "test": + return + + prefix = ( + self.cfg.logger.train_metric_prefix + if split == "train" + else self.cfg.logger.val_metric_prefix + ) # log metrics to tensorboard - if split in ["train", "val"]: - self.log(f"precision/{split}", metrics["precision"], on_epoch=True) - self.log(f"recall/{split}", metrics["recall"], on_epoch=True) - self.log(f"f1_score/{split}", metrics["f1_score"], on_epoch=True) + self.log(f"{prefix}precision", metrics["precision"], on_epoch=True) + self.log(f"{prefix}recall", metrics["recall"], on_epoch=True) + self.log(f"{prefix}f1_score", metrics["f1_score"], on_epoch=True) diff --git a/bigearthnet/utils/callbacks.py b/bigearthnet/utils/callbacks.py index 642cba4..3d292c5 100644 --- a/bigearthnet/utils/callbacks.py +++ b/bigearthnet/utils/callbacks.py @@ -5,6 +5,7 @@ import matplotlib.pyplot as plt import numpy as np +from aim import Image from hydra.utils import get_original_cwd from omegaconf import DictConfig, OmegaConf from pip._internal.operations import freeze @@ -57,6 +58,28 @@ def _summarize_metrics(metrics, class_names, split, current_epoch=None): """ +def check_requires_update(current_value, best_value, mode): + """Returns True if a metric needs to be updated.""" + assert mode in ["min", "max"] + if mode == "min" and current_value < best_value: + return True + if mode == "max" and current_value > best_value: + return True + return False + + +def get_output_dir(trainer): + + # FIXME: Find a good way to check which logger is being used + # if using tensorboad + # output_dir = trainer.logger.log_dir + + # if using aim + output_dir = trainer.logger.save_dir + + return output_dir + + class ReproducibilityLogging(Callback): """Log experiment details for reproducibility. @@ -92,10 +115,12 @@ def log_exp_info(self, trainer, pl_module): log.info("Experiment info:" + exp_details + "\n") # dump the omegaconf config for reproducibility - output_dir = os.path.join(trainer.logger.log_dir) if trainer.logger else "." + output_dir = get_output_dir(trainer) if not os.path.isdir(output_dir): os.makedirs(output_dir) + # FIXME: Check if this is being saved in the correct experiment folder. + # It might just be overwriting current existing files right now. OmegaConf.save(cfg, os.path.join(output_dir, "exp_config.yaml")) def on_train_start(self, trainer, pl_module): @@ -105,68 +130,44 @@ def on_test_start(self, trainer, pl_module): self.log_exp_info(trainer, pl_module) -class MonitorHyperParameters(Callback): - """Keeps track of hyper parameters in tensorboard. +class MonitorBestValues(Callback): + """Keeps track of best values throughout training. - Useful for generating parallel coordinates view. + Only tested with AIM logger. """ - @staticmethod - def extract_hparams(cfg) -> typing.Dict: - """Select which of the config params to log in logger.""" - hparams = { - "optimizer": cfg.optimizer, - "transforms": cfg.transforms.description, - "datamodule": { - k: cfg.datamodule[k] for k in ["batch_size", "dataset_name"] - }, - "model": cfg.model, - } - if cfg.model.get("pretrained"): - # tensorboard doesn't log bool values, convert to int - hparams["model"]["pretrained"] = int(hparams["model"]["pretrained"]) - return hparams - - def init_hparams_metrics(self, trainer, pl_module): + def init_metrics(self, trainer, pl_module): """Set up initial metrics associated to hyper params before training. Initially, all metrics are set to +/- infinity depending on if the value is to be maximized/minimized. After every validation epoch, if the metric is better than the last, it is updated. The metric to monitor is the same as the metric used for EarlyStopping, ModelSelection, etc. """ + cfg = pl_module.cfg + prefix = self.prefix = cfg.logger.val_metric_prefix + best_prefix = self.best_prefix = prefix + "best_" # e.g. val_best_{metric} init_metrics = { - "val_best_metrics/loss": float("inf"), - "val_best_metrics/precision": float("-inf"), - "val_best_metrics/recall": float("-inf"), - "val_best_metrics/f1_score": float("-inf"), + "loss": float("inf"), + "precision": float("-inf"), + "recall": float("-inf"), + "f1_score": float("-inf"), } # verify that the value we want to monitor is valid monitor_name = pl_module.cfg.monitor.name - possible_monitor_names = ["loss", "precision", "recall", "f1_score"] - if monitor_name not in possible_monitor_names: + self.available_metrics = ["loss", "precision", "recall", "f1_score"] + if monitor_name not in self.available_metrics: raise ValueError( - f"Specified monitor.name as {monitor_name}. Value to monitor must be one of {possible_monitor_names}" + f"Specified monitor.name as {monitor_name}. Value to monitor must be in {available_metrics}" ) # set the best value as the initial value - pl_module.val_best_metric = init_metrics[f"val_best_metrics/{monitor_name}"] - - # Log the initialized values to tensorboard - trainer.logger.log_hyperparams( - self.extract_hparams(pl_module.cfg), metrics=init_metrics - ) - - @staticmethod - def requires_update(metrics, mode, name, best_value): - """Returns True if a metric needs to be updated.""" - assert mode in ["min", "max"] - current_value = metrics[name] - if mode == "min" and current_value < best_value: - return True - if mode == "max" and current_value > best_value: - return True - return False + pl_module.val_best_metric = init_metrics[f"{monitor_name}"] + best_metrics = { + f"{best_prefix}{metric}": init_metrics[metric] + for metric in self.available_metrics + } + trainer.logger.log_metrics(best_metrics) def save_best_metrics(self, trainer, pl_module, split): """Saves the best metrics information to disk. @@ -176,7 +177,7 @@ def save_best_metrics(self, trainer, pl_module, split): * the best metrics in a numpy object * the best confusion matrices in a .png - It will also upload the .png to the tensorboard logger during validation. + It will also upload the .png to the logger during validation. """ assert split in ["val", "test"] class_names = pl_module.class_names @@ -194,7 +195,7 @@ def save_best_metrics(self, trainer, pl_module, split): ) if split == "val": - output_dir = os.path.join(trainer.logger.log_dir) if trainer.logger else "." + output_dir = get_output_dir(trainer) if split == "test": output_dir = "." @@ -219,40 +220,41 @@ def save_best_metrics(self, trainer, pl_module, split): conf_mat_figure = _plot_conf_mats(conf_mats, class_names, title=fig_title) # save confusion matrix figures - fname = os.path.join(output_dir, f"{split}_conf_mats.png") - plt.savefig(fname) + img_fname = os.path.join(output_dir, f"{split}_conf_mats.png") + plt.savefig(img_fname) + plt.close(conf_mat_figure) # log figure to tensorboard on validation splits only if split == "val": - trainer.logger.experiment.add_figure( - tag=f"best_confusion_matrix/{split}", - figure=conf_mat_figure, - global_step=pl_module.global_step, + # if aim + aim_image = Image(img_fname, format="jpeg", optimize=True, quality=50) + trainer.logger.experiment.track( + aim_image, name="conf_mat", step=current_epoch ) - plt.close(conf_mat_figure) def update_best_metric(self, trainer, pl_module, split): """Updates the best hparams metrics after validation epoch.""" val_metrics = pl_module.val_metrics # the latest combined metrics - best_value = pl_module.val_best_metric # the best value (so far) + best_prefix = self.best_prefix mode = pl_module.cfg.monitor.mode name = pl_module.cfg.monitor.name + current_value = val_metrics[name] + best_value = pl_module.val_best_metric # the best value (so far) - if self.requires_update(val_metrics, mode, name, best_value): - trainer.logger.log_hyperparams( - self.extract_hparams(pl_module.cfg), - metrics={ - f"val_best_metrics/{k}": val_metrics[k] - for k in ["loss", "precision", "recall", "f1_score"] - }, - ) + if check_requires_update(current_value, best_value, mode): + + best_metrics = { + f"{best_prefix}{metric}": val_metrics[metric] + for metric in self.available_metrics + } + trainer.logger.log_metrics(best_metrics) # updates to the new best metric pl_module.val_best_metric = val_metrics[name] self.save_best_metrics(trainer, pl_module, split) def on_train_start(self, trainer, pl_module): - self.init_hparams_metrics(trainer, pl_module) + self.init_metrics(trainer, pl_module) def on_validation_epoch_end(self, trainer, pl_module): if not trainer.sanity_checking: diff --git a/setup.py b/setup.py index 1deec00..344ea60 100644 --- a/setup.py +++ b/setup.py @@ -7,6 +7,7 @@ packages=find_packages(include=["bigearthnet", "bigearthnet.*"]), python_requires=">=3.8", install_requires=[ + "aim", "gdown", "gitpython", "hub",