diff --git a/ahcore/callbacks/log_images_callback.py b/ahcore/callbacks/log_images_callback.py new file mode 100644 index 0000000..3ae48cd --- /dev/null +++ b/ahcore/callbacks/log_images_callback.py @@ -0,0 +1,205 @@ +from typing import Any, List, Optional, Tuple + +import matplotlib.pyplot as plt +import numpy as np +import pytorch_lightning as pl +import torch +import torch.nn.functional as F +from matplotlib.colors import to_rgb + +from ahcore.metrics.metrics import _compute_dice +from ahcore.utils.callbacks import AhCoreLogger +from ahcore.utils.types import GenericNumberArray, ScannerVendors + + +def get_sample_wise_dice_components( + predictions: torch.Tensor, + target: torch.Tensor, + roi: torch.Tensor | None, + num_classes: int, +) -> List[List[Tuple[torch.Tensor, torch.Tensor]]]: + # Apply softmax along the class dimension + soft_predictions = F.softmax(predictions, dim=1) + + # Get the predicted class per pixel + predictions = soft_predictions.argmax(dim=1) + + # Get the target class per pixel + _target = target.argmax(dim=1) + + # Initialize list to store dice components for each sample + dice_components = [] + + # Loop over the batch (B samples) + for batch_idx in range(predictions.size(0)): # Loop through the batch dimension (B) + batch_dice_components = [] + + for class_idx in range(num_classes): # Loop through the classes + # Get predictions and target for the current class and sample + curr_predictions = (predictions[batch_idx] == class_idx).int() + curr_target = (_target[batch_idx] == class_idx).int() + + if roi is not None: + # Apply ROI if it's provided + curr_roi = roi[batch_idx].squeeze(0) # Adjust ROI for current sample + intersection = torch.sum((curr_predictions * curr_target) * curr_roi, dim=(0, 1)) + cardinality = torch.sum(curr_predictions * curr_roi, dim=(0, 1)) + torch.sum( + curr_target * curr_roi, dim=(0, 1) + ) + else: + # No ROI, just compute intersection and cardinality + intersection = torch.sum(curr_predictions * curr_target, dim=(0, 1)) + cardinality = torch.sum(curr_predictions, dim=(0, 1)) + torch.sum(curr_target, dim=(0, 1)) + + batch_dice_components.append((intersection, cardinality)) + + dice_components.append(batch_dice_components) + + return dice_components + + +def get_sample_wise_dice( + outputs: dict[str, Any], batch: dict[str, Any], roi: torch.Tensor, num_classes: int +) -> List[dict[int, float]]: + dices = [] + dice_components = get_sample_wise_dice_components(outputs["prediction"], batch["target"], roi, num_classes) + for samplewise_dice_components in dice_components: + classwise_dices_for_a_sample: dict[int, float] = {1: 0.0, 2: 0.0, 3: 0.0} + for class_idx in range(num_classes): + if class_idx == 0: + continue + intersection, cardinality = samplewise_dice_components[class_idx] + dice_score = _compute_dice(intersection, cardinality) + classwise_dices_for_a_sample[class_idx] = dice_score.item() + dices.append(classwise_dices_for_a_sample) + return dices + + +def _extract_scanner_name(path: str) -> str: + # Extract file extension + extension = path.split(".")[-1] + # Use the ScannerEnum to get the scanner name + return ScannerVendors.get_vendor_name(extension) + + +class LogImagesCallback(pl.Callback): + def __init__( + self, + color_map: dict[int, str], + num_classes: int, + plot_dice: bool = True, + plot_scanner_wise: bool = False, + plot_every_n_epochs: int = 10, + ): + super().__init__() + self.color_map = {k: np.array(to_rgb(v)) * 255 for k, v in color_map.items()} + self._num_classes = num_classes + self._already_seen_scanner: List[str] = [] + self._plot_scanner_wise = plot_scanner_wise + self._plot_every_n_epochs = plot_every_n_epochs + self._plot_dice = plot_dice + self._logger: Optional[AhCoreLogger | None] = None + + def on_validation_batch_end( + self, + trainer: "pl.Trainer", + pl_module: "pl.LightningModule", + outputs: dict[str, Any], # type: ignore + batch: dict[str, Any], + batch_idx: int, + dataloader_idx: int = 0, + ) -> None: + if self._logger is None: + self._logger = AhCoreLogger(pl_module.logger) + if trainer.current_epoch % self._plot_every_n_epochs == 0: + val_images = batch["image"] + roi = batch["roi"] + path = batch["path"][0] + + if self._plot_scanner_wise: + scanner_name = _extract_scanner_name(path) + else: + scanner_name = None + + if self._plot_dice: + dices = get_sample_wise_dice(outputs, batch, roi, self._num_classes) + else: + dices = None + + val_images_numpy = val_images.permute(0, 2, 3, 1).detach().cpu().numpy() + val_images_numpy = (val_images_numpy - val_images_numpy.min()) / ( + val_images_numpy.max() - val_images_numpy.min() + ) + + val_predictions = outputs["prediction"] + val_predictions_numpy = val_predictions.permute(0, 2, 3, 1).detach().cpu().numpy() + + val_targets = batch["target"] + val_targets_numpy = val_targets.permute(0, 2, 3, 1).detach().cpu().numpy() + + if scanner_name not in self._already_seen_scanner: + self._plot_and_log( + val_images_numpy, + val_predictions_numpy, + val_targets_numpy, + trainer.global_step, + batch_idx, + scanner_name, + dices, + ) + + if self._plot_scanner_wise and scanner_name is not None: + self._already_seen_scanner.append(scanner_name) + + def on_validation_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + self._already_seen_scanner = [] + + def _plot_and_log( + self, + images_numpy: GenericNumberArray, + predictions_numpy: GenericNumberArray, + targets_numpy: GenericNumberArray, + step: int, + batch_idx: int, + scanner_name: Optional[str | None] = None, + dices: Optional[List[dict[int, float]] | None] = None, + ) -> None: + class_wise_dice = None + batch_size = images_numpy.shape[0] + figure = plt.figure(figsize=(15, batch_size * 5)) + for i in range(batch_size): + if dices is not None: + class_wise_dice = dices[i] + plt.subplot(batch_size, 3, i * 3 + 1) + plt.imshow(images_numpy[i]) + plt.axis("off") + if scanner_name is not None: + plt.title(f"Original Image (Scanner: {scanner_name})") + + plt.subplot(batch_size, 3, i * 3 + 2) + class_indices_gt = np.argmax(targets_numpy[i], axis=-1) + colored_img_gt = apply_color_map(class_indices_gt, self.color_map, self._num_classes) + plt.imshow(colored_img_gt) + plt.axis("off") + + plt.subplot(batch_size, 3, i * 3 + 3) + class_indices_pred = np.argmax(predictions_numpy[i], axis=-1) + colored_img_pred = apply_color_map(class_indices_pred, self.color_map, self._num_classes) + if dices is not None and class_wise_dice is not None: + dice_values = " ".join([f"{class_wise_dice[i]:.2f}" for i in range(1, self._num_classes)]) + plt.title(f"Dice: {dice_values}") + plt.imshow(colored_img_pred) + plt.axis("off") + plt.tight_layout() + if self._logger: # This is for mypy + self._logger.log_figure(figure, step, batch_idx) + plt.close() + + +def apply_color_map( + image: GenericNumberArray, color_map: dict[int, Any], num_classes: int +) -> np.ndarray[Any, np.dtype[np.uint8]]: + colored_image = np.zeros((*image.shape, 3), dtype=np.uint8) + for i in range(1, num_classes): + colored_image[image == i] = color_map[i] + return colored_image diff --git a/ahcore/callbacks/scanner_tile_metrics_callback.py b/ahcore/callbacks/scanner_tile_metrics_callback.py new file mode 100644 index 0000000..e0a34c9 --- /dev/null +++ b/ahcore/callbacks/scanner_tile_metrics_callback.py @@ -0,0 +1,73 @@ +from typing import Any, Optional + +import pytorch_lightning as pl + +from ahcore.metrics import TileMetric +from ahcore.utils.callbacks import AhCoreLogger +from ahcore.utils.types import ScannerVendors + + +class ScannerTileMetricsCallback(pl.Callback): + """ + This callback is used to track several `TileMetric` from ahcore per scanner. + The callback works on certain assumptions: + - You want to track metrics corresponding to each class in the `index_map` + - Each metric is tracked per scanner + """ + + def __init__(self, metrics: list[TileMetric], index_map: dict[str, int]) -> None: + super().__init__() + self.metrics = metrics + self.index_map = index_map + self._metrics_per_scanner = { + scanner.scanner_name: { + f"{metric.name}/{class_name}": 0.0 for class_name in self.index_map.keys() for metric in self.metrics + } + for scanner in ScannerVendors + } + self._batch_count_per_scanner = {scanner.scanner_name: 0 for scanner in ScannerVendors} + self._logger: Optional[AhCoreLogger] = None + + def on_validation_batch_end( + self, + trainer: "pl.Trainer", + pl_module: "pl.LightningModule", + outputs: dict[str, Any], # type: ignore + batch: dict[str, Any], + batch_idx: int, + dataloader_idx: int = 0, + ) -> None: + if not self._logger: + self._logger = AhCoreLogger(pl_module.logger) + + path = batch["path"][0] + file_extension = path.split(".")[-1] + scanner_name = ScannerVendors.get_vendor_name(file_extension) + + prediction = outputs["prediction"] + target = batch["target"] + roi = batch.get("roi", None) + + for metric in self.metrics: + batch_metrics = metric(prediction, target, roi) + for class_name, class_index in self.index_map.items(): + metric_key = f"{metric.name}/{class_name}" + self._metrics_per_scanner[scanner_name][metric_key] += batch_metrics[metric_key].item() + + self._batch_count_per_scanner[scanner_name] += 1 + + def on_validation_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + for scanner_name, metrics in self._metrics_per_scanner.items(): + batch_count = self._batch_count_per_scanner[scanner_name] + if batch_count > 0: + if self._logger: # This is for mypy + averaged_metrics = {f"{scanner_name}/{key}": value / batch_count for key, value in metrics.items()} + self._logger.log_metrics(averaged_metrics, step=trainer.global_step) + + self._metrics_per_scanner = { + scanner.scanner_name: { + f"{metric.name}/{class_name}": 0.0 for class_name in self.index_map.keys() for metric in self.metrics + } + for scanner in ScannerVendors + } + self._batch_count_per_scanner = {scanner.scanner_name: 0 for scanner in ScannerVendors} diff --git a/ahcore/utils/callbacks.py b/ahcore/utils/callbacks.py index 674bde7..71118fe 100644 --- a/ahcore/utils/callbacks.py +++ b/ahcore/utils/callbacks.py @@ -14,6 +14,8 @@ from dlup.annotations import WsiAnnotations from dlup.data.transforms import convert_annotations, rename_labels from dlup.tiling import Grid, GridOrder, TilingMode +from matplotlib.figure import Figure +from pytorch_lightning.loggers import Logger from shapely.geometry import MultiPoint, Point from torch.utils.data import Dataset @@ -21,7 +23,7 @@ from ahcore.transforms.pre_transforms import one_hot_encoding from ahcore.utils.data import DataDescription from ahcore.utils.io import get_logger -from ahcore.utils.types import DlupDatasetSample +from ahcore.utils.types import DlupDatasetSample, LoggerEnum logger = get_logger(__name__) @@ -241,3 +243,36 @@ def get_output_filename(dump_dir: Path, input_path: Path, model_name: str, count if counter is not None: return dump_dir / "outputs" / model_name / f"{counter}" / f"{hex_dig}.cache" return dump_dir / "outputs" / model_name / f"{hex_dig}.cache" + + +class AhCoreLogger: + def __init__(self, pl_logger: Logger | Any) -> None: + self.logger = pl_logger + + def get_logger_type(self) -> LoggerEnum: + if hasattr(self.logger.experiment, "log_figure"): + return LoggerEnum.MLFLOW + elif hasattr(self.logger.experiment, "add_figure"): + return LoggerEnum.TENSORBOARD + else: + return LoggerEnum.UNKNOWN + + def log_figure(self, figure: Figure, step: int, batch_idx: int) -> None: + if self.get_logger_type() == LoggerEnum.MLFLOW: + artifact_file_name = f"validation_global_step{step:03d}_batch{batch_idx:03d}.png" + self.logger.experiment.log_figure(self.logger.run_id, figure, artifact_file=artifact_file_name) + elif self.get_logger_type() == LoggerEnum.TENSORBOARD: + self.logger.experiment.add_figure(f"validation_step_{step}_batch_{batch_idx}", figure, global_step=step) + else: + raise NotImplementedError(f"Logging method for logger {type(self.logger).__name__} not implemented.") + + def log_metrics(self, metrics: dict[str, Any], step: int) -> None: + logger_type = self.get_logger_type() + if logger_type == LoggerEnum.MLFLOW: + for metric_name, value in metrics.items(): + self.logger.experiment.log_metric(self.logger.run_id, key=metric_name, value=value, step=step) + elif logger_type == LoggerEnum.TENSORBOARD: + for metric_name, value in metrics.items(): + self.logger.experiment.add_scalar(metric_name, value, global_step=step) + else: + raise NotImplementedError(f"Logging method for logger {type(self.logger).__name__} not implemented.") diff --git a/ahcore/utils/types.py b/ahcore/utils/types.py index ba0b137..3da2bc7 100644 --- a/ahcore/utils/types.py +++ b/ahcore/utils/types.py @@ -86,3 +86,28 @@ class ViTEmbedMode(str, Enum): CONCAT_MEAN = "embed_concat_mean" CONCAT = "embed_concat" # Extend as necessary + + +class ScannerVendors(Enum): + SVS = ("svs", "Aperio") + MRXS = ("mrxs", "3DHistech") + DEFAULT = ("default", "Unknown Scanner") + + def __init__(self, extension: str, scanner_name: str) -> None: + self.extension = extension + self.scanner_name = scanner_name + + @classmethod + def get_vendor_name(cls, file_extension: str) -> str: + for scanner in cls: + if scanner.extension == file_extension: + return scanner.scanner_name + # Return a default value if extension is not found + return cls.DEFAULT.scanner_name + + +class LoggerEnum(Enum): + TENSORBOARD = "tensorboard" + MLFLOW = "mlflow" + UNKNOWN = "unknown" + # Extend as necessary diff --git a/config/callbacks/log_images_callback.yaml b/config/callbacks/log_images_callback.yaml new file mode 100644 index 0000000..19c9d57 --- /dev/null +++ b/config/callbacks/log_images_callback.yaml @@ -0,0 +1,7 @@ +log_images_callback: + _target_: ahcore.callbacks.log_images_callback.LogImagesCallback + color_map: ${data_description.color_map} + num_classes: ${data_description.num_classes} + plot_scanner_wise: True + plot_every_n_epochs: 10 + plot_dice: True \ No newline at end of file diff --git a/config/callbacks/scanner_tile_metrics.yaml b/config/callbacks/scanner_tile_metrics.yaml new file mode 100644 index 0000000..49cf7cb --- /dev/null +++ b/config/callbacks/scanner_tile_metrics.yaml @@ -0,0 +1,6 @@ +scanner_tile_metrics_callback: + _target_: ahcore.callbacks.scanner_tile_metrics_callback.ScannerTileMetricsCallback + metrics: + - _target_: ahcore.metrics.DiceMetric + data_description: ${data_description} + index_map: ${data_description.index_map} \ No newline at end of file