diff --git a/.github/workflows/mr_ci.yml b/.github/workflows/mr_ci.yml index 6d7a72a0..201dcf53 100644 --- a/.github/workflows/mr_ci.yml +++ b/.github/workflows/mr_ci.yml @@ -37,7 +37,7 @@ jobs: - name: Install dependencies run: | - python -m pip install ".[dev]" + python -m pip install ".[dev, lightning]" python -m pip install pytest-cov - name: Quality Assurance diff --git a/.github/workflows/mr_ci_text_spotting.yml b/.github/workflows/mr_ci_text_spotting.yml index 9bbffc55..3adc196e 100644 --- a/.github/workflows/mr_ci_text_spotting.yml +++ b/.github/workflows/mr_ci_text_spotting.yml @@ -58,7 +58,7 @@ jobs: run: | python -m pip install --no-cache-dir wheel python -m pip install --no-cache-dir numpy==1.26.4 torch==2.2.2 torchvision==0.17.2 -f https://download.pytorch.org/whl/torch_stable.html - python -m pip install --no-cache-dir ".[dev]" + python -m pip install --no-cache-dir ".[dev, lightning]" python -m pip install --no-cache-dir pytest-cov python -m pip install --no-cache-dir --no-build-isolation 'git+https://github.com/facebookresearch/detectron2.git' python -m pip install --no-cache-dir --no-build-isolation 'git+https://github.com/maps-as-data/DeepSolo.git' diff --git a/CHANGELOG.md b/CHANGELOG.md index e2984d46..77c93a8e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -14,6 +14,10 @@ The following table shows which versions of MapReader are compatible with which ## Pre-release _Add new changes here_ +### Added + +- Added `LightningClassifierContainer` to support multi-GPU training via `lightning.pytorch.Trainer` + ### Fixed - Fixes the `model_summary` method in the `ClassifierContainer` class ([#574](https://github.com/maps-as-data/MapReader/pull/574)) diff --git a/docs/source/using-mapreader/step-by-step-guide/4-classify/train.rst b/docs/source/using-mapreader/step-by-step-guide/4-classify/train.rst index f76ae0ee..b4813315 100644 --- a/docs/source/using-mapreader/step-by-step-guide/4-classify/train.rst +++ b/docs/source/using-mapreader/step-by-step-guide/4-classify/train.rst @@ -349,8 +349,8 @@ There are a number of options for the ``model`` argument: MapReader will automatically download the model and its corresponding image processor from the Hugging Face Hub using the `transformers `__ library. e.g. `This model `__ is based on our `*gold standard* dataset `__. - It can be loaded directly like this: - + It can be loaded directly like this: + .. code-block:: python #EXAMPLE @@ -772,3 +772,99 @@ Or, if your maps are georeferenced, you can use the ``explore_patches`` method i ) Refer to the :doc:`Load ` user guidance for further details on how these methods work. + +---- + +Using multiple GPUs - ``LightningClassifierContainer`` +-------------------------------------------------------- + +.. note:: This is a new feature and so is currently in beta. Please let us know if you have any issues using this or if you have any suggestions for improvement! + +If you have access to multiple GPUs for training your classification model, you can use the ``LightningClassifierContainer`` class to take advantage of these. + +The ``LightningClassifierContainer`` mirrors the ``ClassifierContainer`` but delegates the training to `PyTorch Lightning `__. +This makes it straightforward to train on multiple GPUs. + +.. note:: + + You will need to install the ``lightning`` dependency group to use this class with ``pip install -e .[lightning]`` + +Initialize ``LightningClassifierContainer`` +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +Set up the classifier exactly as you would with ``ClassifierContainer``: + +.. code-block:: python + + from mapreader import LightningClassifierContainer + + my_classifier = LightningClassifierContainer( + "resnet18", + labels_map=annotated_images.labels_map, + dataloaders=dataloaders, + ) + +The same model options available for ``ClassifierContainer`` (torchvision model names, custom ``nn.Module``, ``load_path``, etc.) are supported. + + +Define loss function, optimizer and scheduler +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +Again, you can define your loss function, optimizer and scheduler using the same methods as for ``ClassifierContainer``: + +.. code-block:: python + + my_classifier.add_loss_fn("cross entropy") + my_classifier.initialize_optimizer("adam") + my_classifier.initialize_scheduler() + + +Train with a Lightning Trainer +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +You will then need to pass your classifier to a ``lightning.pytorch.Trainer`` and call ``fit``: + +.. code-block:: python + + from lightning.pytorch import Trainer + + trainer = Trainer(max_epochs=25) + trainer.fit( + my_classifier, + train_dataloaders=dataloaders["train"], + val_dataloaders=dataloaders["val"], + ) + +The Trainer handles device placement (i.e. will distribute your training across multiple GPUs), logging, and checkpointing. +Metrics are computed each epoch (loss, precision, recall, f-score, ROC AUC) and stored in ``my_classifier.metrics`` and can be plotted with ``my_classifier.plot_metric()``. + +You can also explicitly specify the device to train on by passing the ``devices`` and ``accelerator`` arguments to the Trainer: + +.. code-block:: python + + trainer = Trainer( + max_epochs=25, + devices=2, # number of GPUs to use + accelerator="cuda", # type of device to use (e.g. "cuda", "mps" or "cpu") + ) + trainer.fit( + my_classifier, + train_dataloaders=dataloaders["train"], + val_dataloaders=dataloaders["val"], + ) + + +Save and load +~~~~~~~~~~~~~~ + +As with the ``ClassifierContainer`` you can save your trained model and classifier container using the ``save`` method: + +.. code-block:: python + + my_classifier.save("my_lightning_classifier.pkl") + +And you can load this file back in using the ``load_path`` argument when initializing a new ``LightningClassifierContainer``: + +.. code-block:: python + + loaded = LightningClassifierContainer(None, None, None, load_path="my_lightning_classifier.pkl") diff --git a/mapreader/__init__.py b/mapreader/__init__.py index 879bdf93..e108e806 100644 --- a/mapreader/__init__.py +++ b/mapreader/__init__.py @@ -22,6 +22,10 @@ from mapreader.classify.datasets import PatchDataset from mapreader.classify.datasets import PatchContextDataset from mapreader.classify.classifier import ClassifierContainer +try: + from mapreader.classify.lightning_classifier import LightningClassifierContainer +except ImportError: + pass from mapreader.classify import custom_models # spot_text diff --git a/mapreader/classify/classifier.py b/mapreader/classify/classifier.py index 5be5a6cd..2ae1b8da 100644 --- a/mapreader/classify/classifier.py +++ b/mapreader/classify/classifier.py @@ -70,8 +70,6 @@ class ClassifierContainer: A dictionary to store dataloaders for the model. labels_map : dict A dictionary mapping label indices to their labels. - dataset_sizes : dict - A dictionary to store sizes of datasets for the model. model : torch.nn.Module The model. input_size : Tuple of int @@ -152,7 +150,10 @@ def __init__( elif isinstance(model, str): if huggingface: try: - from transformers import AutoModelForImageClassification, AutoImageProcessor + from transformers import ( + AutoImageProcessor, + AutoModelForImageClassification, + ) except ImportError: raise ImportError( "Hugging Face models require the 'transformers' library: 'pip install transformers'." @@ -160,9 +161,7 @@ def __init__( print(f"[INFO] Initializing Hugging Face model: {model}") num_labels = len(self.labels_map) self.model = AutoModelForImageClassification.from_pretrained( - model, - num_labels=num_labels, - ignore_mismatched_sizes=True + model, num_labels=num_labels, ignore_mismatched_sizes=True ).to(self.device) hf_processor = AutoImageProcessor.from_pretrained(model) size = getattr(hf_processor, "size", {}) @@ -289,7 +288,7 @@ def _initialize_model( num_ftrs = model_dw.fc.in_features model_dw.fc = nn.Linear(num_ftrs, last_layer_num_classes) is_inception = True - input_size = 299 + input_size = (299, 299) else: raise NotImplementedError( @@ -340,7 +339,7 @@ def generate_layerwise_lrs( elif spacing.lower() in ["log", "geomspace"]: lrs = np.geomspace(min_lr, max_lr, len(list(self.model.named_parameters()))) params2optimize = [ - {"params": params, "learning rate": lr} + {"params": params, "lr": lr} for (_, params), lr in zip(self.model.named_parameters(), lrs) ] @@ -587,10 +586,14 @@ def model_summary( if trainable_col: col_names = ["num_params", "output_size", "trainable"] else: - col_names = ["output_size", "output_size", "num_params"] + col_names = ["input_size", "output_size", "num_params"] model_summary = summary( - self.model, input_size=input_size, col_names=col_names, device=self.device, **kwargs + self.model, + input_size=input_size, + col_names=col_names, + device=self.device, + **kwargs, ) print(model_summary) @@ -1109,7 +1112,7 @@ def train_core( best_model_wts = copy.deepcopy(self.model.state_dict()) if phase.lower() in valid_phase_names: - if epoch % tmp_file_save_freq == 0: + if tmp_file_save_freq and epoch % tmp_file_save_freq == 0: tmp_str = f'[INFO] Checkpoint file saved to "{self.tmp_save_filename}".' # noqa print( self._print_colors["lgrey"] @@ -1149,7 +1152,7 @@ def _get_logits(out): try: out = out.logits except AttributeError as err: - raise AttributeError(err.message) + raise AttributeError(str(err)) return out def _gen_epoch_msg(self, phase: str, epoch_msg: str) -> str: @@ -1236,8 +1239,9 @@ def calculate_add_metrics( y_score = np.array(y_score) for average in [None, "micro", "macro", "weighted"]: + labels = list(range(y_score.shape[1])) if average is None else None precision, recall, fscore, support = precision_recall_fscore_support( - y_true, y_pred, average=average + y_true, y_pred, average=average, labels=labels ) if average is None: diff --git a/mapreader/classify/lightning_classifier.py b/mapreader/classify/lightning_classifier.py new file mode 100644 index 00000000..7d14659f --- /dev/null +++ b/mapreader/classify/lightning_classifier.py @@ -0,0 +1,1591 @@ +#!/usr/bin/env python +from __future__ import annotations + +import copy +import os +from collections.abc import Iterable +from typing import Any + +import joblib +import matplotlib.pyplot as plt +import numpy as np +import torch +import torch.nn as nn +from lightning.pytorch import LightningModule +from matplotlib.ticker import MaxNLocator +from sklearn.metrics import precision_recall_fscore_support, roc_auc_score +from torch import optim +from torch.utils.data import DataLoader, Sampler +from torchinfo import summary +from torchvision import models + +from .datasets import PatchDataset + + +class LightningClassifierContainer(LightningModule): + """ + A class to store and train a PyTorch model using PyTorch Lightning. + + Mirrors the API of ``ClassifierContainer``. Set up the model exactly as + you would with ``ClassifierContainer``, then pass it to a + ``lightning.pytorch.Trainer``:: + + classifier = LightningClassifierContainer(model="resnet18", labels_map={0: "no", 1: "yes"}) + classifier.add_loss_fn("cross entropy") + classifier.initialize_optimizer("adam") + classifier.initialize_scheduler("steplr") + + trainer = Trainer(max_epochs=25) + trainer.fit(classifier, train_dataloaders=dataloaders["train"], + val_dataloaders=dataloaders["val"]) + + For inference without a Trainer use ``classifier.inference(set_name)``. + For distributed / GPU inference use ``trainer.predict()``. + + Parameters + ---------- + model : str, nn.Module or None + The PyTorch model to add to the object. + + - If passed as a string, will run ``_initialize_model(model, **kwargs)``. See https://pytorch.org/vision/0.8/models.html for options. + - Must be ``None`` if ``load_path`` is specified as model will be loaded from file. + + labels_map: Dict or None + A dictionary containing the mapping of each label index to its label, with indices as keys and labels as values (i.e. idx: label). + Can only be ``None`` if ``load_path`` is specified as labels_map will be loaded from file. + dataloaders: Dict or None + A dictionary containing set names as keys and dataloaders as values (i.e. set_name: dataloader). + input_size : int, optional + The expected input size of the model. Default is ``(224,224)``. + is_inception : bool, optional + Whether the model is an Inception-style model. + Default is ``False``. + load_path : str, optional + The path to an ``.obj`` file containing a previously saved ``LightningClassifierContainer``. + force_device : bool, optional + Whether to force the use of a specific device. + If set to ``True``, the default device is used. + Defaults to ``False``. + kwargs : Dict + Keyword arguments to pass to the + :meth:`~.classify.lightning_classifier.LightningClassifierContainer._initialize_model` + method (if passing ``model`` as a string). + + Attributes + ---------- + dataloaders : dict + A dictionary to store dataloaders for the model. + labels_map : dict + A dictionary mapping label indices to their labels. + model : torch.nn.Module + The model. + input_size : Tuple of int + The size of the input to the model. + is_inception : bool + A flag indicating if the model is an Inception model. + optimizer : None or torch.optim.Optimizer + The optimizer being used for training the model. + scheduler : None or torch.optim.lr_scheduler._LRScheduler + The learning rate scheduler being used for training the model. + loss_fn : None or nn.modules.loss._Loss + The loss function to use for training the model. + metrics : dict + A dictionary to store the metrics computed during training. + last_epoch : int + The last epoch number completed during training. + best_loss : float + The best validation loss achieved during training. + best_epoch : int + The epoch in which the best validation loss was achieved during + training. + """ + + def __init__( + self, + model: str | nn.Module | None = None, + labels_map: dict[int, str] | None = None, + dataloaders: dict[str, DataLoader] | None = None, + input_size: int = (224, 224), + is_inception: bool = False, + load_path: str | None = None, + force_device: bool = False, + huggingface: bool = False, + **kwargs, + ): + super().__init__() + + # check if loading an pre-existing object + if load_path: + if model: + print("[WARNING] Ignoring ``model`` as ``load_path`` is specified.") + if labels_map: + print( + "[WARNING] Ignoring ``labels_map`` as ``load_path`` is specified." + ) + + # load object + self.load(load_path=load_path, force_device=force_device) + + # add any extra dataloaders + if dataloaders: + for set_name, dataloader in dataloaders.items(): + self.dataloaders[set_name] = dataloader + + else: + if model is None or labels_map is None: + raise ValueError( + "[ERROR] ``model`` and ``labels_map`` must be defined." + ) + + self.labels_map = labels_map + + # set up model (Lightning manages device placement) + print("[INFO] Initializing model.") + if isinstance(model, nn.Module): + self.model = model + self.input_size = input_size + self.is_inception = is_inception + elif isinstance(model, str): + if huggingface: + try: + from transformers import ( + AutoImageProcessor, + AutoModelForImageClassification, + ) + except ImportError: + raise ImportError( + "Hugging Face models require the 'transformers' library: 'pip install transformers'." + ) + print(f"[INFO] Initializing Hugging Face model: {model}") + num_labels = len(self.labels_map) + self.model = AutoModelForImageClassification.from_pretrained( + model, num_labels=num_labels, ignore_mismatched_sizes=True + ) + hf_processor = AutoImageProcessor.from_pretrained(model) + size = getattr(hf_processor, "size", {}) + if "height" in size and "width" in size: + self.input_size = (size["height"], size["width"]) + elif "shortest_edge" in size: + self.input_size = (size["shortest_edge"], size["shortest_edge"]) + else: + self.input_size = input_size + self.is_inception = False + else: + self._initialize_model(model, **kwargs) + + self.optimizer = None + self.scheduler = None + self.loss_fn = None + + self.metrics = {} + self.last_epoch = 0 + self.best_loss = np.inf + self.best_epoch = 0 + + self.pred_conf = [] + self.pred_label_indices = [] + self.pred_label = [] + self.gt_label_indices = [] + + # add dataloaders and labels_map + self.dataloaders = dataloaders if dataloaders else {} + + for set_name, dataloader in self.dataloaders.items(): + print(f'[INFO] Loaded "{set_name}" with {len(dataloader.dataset)} items.') + + def _initialize_model( + self, + model_name: str, + weights: str | None = "DEFAULT", + last_layer_num_classes: str | int | None = "default", + ) -> tuple[Any, int, bool]: + """ + Initializes a PyTorch model with the option to change the number of + classes in the last layer (``last_layer_num_classes``). + + Parameters + ---------- + model_name : str + Name of a PyTorch model. See https://pytorch.org/vision/0.8/models.html for options. + weights : str, optional + Weights to load into the model. If ``"DEFAULT"``, loads the default weights for the chosen model. + By default, ``"DEFAULT"``. + last_layer_num_classes : str or int, optional + Number of elements in the last layer. If ``"default"``, sets it to + the number of classes. By default, ``"default"``. + + Returns + ------- + model : PyTorch model + The initialized PyTorch model with the changed last layer. + input_size : int + Input size of the model. + is_inception : bool + True if the model is Inception v3. + + Raises + ------ + ValueError + If an invalid model name is passed. + + Notes + ----- + Inception v3 requires the input size to be ``(299, 299)``, whereas all + of the other models expect ``(224, 224)``. + + See https://pytorch.org/vision/0.8/models.html. + """ + + # Initialize these variables which will be set in this if statement. + # Each of these variables is model specific. + model_dw = models.get_model(model_name, weights=weights) + input_size = (224, 224) + is_inception = False + + if last_layer_num_classes in ["default"]: + last_layer_num_classes = len(self.labels_map) + else: + last_layer_num_classes = int(last_layer_num_classes) + + if "resnet" in model_name: + num_ftrs = model_dw.fc.in_features + model_dw.fc = nn.Linear(num_ftrs, last_layer_num_classes) + + elif "alexnet" in model_name: + num_ftrs = model_dw.classifier[6].in_features + model_dw.classifier[6] = nn.Linear(num_ftrs, last_layer_num_classes) + + elif "vgg" in model_name: + # vgg11_bn + num_ftrs = model_dw.classifier[6].in_features + model_dw.classifier[6] = nn.Linear(num_ftrs, last_layer_num_classes) + + elif "squeezenet" in model_name: + model_dw.classifier[1] = nn.Conv2d( + 512, last_layer_num_classes, kernel_size=(1, 1), stride=(1, 1) + ) + model_dw.num_classes = last_layer_num_classes + + elif "densenet" in model_name: + num_ftrs = model_dw.classifier.in_features + model_dw.classifier = nn.Linear(num_ftrs, last_layer_num_classes) + + elif "inception" in model_name: + # Inception v3: + # Be careful, expects (299,299) sized images + has auxiliary output + + # Handle the auxilary net + num_ftrs = model_dw.AuxLogits.fc.in_features + model_dw.AuxLogits.fc = nn.Linear(num_ftrs, last_layer_num_classes) + # Handle the primary net + num_ftrs = model_dw.fc.in_features + model_dw.fc = nn.Linear(num_ftrs, last_layer_num_classes) + is_inception = True + input_size = (299, 299) + + else: + raise NotImplementedError( + "[ERROR] Invalid model name. Try loading your model directly and this as the `model` argument instead." + ) + + # Lightning manages device placement; do not call .to(self.device) here. + self.model = model_dw + self.input_size = input_size + self.is_inception = is_inception + + def generate_layerwise_lrs( + self, + min_lr: float, + max_lr: float, + spacing: str | None = "linspace", + ) -> list[dict]: + """ + Calculates layer-wise learning rates for a given set of model + parameters. + + Parameters + ---------- + min_lr : float + The minimum learning rate to be used. + max_lr : float + The maximum learning rate to be used. + spacing : str, optional + The type of sequence to use for spacing the specified interval + learning rates. Can be either ``"linspace"`` or ``"geomspace"``, + where `"linspace"` uses evenly spaced learning rates over a + specified interval and `"geomspace"` uses learning rates spaced + evenly on a log scale (a geometric progression). By default ``"linspace"``. + + Returns + ------- + list of dicts + A list of dictionaries containing the parameters and learning + rates for each layer. + """ + + if spacing.lower() not in ["linspace", "geomspace"]: + raise NotImplementedError( + '[ERROR] ``spacing`` must be one of "linspace" or "geomspace"' + ) + + if spacing.lower() == "linspace": + lrs = np.linspace(min_lr, max_lr, len(list(self.model.named_parameters()))) + elif spacing.lower() in ["log", "geomspace"]: + lrs = np.geomspace(min_lr, max_lr, len(list(self.model.named_parameters()))) + params2optimize = [ + {"params": params, "lr": lr} + for (_, params), lr in zip(self.model.named_parameters(), lrs) + ] + + return params2optimize + + def initialize_optimizer( + self, + optim_type: str = "adam", + params2optimize: str | Iterable = "default", + optim_param_dict: dict | None = None, + ): + """ + Initializes an optimizer for the model and adds it to the classifier + object. + + Parameters + ---------- + optim_type : str, optional + The type of optimizer to use. Can be set to ``"adam"`` (default), + ``"adamw"``, or ``"sgd"``. + params2optimize : str or iterable, optional + The parameters to optimize. If set to ``"default"``, all model + parameters that require gradients will be optimized. + Default is ``"default"``. + optim_param_dict : dict, optional + The parameters to pass to the optimizer constructor as a + dictionary, by default ``{"lr": 1e-3}``. + + Notes + ----- + Note that the first argument of an optimizer is parameters to optimize, + e.g. ``params2optimize = model_ft.parameters()``: + + - ``model_ft.parameters()``: all parameters are being optimized + - ``model_ft.fc.parameters()``: only parameters of final layer are being optimized + + Here, we use: + + .. code-block:: python + + filter(lambda p: p.requires_grad, self.model.parameters()) + """ + if optim_param_dict is None: + optim_param_dict = {"lr": 0.001} + if params2optimize == "default": + params2optimize = filter(lambda p: p.requires_grad, self.model.parameters()) + + if optim_type.lower() in ["adam"]: + optimizer = optim.Adam(params2optimize, **optim_param_dict) + elif optim_type.lower() in ["adamw"]: + optimizer = optim.AdamW(params2optimize, **optim_param_dict) + elif optim_type.lower() in ["sgd"]: + optimizer = optim.SGD(params2optimize, **optim_param_dict) + else: + raise NotImplementedError( + '[ERROR] At present, only Adam ("adam"), AdamW ("adamw") and SGD ("sgd") are options for ``optim_type``.' + ) + + self.add_optimizer(optimizer) + + def add_optimizer(self, optimizer: torch.optim.Optimizer) -> None: + """ + Add an optimizer to the classifier object. + + Parameters + ---------- + optimizer : torch.optim.Optimizer + The optimizer to add to the classifier object. + + Returns + ------- + None + """ + self.optimizer = optimizer + + def initialize_scheduler( + self, + scheduler_type: str = "steplr", + scheduler_param_dict: dict | None = None, + ): + """ + Initializes a learning rate scheduler for the optimizer and adds it to + the classifier object. + Only `StepLR` is implemented - otherwise use `torch.optim.lr_scheduler` directly and the `add_scheduler` method. + + Parameters + ---------- + scheduler_type : str, optional + The type of learning rate scheduler to use. Default is ``"steplr"``. + scheduler_param_dict : dict, optional + The parameters to pass to the scheduler constructor, by default + ``{"step_size": 10, "gamma": 0.1}``. + + Raises + ------ + ValueError + If the specified ``scheduler_type`` is not implemented. + """ + if self.optimizer is None: + raise ValueError( + "[ERROR] Optimizer is not yet defined. \n\n\ +Use ``initialize_optimizer`` or ``add_optimizer`` to define one." # noqa + ) + + if scheduler_type.lower() == "steplr": + if scheduler_param_dict is None: + scheduler_param_dict = {"step_size": 10, "gamma": 0.1} + scheduler = optim.lr_scheduler.StepLR( + self.optimizer, **scheduler_param_dict + ) + else: + raise NotImplementedError( + '[ERROR] At present, only StepLR ("steplr") is implemented. \n\n\ +Use ``torch.optim.lr_scheduler`` directly and then the ``add_scheduler`` method for other schedulers.' # noqa + ) + + self.add_scheduler(scheduler) + + def add_scheduler(self, scheduler: torch.optim.lr_scheduler._LRScheduler) -> None: + """ + Add a scheduler to the classifier object. + Note that during training, `scheduler.step()` is called after each epoch - i.e. do not use schedulers that should be stepped after each batch! + + Parameters + ---------- + scheduler : torch.optim.lr_scheduler._LRScheduler + The scheduler to add to the classifier object. + + Raises + ------ + ValueError + If no optimizer has been set. Use ``initialize_optimizer`` or + ``add_optimizer`` to set an optimizer first. + + Returns + ------- + None + """ + if self.optimizer is None: + raise ValueError( + "[ERROR] Optimizer is needed first. Use `initialize_optimizer` or `add_optimizer`" # noqa + ) + + self.scheduler = scheduler + + def add_loss_fn( + self, loss_fn: str | nn.modules.loss._Loss = "cross entropy" + ) -> None: + """ + Add a loss function to the classifier object. + + Parameters + ---------- + loss_fn : str or torch.nn.modules.loss._Loss + The loss function to add to the classifier object. + Accepted string values are "cross entropy" or "ce" (cross-entropy), "bce" (binary cross-entropy) and "mse" (mean squared error). + + Returns + ------- + None + The function only modifies the ``loss_fn`` attribute of the + classifier and does not return anything. + """ + if isinstance(loss_fn, str): + if loss_fn in ["cross entropy", "ce", "cross_entropy", "cross-entropy"]: + loss_fn = nn.CrossEntropyLoss() + elif loss_fn in [ + "bce", + "binary_cross_entropy", + "binary cross entropy", + "binary cross-entropy", + ]: + loss_fn = nn.BCELoss() + elif loss_fn in [ + "mse", + "mean_square_error", + "mean_squared_error", + "mean squared error", + ]: + loss_fn = nn.MSELoss() + else: + raise NotImplementedError( + '[ERROR] At present, if passing ``loss_fn`` as a string, the loss function can only be "cross entropy" or "ce" (cross-entropy), "bce" (binary cross-entropy) or "mse" (mean squared error).' + ) + + print(f'[INFO] Using "{loss_fn}" as loss function.') + + elif not isinstance(loss_fn, nn.modules.loss._Loss): + raise ValueError( + '[ERROR] Please pass ``loss_fn`` as a string ("cross entropy", "bce" or "mse") or torch.nn loss function (see https://pytorch.org/docs/stable/nn.html).' + ) + + self.loss_fn = loss_fn + + def configure_optimizers(self): + """ + Lightning hook — returns the stored optimizer and (optionally) scheduler. + + Call ``initialize_optimizer`` / ``add_optimizer`` and optionally + ``initialize_scheduler`` / ``add_scheduler`` before ``trainer.fit()``. + """ + if self.optimizer is None: + raise ValueError( + "[ERROR] An optimizer should be defined for training the model." + ) + if self.scheduler is None: + return self.optimizer + return { + "optimizer": self.optimizer, + "lr_scheduler": { + "scheduler": self.scheduler, + "interval": "epoch", + }, + } + + def model_summary( + self, + input_size: tuple | list | None = None, + trainable_col: bool = False, + **kwargs, + ) -> None: + """ + Print a summary of the model. + + Parameters + ---------- + input_size : tuple or list, optional + The size of the input data. + If None, input size is taken from "train" dataloader (``self.dataloaders["train"]``). + trainable_col : bool, optional + If ``True``, adds a column showing which parameters are trainable. + Defaults to ``False``. + **kwargs : Dict + Keyword arguments to pass to ``torchinfo.summary()`` (see https://github.com/TylerYep/torchinfo). + + Notes + ----- + Other ways to check params: + + .. code-block:: python + + sum(p.numel() for p in myclassifier.model.parameters()) + + .. code-block:: python + + sum(p.numel() for p in myclassifier.model.parameters() + if p.requires_grad) + + And: + + .. code-block:: python + + for name, param in self.model.named_parameters(): + n = name.split(".")[0].split("_")[0] + print(name, param.requires_grad) + """ + if not input_size: + if "train" in self.dataloaders.keys(): + batch_size = self.dataloaders["train"].batch_size + channels = len(self.dataloaders["train"].dataset.image_mode) + input_size = (batch_size, channels, *self.input_size) + else: + raise ValueError("[ERROR] Please pass an input size.") + + if trainable_col: + col_names = ["num_params", "output_size", "trainable"] + else: + col_names = ["input_size", "output_size", "num_params"] + + model_summary = summary( + self.model, input_size=input_size, col_names=col_names, **kwargs + ) + print(model_summary) + + def freeze_layers(self, layers_to_freeze: list[str]) -> None: + """ + Freezes the specified layers in the neural network by setting + ``requires_grad`` attribute to False for their parameters. + + Parameters + ---------- + layers_to_freeze : list of str + List of names of the layers to freeze. + If a layer name ends with an asterisk (``"*"``), then all parameters whose name contains the layer name (excluding the asterisk) are frozen. Otherwise, only the parameters with an exact match to the layer name are frozen. + + Returns + ------- + None + The function only modifies the ``requires_grad`` attribute of the + specified parameters and does not return anything. + + Notes + ----- + e.g. ["layer1*", "layer2*"] will freeze all parameters whose name contains "layer1" and "layer2" (excluding the asterisk). + e.g. ["layer1", "layer2"] will freeze all parameters with an exact match to "layer1" and "layer2". + """ + if not isinstance(layers_to_freeze, list): + raise ValueError( + '[ERROR] ``layers_to_freeze`` must be a list of strings. E.g. ["layer1*", "layer2*"].' + ) + + for layer in layers_to_freeze: + for name, param in self.model.named_parameters(): + if (layer.endswith("*")) and ( + layer.strip("*") in name + ): # if using asterix wildcard + param.requires_grad = False + elif (not layer.endswith("*")) and ( + layer == name + ): # if using exact match + param.requires_grad = False + + def unfreeze_layers(self, layers_to_unfreeze: list[str]): + """ + Unfreezes the specified layers in the neural network by setting + ``requires_grad`` attribute to True for their parameters. + + Parameters + ---------- + layers_to_unfreeze : list of str + List of names of the layers to unfreeze. + If a layer name ends with an asterisk (``"*"``), then all parameters whose name contains the layer name (excluding the asterisk) are unfrozen. Otherwise, only the parameters with an exact match to the layer name are unfrozen. + + Returns + ------- + None + The function only modifies the ``requires_grad`` attribute of the + specified parameters and does not return anything. + + Notes + ----- + e.g. ["layer1*", "layer2*"] will unfreeze all parameters whose name contains "layer1" and "layer2" (excluding the asterisk). + e.g. ["layer1", "layer2"] will unfreeze all parameters with an exact match to "layer1" and "layer2". + """ + + if not isinstance(layers_to_unfreeze, list): + raise ValueError( + '[ERROR] ``layers_to_unfreeze`` must be a list of strings. E.g. ["layer1*", "layer2*"].' + ) + + for layer in layers_to_unfreeze: + for name, param in self.model.named_parameters(): + if (layer.endswith("*")) and (layer.strip("*") in name): + param.requires_grad = True + elif (not layer.endswith("*")) and (layer == name): + param.requires_grad = True + + def only_keep_layers(self, only_keep_layers_list: list[str]) -> None: + """ + Only keep the specified layers (``only_keep_layers_list``) for + gradient computation during the backpropagation. + + Parameters + ---------- + only_keep_layers_list : list + List of layer names to keep. All other layers will have their + gradient computation turned off. + + Returns + ------- + None + The function only modifies the ``requires_grad`` attribute of the + specified parameters and does not return anything. + """ + if not isinstance(only_keep_layers_list, list): + raise ValueError( + '[ERROR] ``only_keep_layers_list`` must be a list of strings. E.g. ["layer1", "layer2"].' + ) + + for name, param in self.model.named_parameters(): + if name in only_keep_layers_list: + param.requires_grad = True + else: + param.requires_grad = False + + def inference( + self, + set_name: str = "infer", + verbose: bool = False, + print_info_batch_freq: int = 5, + ): + """ + Run inference on a specified dataset (``set_name``). + + Populates ``self.pred_conf``, ``self.pred_label_indices`` and + ``self.pred_label``. For distributed or GPU inference, prefer + ``trainer.predict(model, dataloaders=...)``. + + Parameters + ---------- + set_name : str, optional + The name of the dataset to run inference on, by default + ``"infer"``. + verbose : bool, optional + Whether to print verbose outputs, by default False. + print_info_batch_freq : int, optional + The frequency of printouts, by default ``5``. + + Returns + ------- + None + """ + if set_name not in self.dataloaders.keys(): + raise KeyError( + f'[ERROR] "{set_name}" dataloader cannot be found in dataloaders.\n\ + Valid options are: {list(self.dataloaders.keys())}' # noqa + ) + + self.eval() + self.pred_conf = [] + self.pred_label_indices = [] + + total_input_counts = len(self.dataloaders[set_name].dataset) + phase_batch_size = self.dataloaders[set_name].batch_size + + with torch.no_grad(): + for batch_idx, (inputs, _, _) in enumerate(self.dataloaders[set_name]): + inputs = tuple(inp.to(self.device) for inp in inputs) + + outputs = self.model(*inputs) + if not isinstance(outputs, torch.Tensor): + outputs = self._get_logits(outputs) + + _, pred_label_indices = torch.max(outputs, dim=1) + self.pred_conf.extend( + torch.nn.functional.softmax(outputs, dim=1).cpu().tolist() + ) + self.pred_label_indices.extend(pred_label_indices.cpu().tolist()) + + if print_info_batch_freq and batch_idx % print_info_batch_freq == 0: + current_input_counts = min( + total_input_counts, + (batch_idx + 1) * phase_batch_size, + ) + progress = current_input_counts / total_input_counts * 100.0 + progress_msg = f"{current_input_counts}/{total_input_counts} ({progress:5.1f}% )" # noqa + print(f"[INFO] {progress_msg}") + + self.pred_label = [ + self.labels_map.get(i, None) for i in self.pred_label_indices + ] + + def train_component_summary(self) -> None: + """ + Print a summary of the optimizer, loss function, and trainable model + components. + + Returns: + -------- + None + """ + divider = 20 * "=" + print(divider) + print("* Optimizer:") + print(str(self.optimizer)) + print(divider) + print("* Loss function:") + print(str(self.loss_fn)) + print(divider) + print("* Model:") + self.model_summary(trainable_col=True) + + # Lightning training hooks + + def on_train_start(self) -> None: + self.pred_conf = [] + self.pred_label_indices = [] + self.gt_label_indices = [] + + def on_train_epoch_start(self) -> None: + self._train_running_pred_conf = [] + self._train_running_pred_label_indices = [] + self._train_running_gt_label_indices = [] + + def training_step(self, batch, batch_idx: int) -> torch.Tensor: + if self.loss_fn is None: + raise ValueError( + "[ERROR] A loss function should be defined for training the model." + ) + + inputs, _, gt_label_indices = batch + inputs = tuple(inputs) + + if self.is_inception: + outputs, aux_outputs = self.model(*inputs) + if not isinstance(outputs, torch.Tensor): + outputs = self._get_logits(outputs) + if not isinstance(aux_outputs, torch.Tensor): + aux_outputs = self._get_logits(aux_outputs) + loss1 = self.loss_fn(outputs, gt_label_indices) + loss2 = self.loss_fn(aux_outputs, gt_label_indices) + # https://discuss.pytorch.org/t/how-to-optimize-inception-model-with-auxiliary-classifiers/7958 + loss = loss1 + 0.4 * loss2 # calculate loss + else: + outputs = self.model(*inputs) + if not isinstance(outputs, torch.Tensor): + outputs = self._get_logits(outputs) + loss = self.loss_fn(outputs, gt_label_indices) # calculate loss + + _, pred_label_indices = torch.max(outputs, dim=1) + + self._train_running_pred_conf.extend( + torch.nn.functional.softmax(outputs, dim=1).cpu().tolist() + ) + self._train_running_pred_label_indices.extend(pred_label_indices.cpu().tolist()) + self._train_running_gt_label_indices.extend(gt_label_indices.cpu().tolist()) + + self.log( + "train_loss", + loss, + on_step=True, + on_epoch=True, + prog_bar=True, + sync_dist=True, + ) + return loss + + def on_train_epoch_end(self) -> None: + if not self._train_running_pred_label_indices: + return + + epoch = self.current_epoch + self.last_epoch = epoch + + # Epoch loss is already aggregated (and synced across ranks) by self.log() + cb = self.trainer.callback_metrics + epoch_loss = float( + cb.get("train_loss_epoch", cb.get("train_loss", float("nan"))) + ) + self._add_metrics("train", "loss", epoch_loss) + + # Per-class metrics: computed per-rank (correct for single GPU/CPU). + # In multi-GPU each rank only sees its own shard — metrics are approximate. + if self.global_rank == 0: + self.calculate_add_metrics( + self._train_running_gt_label_indices, + self._train_running_pred_label_indices, + self._train_running_pred_conf, + "train", + epoch, + ) + + self.pred_conf.extend(self._train_running_pred_conf) + self.pred_label_indices.extend(self._train_running_pred_label_indices) + self.gt_label_indices.extend(self._train_running_gt_label_indices) + + # Lightning validation hooks + + def on_validation_epoch_start(self) -> None: + self._val_running_pred_conf = [] + self._val_running_pred_label_indices = [] + self._val_running_gt_label_indices = [] + + def validation_step(self, batch, batch_idx: int) -> torch.Tensor: + if self.loss_fn is None: + raise ValueError( + "[ERROR] A loss function should be defined for training the model." + ) + + inputs, _, gt_label_indices = batch + inputs = tuple(inputs) + + outputs = self.model(*inputs) + if not isinstance(outputs, torch.Tensor): + outputs = self._get_logits(outputs) + loss = self.loss_fn(outputs, gt_label_indices) + + _, pred_label_indices = torch.max(outputs, dim=1) + + self._val_running_pred_conf.extend( + torch.nn.functional.softmax(outputs, dim=1).cpu().tolist() + ) + self._val_running_pred_label_indices.extend(pred_label_indices.cpu().tolist()) + self._val_running_gt_label_indices.extend(gt_label_indices.cpu().tolist()) + + self.log( + "val_loss", + loss, + on_step=False, + on_epoch=True, + prog_bar=True, + sync_dist=True, + ) + return loss + + def on_validation_epoch_end(self) -> None: + if not self._val_running_pred_label_indices: + return + + epoch = self.current_epoch + + cb = self.trainer.callback_metrics + epoch_loss = float(cb.get("val_loss", float("nan"))) + self._add_metrics("val", "loss", epoch_loss) + + if self.global_rank == 0: + self.calculate_add_metrics( + self._val_running_gt_label_indices, + self._val_running_pred_label_indices, + self._val_running_pred_conf, + "val", + epoch, + ) + + if epoch_loss < self.best_loss: + self.best_loss = epoch_loss + self.best_epoch = epoch + + # Lightning predict hooks + + def on_predict_start(self) -> None: + self.pred_conf = [] + self.pred_label_indices = [] + + def predict_step(self, batch, batch_idx: int) -> dict: + inputs, _, _ = batch + inputs = tuple(inputs) + + outputs = self.model(*inputs) + if not isinstance(outputs, torch.Tensor): + outputs = self._get_logits(outputs) + + pred_conf = torch.nn.functional.softmax(outputs, dim=1) + _, pred_label_indices = torch.max(outputs, dim=1) + + # Accumulate on the object (works for CPU / single GPU). + # For multi-GPU, trainer.predict() gathers the returned dicts across ranks. + self.pred_conf.extend(pred_conf.cpu().tolist()) + self.pred_label_indices.extend(pred_label_indices.cpu().tolist()) + + return { + "pred_conf": pred_conf.cpu(), + "pred_label_indices": pred_label_indices.cpu(), + } + + def on_predict_end(self) -> None: + self.pred_label = [ + self.labels_map.get(i, None) for i in self.pred_label_indices + ] + + def calculate_add_metrics( + self, + y_true, + y_pred, + y_score, + phase: str, + epoch: int | None = -1, + tboard_writer=None, + ) -> None: + """ + Calculate and add metrics to the classifier's metrics dictionary. + + Parameters + ---------- + y_true : 1d array-like of shape (n_samples,) + True binary labels or multiclass labels. Can be considered ground + truth or (correct) target values. + + y_pred : 1d array-like of shape (n_samples,) + Predicted binary labels or multiclass labels. The estimated + targets as returned by a classifier. + + y_score : array-like of shape (n_samples, n_classes) + Predicted probabilities for each class. + + phase : str + Name of the current phase, typically ``"train"`` or ``"val"``. See + ``train`` function. + + epoch : int, optional + Current epoch number. Default is ``-1``. + + tboard_writer : object, optional + TensorBoard SummaryWriter object to write the metrics. Default is + ``None``. + + Returns + ------- + None + + Notes + ----- + This method uses both the + :func:`sklearn.metrics.precision_recall_fscore_support` and + :func:`sklearn.metrics.roc_auc_score` functions from ``scikit-learn`` + to calculate the metrics for each average type (``"micro"``, + ``"macro"`` and ``"weighted"``). The results are then added to the + ``metrics`` dictionary. It also writes the metrics to the TensorBoard + SummaryWriter, if ``tboard_writer`` is not None. + """ + # convert y_score to a numpy array: + if not isinstance(y_score, np.ndarray): + y_score = np.array(y_score) + + for average in [None, "micro", "macro", "weighted"]: + labels = list(range(y_score.shape[1])) if average is None else None + precision, recall, fscore, support = precision_recall_fscore_support( + y_true, y_pred, average=average, labels=labels + ) + + if average is None: + for i in range( + y_score.shape[1] + ): # y_score.shape[1] represents the number of classes + self._add_metrics(phase, f"precision_{i}", precision[i]) + self._add_metrics(phase, f"recall_{i}", recall[i]) + self._add_metrics(phase, f"fscore_{i}", fscore[i]) + self._add_metrics(phase, f"support_{i}", support[i]) + + if tboard_writer is not None: + tboard_writer.add_scalar( + f"Precision/{phase}/binary_{i}", + precision[i], + epoch, + ) + tboard_writer.add_scalar( + f"Recall/{phase}/binary_{i}", + recall[i], + epoch, + ) + tboard_writer.add_scalar( + f"Fscore/{phase}/binary_{i}", + fscore[i], + epoch, + ) + + else: # for micro, macro, weighted + self._add_metrics(phase, f"precision_{average}", precision) + self._add_metrics(phase, f"recall_{average}", recall) + self._add_metrics(phase, f"fscore_{average}", fscore) + self._add_metrics(phase, f"support_{average}", support) + + if tboard_writer is not None: + tboard_writer.add_scalar( + f"Precision/{phase}/{average}", + precision, + epoch, + ) + tboard_writer.add_scalar( + f"Recall/{phase}/{average}", + recall, + epoch, + ) + tboard_writer.add_scalar( + f"Fscore/{phase}/{average}", + fscore, + epoch, + ) + + # --- compute ROC AUC + if y_score.shape[1] == 2: + # ---- binary case + # From scikit-learn: + # The probability estimates correspond to the probability + # of the class with the greater label, i.e. + # estimator.classes_[1] and thus + # estimator.predict_proba(X, y)[:, 1] + roc_auc = roc_auc_score(y_true, y_score[:, 1], average=average) + elif y_score.shape[1] > 2: + # ---- multiclass + # In the multiclass case, it corresponds to an array of shape + # (n_samples, n_classes) + # ovr = One-vs-rest (OvR) multiclass strategy + try: + roc_auc = roc_auc_score( + y_true, y_score, average=average, multi_class="ovr" + ) + except: + continue + else: + continue + self._add_metrics(phase, f"rocauc_{average}", roc_auc) + + def _add_metrics( + self, phase: str, metric: str, value: int | (float | (complex | np.number)) + ) -> None: + """ + Adds a metric value to a dictionary of metrics tracked during training. + + Parameters + ---------- + phase: str + The phase of the training (e.g., "train", "val") to which the + metric value corresponds. + metric : str + The name of the metric to add to the dictionary of metrics. + value : numeric + The metric value to add to the corresponding list of metric values. + + Returns + ------- + None + + Notes + ----- + If the key ``k`` does not exist in the dictionary of metrics, a new + key-value pair is created with ``k`` as the key and a new list + containing the value ``v`` as the value. If the key ``k`` already + exists in the dictionary of metrics, the value `v` is appended to the + list associated with the key ``k``. + """ + if phase not in self.metrics.keys(): + self.metrics[phase] = {} + if metric not in self.metrics[phase].keys(): + self.metrics[phase][metric] = [] + + self.metrics[phase][metric].append(value) + + def list_metrics(self, phases: str | list[str] = "all") -> None: + """Prints the available metrics for the specified phases. + + Parameters + ---------- + phases : str | list[str], optional + The phases to find metrics for, by default "all" + """ + if isinstance(phases, str): + if phases == "all": + phases = [*self.metrics.keys()] + else: + phases = [phases] + + if not isinstance(phases, list): + raise ValueError( + '[ERROR] ``phases`` must be a string or a list of strings. E.g. ["train", "val"].' + ) + + metrics = set( + metric for phase in phases for metric in self.metrics.get(phase, {}).keys() + ) + + print("Phases:", *phases, sep="\n - ") + print("\nAvailable metrics:", *metrics, sep="\n - ") + + def plot_metric( + self, + metrics: str | list[str], + phases: str | list[str] = "all", + colors: list[str] | None = None, + figsize: tuple[int, int] = (10, 5), + plt_yrange: tuple[float, float] | None = None, + plt_xrange: tuple[float, float] | None = None, + ): + """ + Plot the metrics of the classifier object. + + Parameters + ---------- + metrics : str or list of str + A string of list of strings containing metric names to be plotted on the y-axis. + phases : str or list of str, optional + The phases for which the metric is to be plotted. Defaults to ``"all"``. + colors : list of str, optional + Colors to be used for the lines of each metric. Length must be at least the length of the number of phases being plotted (``phases``). If None, will use the default matplotlib colors. Defaults to ``None``. + figsize : tuple of int, optional + The size of the figure in inches. Defaults to ``(10, 5)``. + plt_yrange : tuple of float, optional + The range of values for the y-axis. Defaults to ``None``. + plt_xrange : tuple of float, optional + The range of values for the x-axis. Defaults to ``None``. + """ + plt.figure(figsize=figsize) + + if isinstance(metrics, str): + metrics = [metrics] + if not isinstance(metrics, list): + raise TypeError("`metrics` must be a string or a list of strings.") + + # make list of colors iterable + if colors: + colors = iter(colors) + + for metric_name in metrics: + if isinstance(phases, str): + if phases == "all": + phases = [*self.metrics.keys()] + else: + phases = [phases] + if not isinstance(phases, list): + raise TypeError("`phases` must be a string or a list of strings.") + + for phase in phases: + if metric_name not in self.metrics[phase].keys(): + raise KeyError( + f"Metric {metric_name} not found in {phase} metrics." + ) + + # get color + if colors: + color = next(colors) + else: + color = plt.gca()._get_lines.get_next_color() + + plt.plot( + range( + 1, len(self.metrics[phase][metric_name]) + 1 + ), # i.e. epochs, starting from 1 + self.metrics[phase][metric_name], + label=f"{phase}_{metric_name}", + color=color, + linestyle="-", + marker="o", + linewidth=2, + ) + + # set labels and ticks + plt.xlabel("Epochs", size=24) + plt.gca().xaxis.set_major_locator(MaxNLocator(integer=True)) + plt.ylabel(" | ".join(metrics), size=24) + plt.xticks(size=18) + plt.yticks(size=18) + + # legend + plt.legend( + fontsize=18, + bbox_to_anchor=(0, 1.02, 1, 0.2), + ncol=2, + borderaxespad=0, + loc="lower center", + ) + + # x/y range + if plt_xrange is not None: + plt.xlim(plt_xrange[0], plt_xrange[1]) + if plt_yrange is not None: + plt.ylim(plt_yrange[0], plt_yrange[1]) + + plt.grid() + plt.show() + + def print_batch_info(self, set_name: str | None = "train") -> None: + """ + Print information about a dataset's batches, samples, and batch-size. + + Parameters + ---------- + set_name : str, optional + Name of the dataset to display batch information for (default is + ``"train"``). + + Returns + ------- + None + """ + if set_name not in self.dataloaders.keys(): + raise ValueError( + f"[ERROR] ``set_name`` must be one of {list(self.dataloaders.keys())}." + ) + + batch_size = self.dataloaders[set_name].batch_size + num_samples = len(self.dataloaders[set_name].dataset) + num_batches = int(np.ceil(num_samples / batch_size)) + + print( + f"[INFO] dataset: {set_name}\n\ + - items: {num_samples}\n\ + - batch size: {batch_size}\n\ + - batches: {num_batches}" + ) + + def show_inference_sample_results( + self, + label: str, + num_samples: int | None = 6, + set_name: str | None = "test", + min_conf: None | float | None = None, + max_conf: None | float | None = None, + figsize: tuple[int, int] | None = (15, 15), + ) -> None: + """ + Shows a sample of the results of the inference with current model. + + Parameters + ---------- + label : str, optional + The label for which to display results. + num_samples : int, optional + The number of sample results to display. Defaults to ``6``. + set_name : str, optional + The name of the dataset split to use for inference. Defaults to + ``"test"``. + min_conf : float, optional + The minimum confidence score for a sample result to be displayed. + Samples with lower confidence scores will be skipped. Defaults to + ``None``. + max_conf : float, optional + The maximum confidence score for a sample result to be displayed. + Samples with higher confidence scores will be skipped. Defaults to + ``None``. + figsize : tuple[int, int], optional + Figure size (width, height) in inches, displaying the sample + results. Defaults to ``(15, 15)``. + + Returns + ------- + None + """ + + # eval mode, keep track of the current mode + was_training = self.model.training + self.model.eval() + + if set_name not in self.dataloaders.keys(): + raise ValueError( + f"[ERROR] ``set_name`` must be one of {list(self.dataloaders.keys())}." + ) + + counter = 0 + plt.figure(figsize=figsize) + + with torch.no_grad(): + for inputs, _, label_indices in iter(self.dataloaders[set_name]): + inputs = tuple(input.to(self.device) for input in inputs) + label_indices = label_indices.to(self.device) + + outputs = self.model(*inputs) + + if not isinstance(outputs, torch.Tensor): + outputs = self._get_logits(outputs) + + pred_conf = torch.nn.functional.softmax(outputs, dim=1) * 100.0 + _, preds = torch.max(outputs, 1) + + # reverse the labels_map dict + label_index_dict = {v: k for k, v in self.labels_map.items()} + + # go through images in batch + for i, pred in enumerate(preds): + predicted_index = int(pred) + if predicted_index != label_index_dict[label]: + continue + if (min_conf is not None) and ( + pred_conf[i][predicted_index] < min_conf + ): + continue + if (max_conf is not None) and ( + pred_conf[i][predicted_index] > max_conf + ): + continue + + counter += 1 + + conf_score = pred_conf[i][predicted_index] + ax = plt.subplot(int(num_samples / 2.0), 3, counter) + ax.axis("off") + ax.set_title(f"{label} | {conf_score:.3f}") + + inp = inputs[0].cpu().data[i].numpy().transpose((1, 2, 0)) + inp = np.clip(inp, 0, 1) + plt.imshow(inp) + + if counter == num_samples: + self.model.train(mode=was_training) + plt.show() + return + + self.model.train(mode=was_training) + plt.show() + + def save( + self, + save_path: str | None = "default.obj", + force: bool | None = False, + ) -> None: + """ + Save the object to a file. + + Parameters + ---------- + save_path : str, optional + The path to the file to write. + If the file already exists and ``force`` is not ``True``, a ``FileExistsError`` is raised. + Defaults to ``"default.obj"``. + force : bool, optional + Whether to overwrite the file if it already exists. Defaults to + ``False``. + + Raises + ------ + FileExistsError + If the file already exists and ``force`` is not ``True``. + + Notes + ----- + The object is saved in two parts. First, a serialized copy of the + object's dictionary is written to the specified file using the + ``joblib.dump`` function. The object's ``model`` attribute is excluded + from this dictionary and saved separately using the ``torch.save`` + function, with a filename derived from the original ``save_path``. + """ + if os.path.isfile(save_path): + if force: + os.remove(save_path) + else: + raise FileExistsError(f"[INFO] File already exists: {save_path}") + + # parent/base-names + par_name = os.path.dirname(os.path.abspath(save_path)) + base_name = os.path.basename(os.path.abspath(save_path)) + + # Extract model, write it separately using torch.save + # Exclude Lightning/PyTorch Module internals (stored in _modules etc.) + # and optimizer/scheduler (not portable across runs). + _skip_prefixes = ("_",) + obj2write = { + k: copy.deepcopy(v) + for k, v in self.__dict__.items() + if not any(k.startswith(p) for p in _skip_prefixes) + and k not in ("optimizer", "scheduler") + } + + os.makedirs(par_name, exist_ok=True) + with open(save_path, "wb") as myfile: + joblib.dump(obj2write, myfile) + + torch.save(self.model, os.path.join(par_name, f"model_{base_name}")) + torch.save( + self.model.state_dict(), + os.path.join(par_name, f"model_state_dict_{base_name}"), + ) + + def save_predictions( + self, + set_name: str, + save_path: str | None = None, + delimiter: str = ",", + ): + if set_name not in self.dataloaders.keys(): + raise ValueError( + f"[ERROR] ``set_name`` must be one of {list(self.dataloaders.keys())}." + ) + + patch_df = self.dataloaders[set_name].dataset.patch_df + patch_df["predicted_label"] = self.pred_label + patch_df["pred"] = self.pred_label_indices + patch_df["conf"] = np.array(self.pred_conf).max(axis=1) + + if save_path is None: + save_path = f"{set_name}_predictions_patch_df.csv" + patch_df.to_csv(save_path, sep=delimiter) + print(f"[INFO] Saved predictions to {save_path}.") + + def load_dataset( + self, + dataset: PatchDataset, + set_name: str, + batch_size: int | None = 16, + sampler: Sampler | None | None = None, + shuffle: bool | None = False, + num_workers: int | None = 0, + **kwargs, + ) -> None: + """Creates a DataLoader from a PatchDataset and adds it to the ``dataloaders`` dictionary. + + Parameters + ---------- + dataset : PatchDataset + The dataset to add + set_name : str + The name to use for the dataset + batch_size : Optional[int], optional + The batch size to use when creating the DataLoader, by default 16 + sampler : Optional[Union[Sampler, None]], optional + The sampler to use when creating the DataLoader, by default None + shuffle : Optional[bool], optional + Whether to shuffle the PatchDataset, by default False + num_workers : Optional[int], optional + The number of worker threads to use for loading data, by default 0. + """ + if sampler and shuffle: + print("[INFO] ``sampler`` is defined so train dataset will be unshuffled.") + + dataloader = DataLoader( + dataset, + batch_size=batch_size, + sampler=sampler, + shuffle=shuffle, + num_workers=num_workers, + **kwargs, + ) + + self.dataloaders[set_name] = dataloader + + def load( + self, + load_path: str, + force_device: bool | None = False, + ) -> None: + """ + This function loads the state of a class instance from a saved file + using the joblib library. It also loads a PyTorch model from a + separate file and maps it to the device used to load the class + instance. + + Parameters + ---------- + load_path : str + Path to the saved file to load. + force_device : bool or str, optional + Whether to force the use of a specific device, or the name of the + device to use. If set to ``True``, the default device is used. + Defaults to ``False``. + + Raises + ------ + FileNotFoundError + If the specified file does not exist. + + Returns + ------- + None + """ + + load_path = os.path.abspath(load_path) + mydevice = self.device + + if not os.path.isfile(load_path): + raise FileNotFoundError(f'[ERROR] "{load_path}" cannot be found.') + + print(f'[INFO] Loading "{load_path}".') + + with open(load_path, "rb") as myfile: + # objPickle = pickle.load(myfile) + objPickle = joblib.load(myfile) + + # Skip read-only properties inherited from nn.Module/LightningModule + # (e.g. 'device', 'dtype') — Lightning manages these itself. + _readonly = {"device", "dtype"} + for k, v in objPickle.items(): + if k not in _readonly: + setattr(self, k, v) + + if force_device: + if not isinstance(force_device, str): + force_device = str(force_device) + os.environ["CUDA_VISIBLE_DEVICES"] = force_device + + par_name = os.path.dirname(load_path) + base_name = os.path.basename(load_path) + path2model = os.path.join(par_name, f"model_{base_name}") + self.model = torch.load(path2model, map_location=mydevice, weights_only=False) + + try: + self.device = mydevice + self.model = self.model.to(mydevice) + except: + pass + + @staticmethod + def _get_logits(out): + try: + out = out.logits + except AttributeError as err: + raise AttributeError(str(err)) + return out diff --git a/setup.py b/setup.py index f1219d11..da7c65bd 100644 --- a/setup.py +++ b/setup.py @@ -67,6 +67,9 @@ "lxml", ], extras_require={ + "lightning": [ + "lightning>=2.0", + ], "dev": [ "pytest<9.0.0", "pytest-cov>=4.1.0,<6.0.0", diff --git a/tests/test_classify/test_classifier.py b/tests/test_classify/test_classifier.py index c601dc8e..23eb2f6c 100644 --- a/tests/test_classify/test_classifier.py +++ b/tests/test_classify/test_classifier.py @@ -86,6 +86,14 @@ def test_init_models_string_errors(inputs): ) +def test_init_inception_input_size(inputs): + """Regression: inception_v3 must set input_size to (299, 299).""" + annots, _ = inputs + classifier = ClassifierContainer("inception_v3", labels_map=annots.labels_map) + assert classifier.input_size == (299, 299) + assert classifier.is_inception is True + + # test loading model (e.g. resnet18) using torch load @@ -254,6 +262,15 @@ def test_initialize_optimizer(load_classifier): assert isinstance(classifier.optimizer, torch.optim.Adam) +def test_generate_layerwise_lrs_uses_lr_key(load_classifier): + """Regression: param groups must use 'lr' not 'learning rate'.""" + classifier = load_classifier + params2optimize = classifier.generate_layerwise_lrs(min_lr=1e-4, max_lr=1e-3) + for group in params2optimize: + assert "lr" in group, "param group missing 'lr' key" + assert "learning rate" not in group, "param group has wrong key 'learning rate'" + + def test_initialize_scheduler(load_classifier): classifier = load_classifier classifier.initialize_optimizer() @@ -286,10 +303,21 @@ def test_save(load_classifier, tmp_path): assert os.path.isfile(f"{tmp_path}/model_out.obj") +def test_save_load_roundtrip(inputs, tmp_path): + """Save a fresh classifier and load it back; check model and labels_map survive.""" + annots, _ = inputs + classifier = ClassifierContainer("resnet18", labels_map=annots.labels_map) + classifier.save(save_path=f"{tmp_path}/rt.obj") + loaded = ClassifierContainer(None, None, load_path=f"{tmp_path}/rt.obj") + assert isinstance(loaded.model, models.ResNet) + assert loaded.labels_map == annots.labels_map + + def test_load_dataset(load_classifier, sample_dir): classifier = load_classifier dataset = PatchDataset(f"{sample_dir}/test_annots_append.csv", "test") classifier.load_dataset(dataset, "pytest_set", batch_size=8, shuffle=True) + assert "pytest_set" in classifier.dataloaders # errors diff --git a/tests/test_classify/test_lightning_classifier.py b/tests/test_classify/test_lightning_classifier.py new file mode 100644 index 00000000..44aaf789 --- /dev/null +++ b/tests/test_classify/test_lightning_classifier.py @@ -0,0 +1,402 @@ +from __future__ import annotations + +import os +from pathlib import Path + +import numpy as np +import pandas as pd +import pytest +import torch +from lightning.pytorch import Trainer +from torchvision import models + +from mapreader import AnnotationsLoader +from mapreader.classify.datasets import PatchDataset +from mapreader.classify.lightning_classifier import LightningClassifierContainer + + +@pytest.fixture +def sample_dir(): + return Path(__file__).resolve().parent.parent / "sample_files" + + +@pytest.fixture +@pytest.mark.dependency(depends=["load_annots_csv", "dataloaders"], scope="session") +def inputs(sample_dir): + annots = AnnotationsLoader() + annots.load( + f"{sample_dir}/test_annots.csv", + remove_broken=False, + ignore_broken=True, + ) + dataloaders = annots.create_dataloaders(batch_size=8) + return annots, dataloaders + + +@pytest.fixture +def infer_inputs(sample_dir): + infer_dict = { + "image_id": ["cropped_74488689.png"], + "image_path": [f"{sample_dir}/cropped_74488689.png"], + } + infer_df = pd.DataFrame.from_dict(infer_dict, orient="columns") + infer = PatchDataset(infer_df, transform="val") + return infer + + +@pytest.fixture +def load_classifier(sample_dir): + classifier = LightningClassifierContainer( + None, None, None, load_path=f"{sample_dir}/test.pkl" + ) + return classifier + + +@pytest.fixture +def ready_classifier(sample_dir): + """Classifier with loss/optimizer/scheduler set up, ready for training.""" + classifier = LightningClassifierContainer( + None, None, None, load_path=f"{sample_dir}/test.pkl" + ) + classifier.add_loss_fn("cross entropy") + classifier.initialize_optimizer("adam") + classifier.initialize_scheduler() + return classifier + + +# test loading model using model name as string + + +@pytest.mark.dependency(name="lc_models_by_string", scope="session") +def test_init_models_string(inputs): + annots, dataloaders = inputs + for model2test in [ + ["resnet18", models.ResNet], + ["alexnet", models.AlexNet], + ["vgg11", models.VGG], + ["squeezenet1_0", models.SqueezeNet], + ["densenet121", models.DenseNet], + ["inception_v3", models.Inception3], + ]: + model, model_type = model2test + classifier = LightningClassifierContainer( + model, labels_map=annots.labels_map, dataloaders=dataloaders + ) + assert isinstance(classifier.model, model_type) + assert all(k in classifier.dataloaders.keys() for k in ["train", "test", "val"]) + + classifier = LightningClassifierContainer(model, labels_map=annots.labels_map) + assert isinstance(classifier.model, model_type) + assert classifier.dataloaders == {} + + +def test_init_models_string_errors(inputs): + annots, dataloaders = inputs + with pytest.raises(NotImplementedError, match="Invalid model name"): + LightningClassifierContainer( + "resnext101_32x8d", labels_map=annots.labels_map, dataloaders=dataloaders + ) + + +def test_init_inception_input_size(inputs): + """Regression: inception_v3 must set input_size to (299, 299), not 299.""" + annots, _ = inputs + classifier = LightningClassifierContainer( + "inception_v3", labels_map=annots.labels_map + ) + assert classifier.input_size == (299, 299) + assert classifier.is_inception is True + + +# test loading model (e.g. resnet18) using torch load + + +def test_init_resnet18_torch(inputs): + annots, dataloaders = inputs + my_model = models.resnet18(weights="DEFAULT") + num_input_features = my_model.fc.in_features + my_model.fc = torch.nn.Linear(num_input_features, len(annots.labels_map)) + classifier = LightningClassifierContainer( + my_model, labels_map=annots.labels_map, dataloaders=dataloaders + ) + assert isinstance(classifier.model, models.ResNet) + assert all(k in classifier.dataloaders.keys() for k in ["train", "test", "val"]) + + classifier = LightningClassifierContainer(my_model, labels_map=annots.labels_map) + assert isinstance(classifier.model, models.ResNet) + assert classifier.dataloaders == {} + + +# test loading object from pickle file + + +def test_load_no_dataloaders(inputs, sample_dir): + annots, dataloaders = inputs + classifier = LightningClassifierContainer( + None, None, None, load_path=f"{sample_dir}/test.pkl" + ) + assert all(k in classifier.dataloaders.keys() for k in ["train", "test", "val"]) + assert classifier.labels_map == annots.labels_map + assert isinstance(classifier.model, models.ResNet) + + # without explicitly passing dataloaders as None + classifier = LightningClassifierContainer( + None, None, load_path=f"{sample_dir}/test.pkl" + ) + assert all(k in classifier.dataloaders.keys() for k in ["train", "test", "val"]) + assert classifier.labels_map == annots.labels_map + assert isinstance(classifier.model, models.ResNet) + + +def test_load_w_dataloaders(inputs, sample_dir): + annots, dataloaders = inputs + dataloaders["new_train"] = dataloaders.pop("train") + dataloaders["new_val"] = dataloaders.pop("val") + dataloaders["new_test"] = dataloaders.pop("test") + + classifier = LightningClassifierContainer( + None, None, dataloaders=dataloaders, load_path=f"{sample_dir}/test.pkl" + ) + assert all( + k in classifier.dataloaders.keys() + for k in ["train", "test", "val", "new_train", "new_test", "new_val"] + ) + assert classifier.labels_map == annots.labels_map + assert isinstance(classifier.model, models.ResNet) + + +def test_init_load(inputs, load_classifier): + annots, dataloaders = inputs + classifier = load_classifier + assert all(k in classifier.dataloaders.keys() for k in ["train", "test", "val"]) + assert classifier.labels_map == annots.labels_map + assert isinstance(classifier.model, models.ResNet) + + +def test_add_loss_fn(load_classifier): + classifier = load_classifier + classifier.add_loss_fn("bce") + assert isinstance(classifier.loss_fn, torch.nn.BCELoss) + loss_fn = torch.nn.L1Loss() + classifier.add_loss_fn(loss_fn) + assert isinstance(classifier.loss_fn, torch.nn.L1Loss) + + +def test_initialize_optimizer(load_classifier): + classifier = load_classifier + classifier.initialize_optimizer("sgd") + assert isinstance(classifier.optimizer, torch.optim.SGD) + + params2optimize = classifier.generate_layerwise_lrs( + min_lr=1e-4, max_lr=1e-3, spacing="geomspace" + ) + classifier.initialize_optimizer("adam", params2optimize) + assert isinstance(classifier.optimizer, torch.optim.Adam) + + +def test_generate_layerwise_lrs_uses_lr_key(load_classifier): + """Regression: param groups must use 'lr' not 'learning rate'.""" + classifier = load_classifier + params2optimize = classifier.generate_layerwise_lrs(min_lr=1e-4, max_lr=1e-3) + for group in params2optimize: + assert "lr" in group, "param group missing 'lr' key" + assert "learning rate" not in group, "param group has wrong key 'learning rate'" + + +def test_initialize_scheduler(load_classifier): + classifier = load_classifier + classifier.initialize_optimizer() + classifier.initialize_scheduler( + scheduler_param_dict={"step_size": 5, "gamma": 0.02} + ) + assert isinstance(classifier.scheduler, torch.optim.lr_scheduler.StepLR) + assert classifier.scheduler.step_size == 5 + assert classifier.scheduler.gamma == 0.02 + + +def test_calculate_add_metrics(load_classifier): + classifier = load_classifier + y_true = np.ones(10) + np.random.seed(0) + y_pred = np.random.randint(0, 2, 10) + y_score = np.random.random_sample((10, 1)) + classifier.calculate_add_metrics(y_true, y_pred, y_score, phase="pytest") + assert "pytest" in classifier.metrics.keys() + for metric in ["precision", "recall", "fscore", "support"]: + for suffix in ["0", "micro", "macro", "weighted"]: + assert f"{metric}_{suffix}" in classifier.metrics["pytest"].keys() + assert len(classifier.metrics["pytest"][f"{metric}_{suffix}"]) == 1 + + +def test_save(load_classifier, tmp_path): + classifier = load_classifier + classifier.save(save_path=f"{tmp_path}/out.obj") + assert os.path.isfile(f"{tmp_path}/out.obj") + assert os.path.isfile(f"{tmp_path}/model_out.obj") + + +def test_save_load_roundtrip(inputs, tmp_path): + """Save a fresh classifier and load it back; check model and labels_map survive.""" + annots, _ = inputs + classifier = LightningClassifierContainer("resnet18", labels_map=annots.labels_map) + classifier.save(save_path=f"{tmp_path}/rt.obj") + + loaded = LightningClassifierContainer(None, None, load_path=f"{tmp_path}/rt.obj") + assert isinstance(loaded.model, models.ResNet) + assert loaded.labels_map == annots.labels_map + + +def test_load_dataset(load_classifier, sample_dir): + classifier = load_classifier + dataset = PatchDataset(f"{sample_dir}/test_annots_append.csv", "test") + classifier.load_dataset(dataset, "pytest_set", batch_size=8, shuffle=True) + assert "pytest_set" in classifier.dataloaders + + +# errors + + +def test_init_errors(sample_dir): + with pytest.raises( + ValueError, match="``model`` and ``labels_map`` must be defined" + ): + LightningClassifierContainer("resnet18", None, None) + + +def test_loss_fn_errors(load_classifier): + classifier = load_classifier + with pytest.raises(NotImplementedError, match="loss function can only be"): + classifier.add_loss_fn("a fake loss_fn") + with pytest.raises(ValueError, match="Please pass"): + classifier.add_loss_fn(0.01) + + +def test_optimizer_errors(load_classifier): + classifier = load_classifier + with pytest.raises(NotImplementedError, match="At present, only"): + classifier.initialize_optimizer("a fake optimizer") + with pytest.raises(NotImplementedError, match="must be one of"): + classifier.generate_layerwise_lrs(1e-4, 1e-3, "a fake spacing") + + +def test_scheduler_errors(load_classifier): + classifier = load_classifier + with pytest.raises(ValueError, match="not yet defined"): + classifier.initialize_scheduler() + classifier.initialize_optimizer() + with pytest.raises(NotImplementedError, match="only StepLR"): + classifier.initialize_scheduler("a fake scheduler type") + + +# test configure_optimizers (Lightning hook) + + +def test_configure_optimizers_no_scheduler(load_classifier): + classifier = load_classifier + classifier.initialize_optimizer("adam") + result = classifier.configure_optimizers() + assert isinstance(result, torch.optim.Adam) + + +def test_configure_optimizers_with_scheduler(load_classifier): + classifier = load_classifier + classifier.initialize_optimizer("adam") + classifier.initialize_scheduler() + result = classifier.configure_optimizers() + assert isinstance(result, dict) + assert "optimizer" in result + assert "lr_scheduler" in result + assert isinstance(result["optimizer"], torch.optim.Adam) + + +def test_configure_optimizers_no_optimizer(load_classifier): + classifier = load_classifier + classifier.optimizer = None + with pytest.raises(ValueError, match="optimizer should be defined"): + classifier.configure_optimizers() + + +# test inference + + +@pytest.mark.dependency(depends=["lc_models_by_string"], scope="session") +def test_inference(inputs, infer_inputs): + annots, dataloaders = inputs + classifier = LightningClassifierContainer( + "resnet18", labels_map=annots.labels_map, dataloaders=dataloaders + ) + classifier.add_loss_fn() + classifier.initialize_optimizer() + classifier.initialize_scheduler() + classifier.load_dataset(infer_inputs, set_name="infer") + classifier.inference("infer") + + +# test train + + +def test_training_step(inputs, sample_dir): + """Smoke-test: Trainer.fit() with fast_dev_run=True runs one batch without error.""" + from torch.utils.data import DataLoader + + annots, _ = inputs + + # Build a labelled dataset from the images known to exist in sample_files. + # The inputs fixture dataloaders have absolute paths that are only valid locally. + train_df = pd.DataFrame( + { + "image_id": ["cropped_74488689.png", "cropped_74488689.png"], + "image_path": [ + f"{sample_dir}/cropped_74488689.png", + f"{sample_dir}/cropped_74488689.png", + ], + "label": ["no", "railspace"], + "label_index": [0, 1], + } + ) + train_dataset = PatchDataset( + train_df, transform="train", label_col="label", label_index_col="label_index" + ) + train_loader = DataLoader(train_dataset, batch_size=2) + + classifier = LightningClassifierContainer("resnet18", labels_map=annots.labels_map) + classifier.add_loss_fn("cross entropy") + classifier.initialize_optimizer("adam") + classifier.initialize_scheduler() + + trainer = Trainer( + max_epochs=1, + fast_dev_run=True, + accelerator="cpu", + enable_progress_bar=False, + enable_model_summary=False, + logger=False, + ) + trainer.fit( + classifier, + train_dataloaders=train_loader, + val_dataloaders=train_loader, + ) + + +def test_predict_step(inputs, infer_inputs): + """Smoke-test: Trainer.predict() populates pred_label.""" + annots, dataloaders = inputs + classifier = LightningClassifierContainer("resnet18", labels_map=annots.labels_map) + classifier.add_loss_fn("cross entropy") + classifier.initialize_optimizer("adam") + classifier.initialize_scheduler() + + from torch.utils.data import DataLoader + + infer_loader = DataLoader(infer_inputs, batch_size=1) + + trainer = Trainer( + accelerator="cpu", + enable_progress_bar=False, + enable_model_summary=False, + logger=False, + ) + trainer.predict(classifier, dataloaders=infer_loader) + # After predict, predictions should have been collected + assert len(classifier.pred_label) > 0 or len(classifier.pred_label_indices) > 0