From f58117daf94ae4a26dd917b6c02f417073030392 Mon Sep 17 00:00:00 2001 From: Ajey Pai K Date: Fri, 13 Sep 2024 10:45:54 +0200 Subject: [PATCH 01/24] implemented LogImagesCallback --- ahcore/callbacks/log_images_callback.py | 75 +++++++++++++++++++++++++ 1 file changed, 75 insertions(+) create mode 100644 ahcore/callbacks/log_images_callback.py diff --git a/ahcore/callbacks/log_images_callback.py b/ahcore/callbacks/log_images_callback.py new file mode 100644 index 0000000..b35471b --- /dev/null +++ b/ahcore/callbacks/log_images_callback.py @@ -0,0 +1,75 @@ +import pytorch_lightning as pl +import matplotlib.pyplot as plt +import numpy as np +import io +from PIL import Image +from matplotlib.colors import to_rgb + + +class LogImagesCallback(pl.Callback): + def __init__(self, color_map): + super().__init__() + self.color_map = {k: np.array(to_rgb(v)) * 255 for k, v in color_map.items()} + + def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx=0) -> None: + val_images = batch['image'] + val_images_numpy = val_images.permute(0, 2, 3, 1).detach().cpu().numpy() + val_images_numpy = (val_images_numpy - val_images_numpy.min()) / (val_images_numpy.max() - val_images_numpy.min()) + + val_predictions = outputs['prediction'] + val_predictions_numpy = val_predictions.permute(0, 2, 3, 1).detach().cpu().numpy() + + val_targets = batch['target'] + val_targets_numpy = val_targets.permute(0, 2, 3, 1).detach().cpu().numpy() + + self._plot_and_log(val_images_numpy, val_predictions_numpy, val_targets_numpy, pl_module, trainer.global_step, batch_idx) + + def _plot_and_log(self, images_numpy, predictions_numpy, targets_numpy, pl_module, step, batch_idx) -> None: + batch_size = images_numpy.shape[0] + plt.figure(figsize=(15, batch_size * 5)) # Adjust the figure size to fit the grid + + for i in range(batch_size): + # Plot the original image + plt.subplot(batch_size, 3, i * 3 + 1) + plt.imshow(images_numpy[i]) + plt.title("Original Image") + plt.axis("off") + + # Plot the ground truth mask + plt.subplot(batch_size, 3, i * 3 + 2) + class_indices_gt = np.argmax(targets_numpy[i], axis=-1) + colored_img_gt = apply_color_map(class_indices_gt, self.color_map) # Apply color map to ground truth + plt.imshow(colored_img_gt) + plt.title("Ground Truth Mask") + plt.axis("off") + + # Plot the prediction mask + plt.subplot(batch_size, 3, i * 3 + 3) + class_indices_pred = np.argmax(predictions_numpy[i], axis=-1) + colored_img_pred = apply_color_map(class_indices_pred, self.color_map) # Apply color map to prediction + plt.imshow(colored_img_pred) + plt.title("Prediction Mask") + plt.axis("off") + + # Save the figure to a buffer + buf = io.BytesIO() + plt.savefig(buf, format="png") + plt.close() + buf.seek(0) + image_grid = Image.open(buf) + + # Log the image using the logger's experiment interface and mlflow log_image + artifact_file_name = f"validation_global_step{step:03d}_batch{batch_idx:03d}.png" + pl_module.logger.experiment.log_image( + pl_module.logger.run_id, + image_grid, + artifact_file=artifact_file_name + ) + buf.close() + + +def apply_color_map(image, color_map): + colored_image = np.zeros((*image.shape, 3), dtype=np.uint8) + for i in range(1, 4): # Assuming classes are 1, 2, 3 + colored_image[image == i] = color_map[i] + return colored_image From 9066617d595ef6edfa4528eaa4e3421c03940182 Mon Sep 17 00:00:00 2001 From: Ajey Pai K Date: Fri, 13 Sep 2024 13:46:19 +0200 Subject: [PATCH 02/24] Remove unnecessary buffer --- ahcore/callbacks/log_images_callback.py | 27 ++++++++----------------- 1 file changed, 8 insertions(+), 19 deletions(-) diff --git a/ahcore/callbacks/log_images_callback.py b/ahcore/callbacks/log_images_callback.py index b35471b..b432046 100644 --- a/ahcore/callbacks/log_images_callback.py +++ b/ahcore/callbacks/log_images_callback.py @@ -1,8 +1,6 @@ import pytorch_lightning as pl import matplotlib.pyplot as plt import numpy as np -import io -from PIL import Image from matplotlib.colors import to_rgb @@ -14,7 +12,8 @@ def __init__(self, color_map): def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx=0) -> None: val_images = batch['image'] val_images_numpy = val_images.permute(0, 2, 3, 1).detach().cpu().numpy() - val_images_numpy = (val_images_numpy - val_images_numpy.min()) / (val_images_numpy.max() - val_images_numpy.min()) + val_images_numpy = (val_images_numpy - val_images_numpy.min()) / ( + val_images_numpy.max() - val_images_numpy.min()) val_predictions = outputs['prediction'] val_predictions_numpy = val_predictions.permute(0, 2, 3, 1).detach().cpu().numpy() @@ -22,17 +21,17 @@ def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, val_targets = batch['target'] val_targets_numpy = val_targets.permute(0, 2, 3, 1).detach().cpu().numpy() - self._plot_and_log(val_images_numpy, val_predictions_numpy, val_targets_numpy, pl_module, trainer.global_step, batch_idx) + self._plot_and_log(val_images_numpy, val_predictions_numpy, val_targets_numpy, pl_module, trainer.global_step, + batch_idx) def _plot_and_log(self, images_numpy, predictions_numpy, targets_numpy, pl_module, step, batch_idx) -> None: batch_size = images_numpy.shape[0] - plt.figure(figsize=(15, batch_size * 5)) # Adjust the figure size to fit the grid + figure = plt.figure(figsize=(15, batch_size * 5)) # Adjust the figure size to fit the grid for i in range(batch_size): # Plot the original image plt.subplot(batch_size, 3, i * 3 + 1) plt.imshow(images_numpy[i]) - plt.title("Original Image") plt.axis("off") # Plot the ground truth mask @@ -40,7 +39,6 @@ def _plot_and_log(self, images_numpy, predictions_numpy, targets_numpy, pl_modul class_indices_gt = np.argmax(targets_numpy[i], axis=-1) colored_img_gt = apply_color_map(class_indices_gt, self.color_map) # Apply color map to ground truth plt.imshow(colored_img_gt) - plt.title("Ground Truth Mask") plt.axis("off") # Plot the prediction mask @@ -48,24 +46,15 @@ def _plot_and_log(self, images_numpy, predictions_numpy, targets_numpy, pl_modul class_indices_pred = np.argmax(predictions_numpy[i], axis=-1) colored_img_pred = apply_color_map(class_indices_pred, self.color_map) # Apply color map to prediction plt.imshow(colored_img_pred) - plt.title("Prediction Mask") plt.axis("off") - # Save the figure to a buffer - buf = io.BytesIO() - plt.savefig(buf, format="png") - plt.close() - buf.seek(0) - image_grid = Image.open(buf) - - # Log the image using the logger's experiment interface and mlflow log_image artifact_file_name = f"validation_global_step{step:03d}_batch{batch_idx:03d}.png" - pl_module.logger.experiment.log_image( + pl_module.logger.experiment.log_figure( pl_module.logger.run_id, - image_grid, + figure, artifact_file=artifact_file_name ) - buf.close() + plt.close() def apply_color_map(image, color_map): From a0af38752387c665704a3a170e1e3f496545598b Mon Sep 17 00:00:00 2001 From: Ajey Pai K Date: Tue, 17 Sep 2024 14:07:45 +0200 Subject: [PATCH 03/24] Only log the first batch. --- ahcore/callbacks/log_images_callback.py | 21 +++++++++++---------- 1 file changed, 11 insertions(+), 10 deletions(-) diff --git a/ahcore/callbacks/log_images_callback.py b/ahcore/callbacks/log_images_callback.py index b432046..61bcf6a 100644 --- a/ahcore/callbacks/log_images_callback.py +++ b/ahcore/callbacks/log_images_callback.py @@ -10,19 +10,20 @@ def __init__(self, color_map): self.color_map = {k: np.array(to_rgb(v)) * 255 for k, v in color_map.items()} def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx=0) -> None: - val_images = batch['image'] - val_images_numpy = val_images.permute(0, 2, 3, 1).detach().cpu().numpy() - val_images_numpy = (val_images_numpy - val_images_numpy.min()) / ( - val_images_numpy.max() - val_images_numpy.min()) + if batch_idx == 0: + val_images = batch['image'] + val_images_numpy = val_images.permute(0, 2, 3, 1).detach().cpu().numpy() + val_images_numpy = (val_images_numpy - val_images_numpy.min()) / ( + val_images_numpy.max() - val_images_numpy.min()) - val_predictions = outputs['prediction'] - val_predictions_numpy = val_predictions.permute(0, 2, 3, 1).detach().cpu().numpy() + val_predictions = outputs['prediction'] + val_predictions_numpy = val_predictions.permute(0, 2, 3, 1).detach().cpu().numpy() - val_targets = batch['target'] - val_targets_numpy = val_targets.permute(0, 2, 3, 1).detach().cpu().numpy() + val_targets = batch['target'] + val_targets_numpy = val_targets.permute(0, 2, 3, 1).detach().cpu().numpy() - self._plot_and_log(val_images_numpy, val_predictions_numpy, val_targets_numpy, pl_module, trainer.global_step, - batch_idx) + self._plot_and_log(val_images_numpy, val_predictions_numpy, val_targets_numpy, pl_module, trainer.global_step, + batch_idx) def _plot_and_log(self, images_numpy, predictions_numpy, targets_numpy, pl_module, step, batch_idx) -> None: batch_size = images_numpy.shape[0] From df1f8a364d0245d36eff0061e529d800675ad8e0 Mon Sep 17 00:00:00 2001 From: Ajey Pai K Date: Wed, 18 Sep 2024 11:36:41 +0200 Subject: [PATCH 04/24] Temporarily add the scanner name to logged images --- ahcore/callbacks/log_images_callback.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/ahcore/callbacks/log_images_callback.py b/ahcore/callbacks/log_images_callback.py index 61bcf6a..92c93f3 100644 --- a/ahcore/callbacks/log_images_callback.py +++ b/ahcore/callbacks/log_images_callback.py @@ -2,6 +2,7 @@ import matplotlib.pyplot as plt import numpy as np from matplotlib.colors import to_rgb +from dlup import SlideImage class LogImagesCallback(pl.Callback): @@ -10,8 +11,15 @@ def __init__(self, color_map): self.color_map = {k: np.array(to_rgb(v)) * 255 for k, v in color_map.items()} def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx=0) -> None: + already_seen_wsi = [] if batch_idx == 0: val_images = batch['image'] + if batch["path"][0] not in already_seen_wsi: + slide = SlideImage.from_file_path(batch["path"][0]) + already_seen_wsi.append(batch["path"][0]) + scanner_name = slide.properties.get('mirax.NONHIERLAYER_1_SECTION.SCANNER_HARDWARE_VERSION', None) + if scanner_name is None: + scanner_name = "Aperio" val_images_numpy = val_images.permute(0, 2, 3, 1).detach().cpu().numpy() val_images_numpy = (val_images_numpy - val_images_numpy.min()) / ( val_images_numpy.max() - val_images_numpy.min()) @@ -23,17 +31,17 @@ def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, val_targets_numpy = val_targets.permute(0, 2, 3, 1).detach().cpu().numpy() self._plot_and_log(val_images_numpy, val_predictions_numpy, val_targets_numpy, pl_module, trainer.global_step, - batch_idx) + batch_idx, scanner_name) - def _plot_and_log(self, images_numpy, predictions_numpy, targets_numpy, pl_module, step, batch_idx) -> None: + def _plot_and_log(self, images_numpy, predictions_numpy, targets_numpy, pl_module, step, batch_idx, scanner_name) -> None: batch_size = images_numpy.shape[0] figure = plt.figure(figsize=(15, batch_size * 5)) # Adjust the figure size to fit the grid - for i in range(batch_size): # Plot the original image plt.subplot(batch_size, 3, i * 3 + 1) plt.imshow(images_numpy[i]) plt.axis("off") + plt.title(f"Original Image (Scanner: {scanner_name})") # Plot the ground truth mask plt.subplot(batch_size, 3, i * 3 + 2) @@ -48,6 +56,7 @@ def _plot_and_log(self, images_numpy, predictions_numpy, targets_numpy, pl_modul colored_img_pred = apply_color_map(class_indices_pred, self.color_map) # Apply color map to prediction plt.imshow(colored_img_pred) plt.axis("off") + plt.tight_layout() artifact_file_name = f"validation_global_step{step:03d}_batch{batch_idx:03d}.png" pl_module.logger.experiment.log_figure( From 420dcf1ae8df70a97788de1d3dc6480c083a416f Mon Sep 17 00:00:00 2001 From: Ajey Pai K Date: Wed, 18 Sep 2024 12:37:30 +0200 Subject: [PATCH 05/24] empty list at validation epoch end. --- ahcore/callbacks/log_images_callback.py | 24 ++++++++++++++---------- 1 file changed, 14 insertions(+), 10 deletions(-) diff --git a/ahcore/callbacks/log_images_callback.py b/ahcore/callbacks/log_images_callback.py index 92c93f3..d18bd9d 100644 --- a/ahcore/callbacks/log_images_callback.py +++ b/ahcore/callbacks/log_images_callback.py @@ -2,24 +2,24 @@ import matplotlib.pyplot as plt import numpy as np from matplotlib.colors import to_rgb -from dlup import SlideImage class LogImagesCallback(pl.Callback): def __init__(self, color_map): super().__init__() self.color_map = {k: np.array(to_rgb(v)) * 255 for k, v in color_map.items()} + self._already_seen_scanner = [] def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx=0) -> None: - already_seen_wsi = [] - if batch_idx == 0: - val_images = batch['image'] - if batch["path"][0] not in already_seen_wsi: - slide = SlideImage.from_file_path(batch["path"][0]) - already_seen_wsi.append(batch["path"][0]) - scanner_name = slide.properties.get('mirax.NONHIERLAYER_1_SECTION.SCANNER_HARDWARE_VERSION', None) - if scanner_name is None: - scanner_name = "Aperio" + scanner_name = None + val_images = batch['image'] + path = batch['path'][0] + if path.split('.')[-1] == 'svs': + scanner_name = "Aperio" + elif path.split('.')[-1] == 'mrxs': + scanner_name = "P1000" + + if scanner_name not in self._already_seen_scanner: val_images_numpy = val_images.permute(0, 2, 3, 1).detach().cpu().numpy() val_images_numpy = (val_images_numpy - val_images_numpy.min()) / ( val_images_numpy.max() - val_images_numpy.min()) @@ -32,6 +32,10 @@ def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, self._plot_and_log(val_images_numpy, val_predictions_numpy, val_targets_numpy, pl_module, trainer.global_step, batch_idx, scanner_name) + self._already_seen_scanner.append(scanner_name) + + def on_validation_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + self._already_seen_scanner = [] def _plot_and_log(self, images_numpy, predictions_numpy, targets_numpy, pl_module, step, batch_idx, scanner_name) -> None: batch_size = images_numpy.shape[0] From 7b415e81c9c63932822d4e4079be7c941e43852f Mon Sep 17 00:00:00 2001 From: Ajey Pai K Date: Wed, 18 Sep 2024 14:07:00 +0200 Subject: [PATCH 06/24] Implement callback to track dice for every scanner type --- ahcore/callbacks/track_metrics_per_scanner.py | 65 +++++++++++++++++++ 1 file changed, 65 insertions(+) create mode 100644 ahcore/callbacks/track_metrics_per_scanner.py diff --git a/ahcore/callbacks/track_metrics_per_scanner.py b/ahcore/callbacks/track_metrics_per_scanner.py new file mode 100644 index 0000000..d6197a7 --- /dev/null +++ b/ahcore/callbacks/track_metrics_per_scanner.py @@ -0,0 +1,65 @@ +import pytorch_lightning as pl +from ahcore.metrics import TileMetric + + +class TrackMetricsPerScanner(pl.Callback): + def __init__(self, metrics: TileMetric): + super().__init__() + self.metrics = metrics[0] + + # Initialize accumulators for Aperio and P1000 metrics + self._aperio_metrics = {'dice/background': 0.0, 'dice/stroma': 0.0, 'dice/tumor': 0.0, 'dice/ignore': 0.0} + self._p1000_metrics = {'dice/background': 0.0, 'dice/stroma': 0.0, 'dice/tumor': 0.0, 'dice/ignore': 0.0} + + # Track the number of batches for each scanner + self._aperio_batch_count = 0 + self._p1000_batch_count = 0 + + def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx=0) -> None: + # Determine the scanner based on the file extension + scanner_name = None + path = batch['path'][0] + if path.split('.')[-1] == 'svs': + scanner_name = "Aperio" + elif path.split('.')[-1] == 'mrxs': + scanner_name = "P1000" + + prediction = outputs['prediction'] + target = batch['target'] + roi = batch.get('roi', None) + + # Get the metrics for the current batch + batch_metrics = self.metrics(prediction, target, roi) + + # Accumulate metrics based on the scanner + if scanner_name == "Aperio": + self._aperio_metrics['dice/background'] += batch_metrics['dice/background'].item() + self._aperio_metrics['dice/stroma'] += batch_metrics['dice/stroma'].item() + self._aperio_metrics['dice/tumor'] += batch_metrics['dice/tumor'].item() + self._aperio_metrics['dice/ignore'] += batch_metrics['dice/ignore'].item() + self._aperio_batch_count += 1 + + elif scanner_name == "P1000": + self._p1000_metrics['dice/background'] += batch_metrics['dice/background'].item() + self._p1000_metrics['dice/stroma'] += batch_metrics['dice/stroma'].item() + self._p1000_metrics['dice/tumor'] += batch_metrics['dice/tumor'].item() + self._p1000_metrics['dice/ignore'] += batch_metrics['dice/ignore'].item() + self._p1000_batch_count += 1 + + def on_validation_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + # Compute the average metrics for Aperio + if self._aperio_batch_count > 0: + averaged_aperio_metrics = {f"Aperio/{key}": value / self._aperio_batch_count for key, value in + self._aperio_metrics.items()} + trainer.logger.log_metrics(averaged_aperio_metrics) + + # Compute the average metrics for P1000 + if self._p1000_batch_count > 0: + averaged_p1000_metrics = {f"P1000/{key}": value / self._p1000_batch_count for key, value in + self._p1000_metrics.items()} + trainer.logger.log_metrics(averaged_p1000_metrics) + + self._aperio_metrics = {'dice/background': 0.0, 'dice/stroma': 0.0, 'dice/tumor': 0.0, 'dice/ignore': 0.0} + self._p1000_metrics = {'dice/background': 0.0, 'dice/stroma': 0.0, 'dice/tumor': 0.0, 'dice/ignore': 0.0} + self._aperio_batch_count = 0 + self._p1000_batch_count = 0 From df1b4c66756837712170f2dccc2572c261dcfa11 Mon Sep 17 00:00:00 2001 From: Ajey Pai K Date: Wed, 18 Sep 2024 14:22:56 +0200 Subject: [PATCH 07/24] Log metrics with global step. --- ahcore/callbacks/track_metrics_per_scanner.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ahcore/callbacks/track_metrics_per_scanner.py b/ahcore/callbacks/track_metrics_per_scanner.py index d6197a7..f868da9 100644 --- a/ahcore/callbacks/track_metrics_per_scanner.py +++ b/ahcore/callbacks/track_metrics_per_scanner.py @@ -51,13 +51,13 @@ def on_validation_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.Lightnin if self._aperio_batch_count > 0: averaged_aperio_metrics = {f"Aperio/{key}": value / self._aperio_batch_count for key, value in self._aperio_metrics.items()} - trainer.logger.log_metrics(averaged_aperio_metrics) + trainer.logger.log_metrics(averaged_aperio_metrics, step=trainer.global_step) # Compute the average metrics for P1000 if self._p1000_batch_count > 0: averaged_p1000_metrics = {f"P1000/{key}": value / self._p1000_batch_count for key, value in self._p1000_metrics.items()} - trainer.logger.log_metrics(averaged_p1000_metrics) + trainer.logger.log_metrics(averaged_p1000_metrics, step=trainer.global_step) self._aperio_metrics = {'dice/background': 0.0, 'dice/stroma': 0.0, 'dice/tumor': 0.0, 'dice/ignore': 0.0} self._p1000_metrics = {'dice/background': 0.0, 'dice/stroma': 0.0, 'dice/tumor': 0.0, 'dice/ignore': 0.0} From 450345ac7533cd6596e8c536f861a9e4b065d75d Mon Sep 17 00:00:00 2001 From: Ajey Pai K Date: Thu, 19 Sep 2024 13:45:17 +0200 Subject: [PATCH 08/24] implement weighted sampler for p1000 images --- ahcore/data/dataset.py | 32 ++++++++++++++++++++++++++++++-- 1 file changed, 30 insertions(+), 2 deletions(-) diff --git a/ahcore/data/dataset.py b/ahcore/data/dataset.py index 3eaed15..be63673 100644 --- a/ahcore/data/dataset.py +++ b/ahcore/data/dataset.py @@ -12,7 +12,7 @@ import pytorch_lightning as pl import torch from dlup.data.dataset import Dataset, TiledWsiDataset -from torch.utils.data import DataLoader, DistributedSampler, RandomSampler, Sampler, SequentialSampler +from torch.utils.data import DataLoader, DistributedSampler, WeightedRandomSampler, RandomSampler, Sampler, SequentialSampler from ahcore.utils.data import DataDescription, basemodel_to_uuid from ahcore.utils.debug_utils import time_it @@ -302,14 +302,42 @@ def _load_from_cache(self, func: Callable[[], Any], stage: str, *args: Any, **kw return obj + def _construct_weighted_sampler(self, dataset: ConcatDataset) -> WeightedRandomSampler: + """Constructs a weighted sampler based on the .mrxs extension.""" + # Initialize weights list based on dataset length + weights = torch.ones(len(dataset)) # Initialize all weights to 1.0 + + # Loop through datasets and adjust weights for .mrxs datasets + for i, ds in enumerate(dataset.datasets): + # Check if the dataset has a .mrxs extension + if hasattr(ds, 'path') and ds.path.suffix == ".mrxs": + start_idx = dataset.cumulative_sizes[i - 1] if i > 0 else 0 # Get the start index for the current dataset + end_idx = start_idx + len(ds) # Get the end index for the current dataset + weights[start_idx:end_idx] = 10.0 + + sampler = WeightedRandomSampler(weights, num_samples=len(weights), replacement=True) + return sampler + def train_dataloader(self) -> Optional[DataLoader[DlupDatasetSample]]: if not self._fit_data_iterator: self.setup("fit") assert self._fit_data_iterator - return self._construct_concatenated_dataloader( + dataset = self._construct_concatenated_dataloader( self._fit_data_iterator, batch_size=self._batch_size, stage="fit", + ).dataset + + sampler = self._construct_weighted_sampler(dataset) + + return DataLoader( + dataset, + num_workers=self._num_workers, + sampler=sampler, + batch_size=self._batch_size, + drop_last=True, + persistent_workers=self._persistent_workers, + pin_memory=self._pin_memory, ) def val_dataloader(self) -> Optional[DataLoader[DlupDatasetSample]]: From 6b381dc8ddc09ae818e2bc6b25a73d6d4afe8973 Mon Sep 17 00:00:00 2001 From: Ajey Pai K Date: Fri, 20 Sep 2024 11:20:35 +0200 Subject: [PATCH 09/24] Remove weighted sampler, log images every 10 epochs --- ahcore/callbacks/log_images_callback.py | 37 +++++++++++++------------ ahcore/data/dataset.py | 32 ++------------------- 2 files changed, 21 insertions(+), 48 deletions(-) diff --git a/ahcore/callbacks/log_images_callback.py b/ahcore/callbacks/log_images_callback.py index d18bd9d..a6c560a 100644 --- a/ahcore/callbacks/log_images_callback.py +++ b/ahcore/callbacks/log_images_callback.py @@ -11,28 +11,29 @@ def __init__(self, color_map): self._already_seen_scanner = [] def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx=0) -> None: - scanner_name = None - val_images = batch['image'] - path = batch['path'][0] - if path.split('.')[-1] == 'svs': - scanner_name = "Aperio" - elif path.split('.')[-1] == 'mrxs': - scanner_name = "P1000" + if trainer.current_epoch % 10 == 0: + scanner_name = None + val_images = batch['image'] + path = batch['path'][0] + if path.split('.')[-1] == 'svs': + scanner_name = "Aperio" + elif path.split('.')[-1] == 'mrxs': + scanner_name = "P1000" - if scanner_name not in self._already_seen_scanner: - val_images_numpy = val_images.permute(0, 2, 3, 1).detach().cpu().numpy() - val_images_numpy = (val_images_numpy - val_images_numpy.min()) / ( - val_images_numpy.max() - val_images_numpy.min()) + if scanner_name not in self._already_seen_scanner: + val_images_numpy = val_images.permute(0, 2, 3, 1).detach().cpu().numpy() + val_images_numpy = (val_images_numpy - val_images_numpy.min()) / ( + val_images_numpy.max() - val_images_numpy.min()) - val_predictions = outputs['prediction'] - val_predictions_numpy = val_predictions.permute(0, 2, 3, 1).detach().cpu().numpy() + val_predictions = outputs['prediction'] + val_predictions_numpy = val_predictions.permute(0, 2, 3, 1).detach().cpu().numpy() - val_targets = batch['target'] - val_targets_numpy = val_targets.permute(0, 2, 3, 1).detach().cpu().numpy() + val_targets = batch['target'] + val_targets_numpy = val_targets.permute(0, 2, 3, 1).detach().cpu().numpy() - self._plot_and_log(val_images_numpy, val_predictions_numpy, val_targets_numpy, pl_module, trainer.global_step, - batch_idx, scanner_name) - self._already_seen_scanner.append(scanner_name) + self._plot_and_log(val_images_numpy, val_predictions_numpy, val_targets_numpy, pl_module, trainer.global_step, + batch_idx, scanner_name) + self._already_seen_scanner.append(scanner_name) def on_validation_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: self._already_seen_scanner = [] diff --git a/ahcore/data/dataset.py b/ahcore/data/dataset.py index be63673..3eaed15 100644 --- a/ahcore/data/dataset.py +++ b/ahcore/data/dataset.py @@ -12,7 +12,7 @@ import pytorch_lightning as pl import torch from dlup.data.dataset import Dataset, TiledWsiDataset -from torch.utils.data import DataLoader, DistributedSampler, WeightedRandomSampler, RandomSampler, Sampler, SequentialSampler +from torch.utils.data import DataLoader, DistributedSampler, RandomSampler, Sampler, SequentialSampler from ahcore.utils.data import DataDescription, basemodel_to_uuid from ahcore.utils.debug_utils import time_it @@ -302,42 +302,14 @@ def _load_from_cache(self, func: Callable[[], Any], stage: str, *args: Any, **kw return obj - def _construct_weighted_sampler(self, dataset: ConcatDataset) -> WeightedRandomSampler: - """Constructs a weighted sampler based on the .mrxs extension.""" - # Initialize weights list based on dataset length - weights = torch.ones(len(dataset)) # Initialize all weights to 1.0 - - # Loop through datasets and adjust weights for .mrxs datasets - for i, ds in enumerate(dataset.datasets): - # Check if the dataset has a .mrxs extension - if hasattr(ds, 'path') and ds.path.suffix == ".mrxs": - start_idx = dataset.cumulative_sizes[i - 1] if i > 0 else 0 # Get the start index for the current dataset - end_idx = start_idx + len(ds) # Get the end index for the current dataset - weights[start_idx:end_idx] = 10.0 - - sampler = WeightedRandomSampler(weights, num_samples=len(weights), replacement=True) - return sampler - def train_dataloader(self) -> Optional[DataLoader[DlupDatasetSample]]: if not self._fit_data_iterator: self.setup("fit") assert self._fit_data_iterator - dataset = self._construct_concatenated_dataloader( + return self._construct_concatenated_dataloader( self._fit_data_iterator, batch_size=self._batch_size, stage="fit", - ).dataset - - sampler = self._construct_weighted_sampler(dataset) - - return DataLoader( - dataset, - num_workers=self._num_workers, - sampler=sampler, - batch_size=self._batch_size, - drop_last=True, - persistent_workers=self._persistent_workers, - pin_memory=self._pin_memory, ) def val_dataloader(self) -> Optional[DataLoader[DlupDatasetSample]]: From 06eed3c8c16e41d80d753a6903d8a374984a7c7e Mon Sep 17 00:00:00 2001 From: Ajey Pai K Date: Thu, 26 Sep 2024 23:33:33 +0200 Subject: [PATCH 10/24] plot samplewise dice score during logging --- ahcore/callbacks/log_images_callback.py | 152 +++++++++++++++++++----- ahcore/utils/types.py | 18 +++ 2 files changed, 143 insertions(+), 27 deletions(-) diff --git a/ahcore/callbacks/log_images_callback.py b/ahcore/callbacks/log_images_callback.py index a6c560a..2f8bf22 100644 --- a/ahcore/callbacks/log_images_callback.py +++ b/ahcore/callbacks/log_images_callback.py @@ -3,62 +3,160 @@ import numpy as np from matplotlib.colors import to_rgb +from typing import List, Tuple, Any +import torch +import torch.nn.functional as F +from ahcore.metrics.metrics import _compute_dice + +from ahcore.utils.types import ScannerEnum + + +def get_sample_wise_dice_components( + predictions: torch.Tensor, + target: torch.Tensor, + roi: torch.Tensor | None, + num_classes: int, +) -> List[List[Tuple[torch.Tensor, torch.Tensor]]]: + # Apply softmax along the class dimension + soft_predictions = F.softmax(predictions, dim=1) + + # Get the predicted class per pixel + predictions = soft_predictions.argmax(dim=1) + + # Get the target class per pixel + _target = target.argmax(dim=1) + + # Initialize list to store dice components for each sample + dice_components = [] + + # Loop over the batch (B samples) + for batch_idx in range(predictions.size(0)): # Loop through the batch dimension (B) + batch_dice_components = [] + + for class_idx in range(num_classes): # Loop through the classes + # Get predictions and target for the current class and sample + curr_predictions = (predictions[batch_idx] == class_idx).int() + curr_target = (_target[batch_idx] == class_idx).int() + + if roi is not None: + # Apply ROI if it's provided + curr_roi = roi[batch_idx].squeeze(0) # Adjust ROI for current sample + intersection = torch.sum((curr_predictions * curr_target) * curr_roi, dim=(0, 1)) + cardinality = ( + torch.sum(curr_predictions * curr_roi, dim=(0, 1)) + + torch.sum(curr_target * curr_roi, dim=(0, 1)) + ) + else: + # No ROI, just compute intersection and cardinality + intersection = torch.sum(curr_predictions * curr_target, dim=(0, 1)) + cardinality = ( + torch.sum(curr_predictions, dim=(0, 1)) + + torch.sum(curr_target, dim=(0, 1)) + ) + + batch_dice_components.append((intersection, cardinality)) + + dice_components.append(batch_dice_components) + + return dice_components + + +def get_sample_wise_dice(outputs: dict[str, Any], batch: dict[str, Any], roi: torch.Tensor, num_classes: int) -> List[dict[int, float]]: + dices = [] + dice_components = get_sample_wise_dice_components(outputs['prediction'], batch['target'], roi, num_classes) + for samplewise_dice_components in dice_components: + classwise_dices_for_a_sample = {1: 0, 2: 0, 3: 0} + for class_idx in range(num_classes): + if class_idx == 0: + continue + intersection, cardinality = samplewise_dice_components[class_idx] + dice_score = _compute_dice(intersection, cardinality) + classwise_dices_for_a_sample[class_idx] = dice_score.item() + dices.append(classwise_dices_for_a_sample) + return dices + + +def _extract_scanner_name(path) -> str: + # Extract file extension + extension = path.split('.')[-1] + # Use the ScannerEnum to get the scanner name + return ScannerEnum.get_scanner_name(extension) + class LogImagesCallback(pl.Callback): - def __init__(self, color_map): + def __init__(self, color_map: dict[int, str], num_classes: int, + plot_dice: bool = True, + plot_scanner_wise: bool = False, + plot_every_n_epochs: int = 10): super().__init__() self.color_map = {k: np.array(to_rgb(v)) * 255 for k, v in color_map.items()} + self._num_classes = num_classes self._already_seen_scanner = [] + self._plot_scanner_wise = plot_scanner_wise + self._plot_every_n_epochs = plot_every_n_epochs + self._plot_dice = plot_dice def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx=0) -> None: - if trainer.current_epoch % 10 == 0: - scanner_name = None + + if trainer.current_epoch % self._plot_every_n_epochs == 0: val_images = batch['image'] + roi = batch['roi'] path = batch['path'][0] - if path.split('.')[-1] == 'svs': - scanner_name = "Aperio" - elif path.split('.')[-1] == 'mrxs': - scanner_name = "P1000" - if scanner_name not in self._already_seen_scanner: - val_images_numpy = val_images.permute(0, 2, 3, 1).detach().cpu().numpy() - val_images_numpy = (val_images_numpy - val_images_numpy.min()) / ( - val_images_numpy.max() - val_images_numpy.min()) + if self._plot_scanner_wise: + scanner_name = _extract_scanner_name(path) + else: + scanner_name = None + + if self._plot_dice: + dices = get_sample_wise_dice(outputs, batch, roi, self._num_classes) + else: + dices = None + + val_images_numpy = val_images.permute(0, 2, 3, 1).detach().cpu().numpy() + val_images_numpy = (val_images_numpy - val_images_numpy.min()) / ( + val_images_numpy.max() - val_images_numpy.min()) - val_predictions = outputs['prediction'] - val_predictions_numpy = val_predictions.permute(0, 2, 3, 1).detach().cpu().numpy() + val_predictions = outputs['prediction'] + val_predictions_numpy = val_predictions.permute(0, 2, 3, 1).detach().cpu().numpy() - val_targets = batch['target'] - val_targets_numpy = val_targets.permute(0, 2, 3, 1).detach().cpu().numpy() + val_targets = batch['target'] + val_targets_numpy = val_targets.permute(0, 2, 3, 1).detach().cpu().numpy() + + if scanner_name not in self._already_seen_scanner: + self._plot_and_log(val_images_numpy, val_predictions_numpy, val_targets_numpy, pl_module, + trainer.global_step, batch_idx, scanner_name, dices) - self._plot_and_log(val_images_numpy, val_predictions_numpy, val_targets_numpy, pl_module, trainer.global_step, - batch_idx, scanner_name) + if self._plot_scanner_wise and scanner_name is not None: self._already_seen_scanner.append(scanner_name) def on_validation_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: self._already_seen_scanner = [] - def _plot_and_log(self, images_numpy, predictions_numpy, targets_numpy, pl_module, step, batch_idx, scanner_name) -> None: + def _plot_and_log(self, images_numpy, predictions_numpy, targets_numpy, pl_module, step, batch_idx, scanner_name, + dices) -> None: batch_size = images_numpy.shape[0] - figure = plt.figure(figsize=(15, batch_size * 5)) # Adjust the figure size to fit the grid + figure = plt.figure(figsize=(15, batch_size * 5)) for i in range(batch_size): - # Plot the original image + if dices is not None: + class_wise_dice = dices[i] plt.subplot(batch_size, 3, i * 3 + 1) plt.imshow(images_numpy[i]) plt.axis("off") - plt.title(f"Original Image (Scanner: {scanner_name})") + if scanner_name is not None: + plt.title(f"Original Image (Scanner: {scanner_name})") - # Plot the ground truth mask plt.subplot(batch_size, 3, i * 3 + 2) class_indices_gt = np.argmax(targets_numpy[i], axis=-1) - colored_img_gt = apply_color_map(class_indices_gt, self.color_map) # Apply color map to ground truth + colored_img_gt = apply_color_map(class_indices_gt, self.color_map, self._num_classes) plt.imshow(colored_img_gt) plt.axis("off") - # Plot the prediction mask plt.subplot(batch_size, 3, i * 3 + 3) class_indices_pred = np.argmax(predictions_numpy[i], axis=-1) - colored_img_pred = apply_color_map(class_indices_pred, self.color_map) # Apply color map to prediction + colored_img_pred = apply_color_map(class_indices_pred, self.color_map, self._num_classes) + if dices is not None: + plt.title(f"Dice: {class_wise_dice[1]:.2f} {class_wise_dice[2]:.2f} {class_wise_dice[3]:.2f}") plt.imshow(colored_img_pred) plt.axis("off") plt.tight_layout() @@ -72,8 +170,8 @@ def _plot_and_log(self, images_numpy, predictions_numpy, targets_numpy, pl_modul plt.close() -def apply_color_map(image, color_map): +def apply_color_map(image, color_map, num_classes): colored_image = np.zeros((*image.shape, 3), dtype=np.uint8) - for i in range(1, 4): # Assuming classes are 1, 2, 3 + for i in range(1, num_classes-1): colored_image[image == i] = color_map[i] return colored_image diff --git a/ahcore/utils/types.py b/ahcore/utils/types.py index ba0b137..615f585 100644 --- a/ahcore/utils/types.py +++ b/ahcore/utils/types.py @@ -86,3 +86,21 @@ class ViTEmbedMode(str, Enum): CONCAT_MEAN = "embed_concat_mean" CONCAT = "embed_concat" # Extend as necessary + + +class ScannerEnum(Enum): + SVS = ("svs", "Aperio") + MRXS = ("mrxs", "P1000") + DEFAULT = ("default", "Unknown Scanner") + + def __init__(self, extension, scanner_name): + self.extension = extension + self.scanner_name = scanner_name + + @classmethod + def get_scanner_name(cls, file_extension): + for scanner in cls: + if scanner.extension == file_extension: + return scanner.scanner_name + # Return a default value if extension is not found + return cls.DEFAULT.scanner_name From 64269354c6a6c909732d3a1561736b939eaf6824 Mon Sep 17 00:00:00 2001 From: Ajey Pai K Date: Thu, 26 Sep 2024 23:38:04 +0200 Subject: [PATCH 11/24] some cleaning --- ahcore/callbacks/log_images_callback.py | 97 +++++++++++-------- ahcore/callbacks/track_metrics_per_scanner.py | 47 ++++----- 2 files changed, 81 insertions(+), 63 deletions(-) diff --git a/ahcore/callbacks/log_images_callback.py b/ahcore/callbacks/log_images_callback.py index 2f8bf22..88ef2ed 100644 --- a/ahcore/callbacks/log_images_callback.py +++ b/ahcore/callbacks/log_images_callback.py @@ -1,21 +1,21 @@ -import pytorch_lightning as pl +from typing import Any, List, Tuple + import matplotlib.pyplot as plt import numpy as np -from matplotlib.colors import to_rgb - -from typing import List, Tuple, Any +import pytorch_lightning as pl import torch import torch.nn.functional as F -from ahcore.metrics.metrics import _compute_dice +from matplotlib.colors import to_rgb +from ahcore.metrics.metrics import _compute_dice from ahcore.utils.types import ScannerEnum def get_sample_wise_dice_components( - predictions: torch.Tensor, - target: torch.Tensor, - roi: torch.Tensor | None, - num_classes: int, + predictions: torch.Tensor, + target: torch.Tensor, + roi: torch.Tensor | None, + num_classes: int, ) -> List[List[Tuple[torch.Tensor, torch.Tensor]]]: # Apply softmax along the class dimension soft_predictions = F.softmax(predictions, dim=1) @@ -42,17 +42,13 @@ def get_sample_wise_dice_components( # Apply ROI if it's provided curr_roi = roi[batch_idx].squeeze(0) # Adjust ROI for current sample intersection = torch.sum((curr_predictions * curr_target) * curr_roi, dim=(0, 1)) - cardinality = ( - torch.sum(curr_predictions * curr_roi, dim=(0, 1)) + - torch.sum(curr_target * curr_roi, dim=(0, 1)) + cardinality = torch.sum(curr_predictions * curr_roi, dim=(0, 1)) + torch.sum( + curr_target * curr_roi, dim=(0, 1) ) else: # No ROI, just compute intersection and cardinality intersection = torch.sum(curr_predictions * curr_target, dim=(0, 1)) - cardinality = ( - torch.sum(curr_predictions, dim=(0, 1)) + - torch.sum(curr_target, dim=(0, 1)) - ) + cardinality = torch.sum(curr_predictions, dim=(0, 1)) + torch.sum(curr_target, dim=(0, 1)) batch_dice_components.append((intersection, cardinality)) @@ -61,9 +57,11 @@ def get_sample_wise_dice_components( return dice_components -def get_sample_wise_dice(outputs: dict[str, Any], batch: dict[str, Any], roi: torch.Tensor, num_classes: int) -> List[dict[int, float]]: +def get_sample_wise_dice( + outputs: dict[str, Any], batch: dict[str, Any], roi: torch.Tensor, num_classes: int +) -> List[dict[int, float]]: dices = [] - dice_components = get_sample_wise_dice_components(outputs['prediction'], batch['target'], roi, num_classes) + dice_components = get_sample_wise_dice_components(outputs["prediction"], batch["target"], roi, num_classes) for samplewise_dice_components in dice_components: classwise_dices_for_a_sample = {1: 0, 2: 0, 3: 0} for class_idx in range(num_classes): @@ -78,16 +76,20 @@ def get_sample_wise_dice(outputs: dict[str, Any], batch: dict[str, Any], roi: to def _extract_scanner_name(path) -> str: # Extract file extension - extension = path.split('.')[-1] + extension = path.split(".")[-1] # Use the ScannerEnum to get the scanner name return ScannerEnum.get_scanner_name(extension) class LogImagesCallback(pl.Callback): - def __init__(self, color_map: dict[int, str], num_classes: int, - plot_dice: bool = True, - plot_scanner_wise: bool = False, - plot_every_n_epochs: int = 10): + def __init__( + self, + color_map: dict[int, str], + num_classes: int, + plot_dice: bool = True, + plot_scanner_wise: bool = False, + plot_every_n_epochs: int = 10, + ): super().__init__() self.color_map = {k: np.array(to_rgb(v)) * 255 for k, v in color_map.items()} self._num_classes = num_classes @@ -97,11 +99,10 @@ def __init__(self, color_map: dict[int, str], num_classes: int, self._plot_dice = plot_dice def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx=0) -> None: - if trainer.current_epoch % self._plot_every_n_epochs == 0: - val_images = batch['image'] - roi = batch['roi'] - path = batch['path'][0] + val_images = batch["image"] + roi = batch["roi"] + path = batch["path"][0] if self._plot_scanner_wise: scanner_name = _extract_scanner_name(path) @@ -115,17 +116,26 @@ def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, val_images_numpy = val_images.permute(0, 2, 3, 1).detach().cpu().numpy() val_images_numpy = (val_images_numpy - val_images_numpy.min()) / ( - val_images_numpy.max() - val_images_numpy.min()) + val_images_numpy.max() - val_images_numpy.min() + ) - val_predictions = outputs['prediction'] + val_predictions = outputs["prediction"] val_predictions_numpy = val_predictions.permute(0, 2, 3, 1).detach().cpu().numpy() - val_targets = batch['target'] + val_targets = batch["target"] val_targets_numpy = val_targets.permute(0, 2, 3, 1).detach().cpu().numpy() if scanner_name not in self._already_seen_scanner: - self._plot_and_log(val_images_numpy, val_predictions_numpy, val_targets_numpy, pl_module, - trainer.global_step, batch_idx, scanner_name, dices) + self._plot_and_log( + val_images_numpy, + val_predictions_numpy, + val_targets_numpy, + pl_module, + trainer.global_step, + batch_idx, + scanner_name, + dices, + ) if self._plot_scanner_wise and scanner_name is not None: self._already_seen_scanner.append(scanner_name) @@ -133,8 +143,17 @@ def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, def on_validation_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: self._already_seen_scanner = [] - def _plot_and_log(self, images_numpy, predictions_numpy, targets_numpy, pl_module, step, batch_idx, scanner_name, - dices) -> None: + def _plot_and_log( + self, + images_numpy: np.ndarray, + predictions_numpy: np.ndarray, + targets_numpy: np.ndarray, + pl_module: "pl.LightningModule", + step: int, + batch_idx: int, + scanner_name: str, + dices: List[dict[int, float]], + ) -> None: batch_size = images_numpy.shape[0] figure = plt.figure(figsize=(15, batch_size * 5)) for i in range(batch_size): @@ -162,16 +181,12 @@ def _plot_and_log(self, images_numpy, predictions_numpy, targets_numpy, pl_modul plt.tight_layout() artifact_file_name = f"validation_global_step{step:03d}_batch{batch_idx:03d}.png" - pl_module.logger.experiment.log_figure( - pl_module.logger.run_id, - figure, - artifact_file=artifact_file_name - ) + pl_module.logger.experiment.log_figure(pl_module.logger.run_id, figure, artifact_file=artifact_file_name) plt.close() -def apply_color_map(image, color_map, num_classes): +def apply_color_map(image, color_map, num_classes) -> np.ndarray: colored_image = np.zeros((*image.shape, 3), dtype=np.uint8) - for i in range(1, num_classes-1): + for i in range(1, num_classes - 1): colored_image[image == i] = color_map[i] return colored_image diff --git a/ahcore/callbacks/track_metrics_per_scanner.py b/ahcore/callbacks/track_metrics_per_scanner.py index f868da9..2052348 100644 --- a/ahcore/callbacks/track_metrics_per_scanner.py +++ b/ahcore/callbacks/track_metrics_per_scanner.py @@ -1,4 +1,5 @@ import pytorch_lightning as pl + from ahcore.metrics import TileMetric @@ -8,8 +9,8 @@ def __init__(self, metrics: TileMetric): self.metrics = metrics[0] # Initialize accumulators for Aperio and P1000 metrics - self._aperio_metrics = {'dice/background': 0.0, 'dice/stroma': 0.0, 'dice/tumor': 0.0, 'dice/ignore': 0.0} - self._p1000_metrics = {'dice/background': 0.0, 'dice/stroma': 0.0, 'dice/tumor': 0.0, 'dice/ignore': 0.0} + self._aperio_metrics = {"dice/background": 0.0, "dice/stroma": 0.0, "dice/tumor": 0.0, "dice/ignore": 0.0} + self._p1000_metrics = {"dice/background": 0.0, "dice/stroma": 0.0, "dice/tumor": 0.0, "dice/ignore": 0.0} # Track the number of batches for each scanner self._aperio_batch_count = 0 @@ -18,48 +19,50 @@ def __init__(self, metrics: TileMetric): def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx=0) -> None: # Determine the scanner based on the file extension scanner_name = None - path = batch['path'][0] - if path.split('.')[-1] == 'svs': + path = batch["path"][0] + if path.split(".")[-1] == "svs": scanner_name = "Aperio" - elif path.split('.')[-1] == 'mrxs': + elif path.split(".")[-1] == "mrxs": scanner_name = "P1000" - prediction = outputs['prediction'] - target = batch['target'] - roi = batch.get('roi', None) + prediction = outputs["prediction"] + target = batch["target"] + roi = batch.get("roi", None) # Get the metrics for the current batch batch_metrics = self.metrics(prediction, target, roi) # Accumulate metrics based on the scanner if scanner_name == "Aperio": - self._aperio_metrics['dice/background'] += batch_metrics['dice/background'].item() - self._aperio_metrics['dice/stroma'] += batch_metrics['dice/stroma'].item() - self._aperio_metrics['dice/tumor'] += batch_metrics['dice/tumor'].item() - self._aperio_metrics['dice/ignore'] += batch_metrics['dice/ignore'].item() + self._aperio_metrics["dice/background"] += batch_metrics["dice/background"].item() + self._aperio_metrics["dice/stroma"] += batch_metrics["dice/stroma"].item() + self._aperio_metrics["dice/tumor"] += batch_metrics["dice/tumor"].item() + self._aperio_metrics["dice/ignore"] += batch_metrics["dice/ignore"].item() self._aperio_batch_count += 1 elif scanner_name == "P1000": - self._p1000_metrics['dice/background'] += batch_metrics['dice/background'].item() - self._p1000_metrics['dice/stroma'] += batch_metrics['dice/stroma'].item() - self._p1000_metrics['dice/tumor'] += batch_metrics['dice/tumor'].item() - self._p1000_metrics['dice/ignore'] += batch_metrics['dice/ignore'].item() + self._p1000_metrics["dice/background"] += batch_metrics["dice/background"].item() + self._p1000_metrics["dice/stroma"] += batch_metrics["dice/stroma"].item() + self._p1000_metrics["dice/tumor"] += batch_metrics["dice/tumor"].item() + self._p1000_metrics["dice/ignore"] += batch_metrics["dice/ignore"].item() self._p1000_batch_count += 1 def on_validation_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: # Compute the average metrics for Aperio if self._aperio_batch_count > 0: - averaged_aperio_metrics = {f"Aperio/{key}": value / self._aperio_batch_count for key, value in - self._aperio_metrics.items()} + averaged_aperio_metrics = { + f"Aperio/{key}": value / self._aperio_batch_count for key, value in self._aperio_metrics.items() + } trainer.logger.log_metrics(averaged_aperio_metrics, step=trainer.global_step) # Compute the average metrics for P1000 if self._p1000_batch_count > 0: - averaged_p1000_metrics = {f"P1000/{key}": value / self._p1000_batch_count for key, value in - self._p1000_metrics.items()} + averaged_p1000_metrics = { + f"P1000/{key}": value / self._p1000_batch_count for key, value in self._p1000_metrics.items() + } trainer.logger.log_metrics(averaged_p1000_metrics, step=trainer.global_step) - self._aperio_metrics = {'dice/background': 0.0, 'dice/stroma': 0.0, 'dice/tumor': 0.0, 'dice/ignore': 0.0} - self._p1000_metrics = {'dice/background': 0.0, 'dice/stroma': 0.0, 'dice/tumor': 0.0, 'dice/ignore': 0.0} + self._aperio_metrics = {"dice/background": 0.0, "dice/stroma": 0.0, "dice/tumor": 0.0, "dice/ignore": 0.0} + self._p1000_metrics = {"dice/background": 0.0, "dice/stroma": 0.0, "dice/tumor": 0.0, "dice/ignore": 0.0} self._aperio_batch_count = 0 self._p1000_batch_count = 0 From 4b812dacf7f91a701d366c469d823d35452c167d Mon Sep 17 00:00:00 2001 From: Ajey Pai K Date: Fri, 27 Sep 2024 00:04:12 +0200 Subject: [PATCH 12/24] include ignore class while plotting colors. --- ahcore/callbacks/converters/tiff_callback.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ahcore/callbacks/converters/tiff_callback.py b/ahcore/callbacks/converters/tiff_callback.py index 41960ee..75fcb8f 100644 --- a/ahcore/callbacks/converters/tiff_callback.py +++ b/ahcore/callbacks/converters/tiff_callback.py @@ -115,7 +115,7 @@ def _write_tiff( file_reader: Type[FileImageReader], iterator_from_reader: Callable[[FileImageReader, tuple[int, int]], Iterator[npt.NDArray[np.int_]]], ) -> None: - with file_reader(filename, stitching_mode=StitchingMode.CROP) as cache_reader: + with file_reader(filename, stitching_mode=StitchingMode.AVERAGE) as cache_reader: writer = TifffileImageWriter( filename.with_suffix(".tiff"), size=cache_reader.size, From 7bf2e347b3a99123d458f75c621b9a1a6404bb5b Mon Sep 17 00:00:00 2001 From: Ajey Pai K Date: Fri, 27 Sep 2024 00:05:10 +0200 Subject: [PATCH 13/24] include ignore class while plotting colors. --- ahcore/callbacks/log_images_callback.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ahcore/callbacks/log_images_callback.py b/ahcore/callbacks/log_images_callback.py index 88ef2ed..b030a20 100644 --- a/ahcore/callbacks/log_images_callback.py +++ b/ahcore/callbacks/log_images_callback.py @@ -187,6 +187,6 @@ def _plot_and_log( def apply_color_map(image, color_map, num_classes) -> np.ndarray: colored_image = np.zeros((*image.shape, 3), dtype=np.uint8) - for i in range(1, num_classes - 1): + for i in range(1, num_classes): colored_image[image == i] = color_map[i] return colored_image From 20aaf3914411d879deb733dedd29f3a9ca44215b Mon Sep 17 00:00:00 2001 From: Ajey Pai K Date: Fri, 27 Sep 2024 00:06:04 +0200 Subject: [PATCH 14/24] Revert "include ignore class while plotting colors." This reverts commit 4b812dacf7f91a701d366c469d823d35452c167d. --- ahcore/callbacks/converters/tiff_callback.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ahcore/callbacks/converters/tiff_callback.py b/ahcore/callbacks/converters/tiff_callback.py index 75fcb8f..41960ee 100644 --- a/ahcore/callbacks/converters/tiff_callback.py +++ b/ahcore/callbacks/converters/tiff_callback.py @@ -115,7 +115,7 @@ def _write_tiff( file_reader: Type[FileImageReader], iterator_from_reader: Callable[[FileImageReader, tuple[int, int]], Iterator[npt.NDArray[np.int_]]], ) -> None: - with file_reader(filename, stitching_mode=StitchingMode.AVERAGE) as cache_reader: + with file_reader(filename, stitching_mode=StitchingMode.CROP) as cache_reader: writer = TifffileImageWriter( filename.with_suffix(".tiff"), size=cache_reader.size, From 10dcd28da1b232e92d92c60e054ae193b1ee29fd Mon Sep 17 00:00:00 2001 From: Ajey Pai K Date: Fri, 27 Sep 2024 00:14:10 +0200 Subject: [PATCH 15/24] Now, logging is possible for mlflow and tensorboard --- ahcore/callbacks/log_images_callback.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/ahcore/callbacks/log_images_callback.py b/ahcore/callbacks/log_images_callback.py index b030a20..628d77c 100644 --- a/ahcore/callbacks/log_images_callback.py +++ b/ahcore/callbacks/log_images_callback.py @@ -180,8 +180,17 @@ def _plot_and_log( plt.axis("off") plt.tight_layout() - artifact_file_name = f"validation_global_step{step:03d}_batch{batch_idx:03d}.png" - pl_module.logger.experiment.log_figure(pl_module.logger.run_id, figure, artifact_file=artifact_file_name) + logger = pl_module.logger + + if hasattr(logger.experiment, "log_figure"): # MLFlow logger case + artifact_file_name = f"validation_global_step{step:03d}_batch{batch_idx:03d}.png" + logger.experiment.log_figure(logger.run_id, figure, artifact_file=artifact_file_name) + elif hasattr(logger.experiment, "add_figure"): # TensorBoard logger case + logger.experiment.add_figure(f"validation_step_{step}_batch_{batch_idx}", figure, global_step=step) + else: + # If another logger is being used, raise a warning or add additional logic + raise NotImplementedError(f"Logging method for logger {type(logger).__name__} not implemented.") + plt.close() From 6f433a7564e62edaca6b9b30f11bec78575a2018 Mon Sep 17 00:00:00 2001 From: Ajey Pai K Date: Fri, 27 Sep 2024 11:51:06 +0200 Subject: [PATCH 16/24] Further Abstractions for loggers --- ahcore/callbacks/log_images_callback.py | 19 +--- ahcore/callbacks/track_metrics_per_scanner.py | 95 ++++++++----------- ahcore/utils/callbacks.py | 35 ++++++- ahcore/utils/types.py | 7 ++ 4 files changed, 88 insertions(+), 68 deletions(-) diff --git a/ahcore/callbacks/log_images_callback.py b/ahcore/callbacks/log_images_callback.py index 628d77c..9c5e8a1 100644 --- a/ahcore/callbacks/log_images_callback.py +++ b/ahcore/callbacks/log_images_callback.py @@ -8,6 +8,7 @@ from matplotlib.colors import to_rgb from ahcore.metrics.metrics import _compute_dice +from ahcore.utils.callbacks import AhCoreLogger from ahcore.utils.types import ScannerEnum @@ -97,8 +98,11 @@ def __init__( self._plot_scanner_wise = plot_scanner_wise self._plot_every_n_epochs = plot_every_n_epochs self._plot_dice = plot_dice + self._logger: AhCoreLogger | None = None def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx=0) -> None: + if self._logger is None: + self._logger = AhCoreLogger(pl_module.logger) if trainer.current_epoch % self._plot_every_n_epochs == 0: val_images = batch["image"] roi = batch["roi"] @@ -130,7 +134,6 @@ def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, val_images_numpy, val_predictions_numpy, val_targets_numpy, - pl_module, trainer.global_step, batch_idx, scanner_name, @@ -148,7 +151,6 @@ def _plot_and_log( images_numpy: np.ndarray, predictions_numpy: np.ndarray, targets_numpy: np.ndarray, - pl_module: "pl.LightningModule", step: int, batch_idx: int, scanner_name: str, @@ -179,18 +181,7 @@ def _plot_and_log( plt.imshow(colored_img_pred) plt.axis("off") plt.tight_layout() - - logger = pl_module.logger - - if hasattr(logger.experiment, "log_figure"): # MLFlow logger case - artifact_file_name = f"validation_global_step{step:03d}_batch{batch_idx:03d}.png" - logger.experiment.log_figure(logger.run_id, figure, artifact_file=artifact_file_name) - elif hasattr(logger.experiment, "add_figure"): # TensorBoard logger case - logger.experiment.add_figure(f"validation_step_{step}_batch_{batch_idx}", figure, global_step=step) - else: - # If another logger is being used, raise a warning or add additional logic - raise NotImplementedError(f"Logging method for logger {type(logger).__name__} not implemented.") - + self._logger.log_figure(figure, step, batch_idx) plt.close() diff --git a/ahcore/callbacks/track_metrics_per_scanner.py b/ahcore/callbacks/track_metrics_per_scanner.py index 2052348..9e0757b 100644 --- a/ahcore/callbacks/track_metrics_per_scanner.py +++ b/ahcore/callbacks/track_metrics_per_scanner.py @@ -1,68 +1,57 @@ import pytorch_lightning as pl from ahcore.metrics import TileMetric - - -class TrackMetricsPerScanner(pl.Callback): - def __init__(self, metrics: TileMetric): +from ahcore.utils.callbacks import AhCoreLogger +from ahcore.utils.types import ScannerEnum + + +class TrackTileMetricsPerScanner(pl.Callback): + """ + This callback is used to track several `TileMetric` from ahcore per scanner. + The callback works on certain assumptions: + - You want to track metrics corresponding to each class in the `index_map` + - Each metric is tracked per scanner + """ + def __init__(self, metrics: list[TileMetric], index_map: dict[str, int]): super().__init__() - self.metrics = metrics[0] - - # Initialize accumulators for Aperio and P1000 metrics - self._aperio_metrics = {"dice/background": 0.0, "dice/stroma": 0.0, "dice/tumor": 0.0, "dice/ignore": 0.0} - self._p1000_metrics = {"dice/background": 0.0, "dice/stroma": 0.0, "dice/tumor": 0.0, "dice/ignore": 0.0} - - # Track the number of batches for each scanner - self._aperio_batch_count = 0 - self._p1000_batch_count = 0 + self.metrics = metrics + self.index_map = index_map + self._metrics_per_scanner = { + scanner.scanner_name: {f"{metric.name}/{class_name}": 0.0 for class_name in self.index_map.keys() for metric in self.metrics} + for scanner in ScannerEnum + } + self._batch_count_per_scanner = {scanner.scanner_name: 0 for scanner in ScannerEnum} + self._logger: AhCoreLogger | None = None def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx=0) -> None: - # Determine the scanner based on the file extension - scanner_name = None + if self._logger is None: + self._logger = AhCoreLogger(pl_module.logger) + path = batch["path"][0] - if path.split(".")[-1] == "svs": - scanner_name = "Aperio" - elif path.split(".")[-1] == "mrxs": - scanner_name = "P1000" + file_extension = path.split(".")[-1] + scanner_name = ScannerEnum.get_scanner_name(file_extension) prediction = outputs["prediction"] target = batch["target"] roi = batch.get("roi", None) - # Get the metrics for the current batch - batch_metrics = self.metrics(prediction, target, roi) - - # Accumulate metrics based on the scanner - if scanner_name == "Aperio": - self._aperio_metrics["dice/background"] += batch_metrics["dice/background"].item() - self._aperio_metrics["dice/stroma"] += batch_metrics["dice/stroma"].item() - self._aperio_metrics["dice/tumor"] += batch_metrics["dice/tumor"].item() - self._aperio_metrics["dice/ignore"] += batch_metrics["dice/ignore"].item() - self._aperio_batch_count += 1 + for metric in self.metrics: + batch_metrics = metric(prediction, target, roi) + for class_name, class_index in self.index_map.items(): + metric_key = f"{metric.name}/{class_name}" + self._metrics_per_scanner[scanner_name][metric_key] += batch_metrics[metric_key].item() - elif scanner_name == "P1000": - self._p1000_metrics["dice/background"] += batch_metrics["dice/background"].item() - self._p1000_metrics["dice/stroma"] += batch_metrics["dice/stroma"].item() - self._p1000_metrics["dice/tumor"] += batch_metrics["dice/tumor"].item() - self._p1000_metrics["dice/ignore"] += batch_metrics["dice/ignore"].item() - self._p1000_batch_count += 1 + self._batch_count_per_scanner[scanner_name] += 1 def on_validation_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: - # Compute the average metrics for Aperio - if self._aperio_batch_count > 0: - averaged_aperio_metrics = { - f"Aperio/{key}": value / self._aperio_batch_count for key, value in self._aperio_metrics.items() - } - trainer.logger.log_metrics(averaged_aperio_metrics, step=trainer.global_step) - - # Compute the average metrics for P1000 - if self._p1000_batch_count > 0: - averaged_p1000_metrics = { - f"P1000/{key}": value / self._p1000_batch_count for key, value in self._p1000_metrics.items() - } - trainer.logger.log_metrics(averaged_p1000_metrics, step=trainer.global_step) - - self._aperio_metrics = {"dice/background": 0.0, "dice/stroma": 0.0, "dice/tumor": 0.0, "dice/ignore": 0.0} - self._p1000_metrics = {"dice/background": 0.0, "dice/stroma": 0.0, "dice/tumor": 0.0, "dice/ignore": 0.0} - self._aperio_batch_count = 0 - self._p1000_batch_count = 0 + for scanner_name, metrics in self._metrics_per_scanner.items(): + batch_count = self._batch_count_per_scanner[scanner_name] + if batch_count > 0: + averaged_metrics = {f"{scanner_name}/{key}": value / batch_count for key, value in metrics.items()} + self._logger.log_metrics(averaged_metrics, step=trainer.global_step) + + self._metrics_per_scanner = { + scanner.scanner_name: {f"{metric.name}/{class_name}": 0.0 for class_name in self.index_map.keys() for metric in self.metrics} + for scanner in ScannerEnum + } + self._batch_count_per_scanner = {scanner.scanner_name: 0 for scanner in ScannerEnum} diff --git a/ahcore/utils/callbacks.py b/ahcore/utils/callbacks.py index 674bde7..ef78168 100644 --- a/ahcore/utils/callbacks.py +++ b/ahcore/utils/callbacks.py @@ -21,7 +21,7 @@ from ahcore.transforms.pre_transforms import one_hot_encoding from ahcore.utils.data import DataDescription from ahcore.utils.io import get_logger -from ahcore.utils.types import DlupDatasetSample +from ahcore.utils.types import DlupDatasetSample, LoggerEnum logger = get_logger(__name__) @@ -241,3 +241,36 @@ def get_output_filename(dump_dir: Path, input_path: Path, model_name: str, count if counter is not None: return dump_dir / "outputs" / model_name / f"{counter}" / f"{hex_dig}.cache" return dump_dir / "outputs" / model_name / f"{hex_dig}.cache" + + +class AhCoreLogger: + def __init__(self, logger): + self.logger = logger + + def get_logger_type(self): + if hasattr(self.logger.experiment, "log_figure"): + return LoggerEnum.MLFLOW + elif hasattr(self.logger.experiment, "add_figure"): + return LoggerEnum.TENSORBOARD + else: + return LoggerEnum.UNKNOWN + + def log_figure(self, figure, step, batch_idx): + if self.get_logger_type() == LoggerEnum.MLFLOW: + artifact_file_name = f"validation_global_step{step:03d}_batch{batch_idx:03d}.png" + self.logger.experiment.log_figure(self.logger.run_id, figure, artifact_file=artifact_file_name) + elif self.get_logger_type() == LoggerEnum.TENSORBOARD: + self.logger.experiment.add_figure(f"validation_step_{step}_batch_{batch_idx}", figure, global_step=step) + else: + raise NotImplementedError(f"Logging method for logger {type(self.logger).__name__} not implemented.") + + def log_metrics(self, metrics, step): + logger_type = self.get_logger_type() + if logger_type == LoggerEnum.MLFLOW: + for metric_name, value in metrics.items(): + self.logger.experiment.log_metric(self.logger.run_id, key=metric_name, value=value, step=step) + elif logger_type == LoggerEnum.TENSORBOARD: + for metric_name, value in metrics.items(): + self.logger.experiment.add_scalar(metric_name, value, global_step=step) + else: + raise NotImplementedError(f"Logging method for logger {type(self.logger).__name__} not implemented.") diff --git a/ahcore/utils/types.py b/ahcore/utils/types.py index 615f585..9b2f2bc 100644 --- a/ahcore/utils/types.py +++ b/ahcore/utils/types.py @@ -104,3 +104,10 @@ def get_scanner_name(cls, file_extension): return scanner.scanner_name # Return a default value if extension is not found return cls.DEFAULT.scanner_name + + +class LoggerEnum(Enum): + TENSORBOARD = "tensorboard" + MLFLOW = "mlflow" + UNKNOWN = "unknown" + # Extend as necessary From 9e2c51f24f904933cfbb9ff39878ba51ef592cd2 Mon Sep 17 00:00:00 2001 From: Ajey Pai K Date: Fri, 27 Sep 2024 11:55:08 +0200 Subject: [PATCH 17/24] enumerate scanner vendors --- ahcore/callbacks/log_images_callback.py | 4 ++-- ahcore/callbacks/track_metrics_per_scanner.py | 12 ++++++------ ahcore/utils/types.py | 6 +++--- 3 files changed, 11 insertions(+), 11 deletions(-) diff --git a/ahcore/callbacks/log_images_callback.py b/ahcore/callbacks/log_images_callback.py index 9c5e8a1..f6194a9 100644 --- a/ahcore/callbacks/log_images_callback.py +++ b/ahcore/callbacks/log_images_callback.py @@ -9,7 +9,7 @@ from ahcore.metrics.metrics import _compute_dice from ahcore.utils.callbacks import AhCoreLogger -from ahcore.utils.types import ScannerEnum +from ahcore.utils.types import ScannerVendors def get_sample_wise_dice_components( @@ -79,7 +79,7 @@ def _extract_scanner_name(path) -> str: # Extract file extension extension = path.split(".")[-1] # Use the ScannerEnum to get the scanner name - return ScannerEnum.get_scanner_name(extension) + return ScannerVendors.get_vendor_name(extension) class LogImagesCallback(pl.Callback): diff --git a/ahcore/callbacks/track_metrics_per_scanner.py b/ahcore/callbacks/track_metrics_per_scanner.py index 9e0757b..adb64c5 100644 --- a/ahcore/callbacks/track_metrics_per_scanner.py +++ b/ahcore/callbacks/track_metrics_per_scanner.py @@ -2,7 +2,7 @@ from ahcore.metrics import TileMetric from ahcore.utils.callbacks import AhCoreLogger -from ahcore.utils.types import ScannerEnum +from ahcore.utils.types import ScannerVendors class TrackTileMetricsPerScanner(pl.Callback): @@ -18,9 +18,9 @@ def __init__(self, metrics: list[TileMetric], index_map: dict[str, int]): self.index_map = index_map self._metrics_per_scanner = { scanner.scanner_name: {f"{metric.name}/{class_name}": 0.0 for class_name in self.index_map.keys() for metric in self.metrics} - for scanner in ScannerEnum + for scanner in ScannerVendors } - self._batch_count_per_scanner = {scanner.scanner_name: 0 for scanner in ScannerEnum} + self._batch_count_per_scanner = {scanner.scanner_name: 0 for scanner in ScannerVendors} self._logger: AhCoreLogger | None = None def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx=0) -> None: @@ -29,7 +29,7 @@ def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, path = batch["path"][0] file_extension = path.split(".")[-1] - scanner_name = ScannerEnum.get_scanner_name(file_extension) + scanner_name = ScannerVendors.get_vendor_name(file_extension) prediction = outputs["prediction"] target = batch["target"] @@ -52,6 +52,6 @@ def on_validation_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.Lightnin self._metrics_per_scanner = { scanner.scanner_name: {f"{metric.name}/{class_name}": 0.0 for class_name in self.index_map.keys() for metric in self.metrics} - for scanner in ScannerEnum + for scanner in ScannerVendors } - self._batch_count_per_scanner = {scanner.scanner_name: 0 for scanner in ScannerEnum} + self._batch_count_per_scanner = {scanner.scanner_name: 0 for scanner in ScannerVendors} diff --git a/ahcore/utils/types.py b/ahcore/utils/types.py index 9b2f2bc..821509d 100644 --- a/ahcore/utils/types.py +++ b/ahcore/utils/types.py @@ -88,9 +88,9 @@ class ViTEmbedMode(str, Enum): # Extend as necessary -class ScannerEnum(Enum): +class ScannerVendors(Enum): SVS = ("svs", "Aperio") - MRXS = ("mrxs", "P1000") + MRXS = ("mrxs", "3DHistech") DEFAULT = ("default", "Unknown Scanner") def __init__(self, extension, scanner_name): @@ -98,7 +98,7 @@ def __init__(self, extension, scanner_name): self.scanner_name = scanner_name @classmethod - def get_scanner_name(cls, file_extension): + def get_vendor_name(cls, file_extension): for scanner in cls: if scanner.extension == file_extension: return scanner.scanner_name From 44f9df7caaf29d58a058a277acff4c18e345134f Mon Sep 17 00:00:00 2001 From: Ajey Pai K Date: Fri, 27 Sep 2024 12:05:18 +0200 Subject: [PATCH 18/24] Some more generalization --- ahcore/callbacks/log_images_callback.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/ahcore/callbacks/log_images_callback.py b/ahcore/callbacks/log_images_callback.py index f6194a9..264a75d 100644 --- a/ahcore/callbacks/log_images_callback.py +++ b/ahcore/callbacks/log_images_callback.py @@ -156,6 +156,7 @@ def _plot_and_log( scanner_name: str, dices: List[dict[int, float]], ) -> None: + class_wise_dice = None batch_size = images_numpy.shape[0] figure = plt.figure(figsize=(15, batch_size * 5)) for i in range(batch_size): @@ -176,8 +177,9 @@ def _plot_and_log( plt.subplot(batch_size, 3, i * 3 + 3) class_indices_pred = np.argmax(predictions_numpy[i], axis=-1) colored_img_pred = apply_color_map(class_indices_pred, self.color_map, self._num_classes) - if dices is not None: - plt.title(f"Dice: {class_wise_dice[1]:.2f} {class_wise_dice[2]:.2f} {class_wise_dice[3]:.2f}") + if dices is not None and class_wise_dice is not None: + dice_values = ' '.join([f"{class_wise_dice[i]:.2f}" for i in range(1, self._num_classes)]) + plt.title(f"Dice: {dice_values}") plt.imshow(colored_img_pred) plt.axis("off") plt.tight_layout() From a3d2562826b74376ea0e58daf6d202d5d31e84a3 Mon Sep 17 00:00:00 2001 From: Ajey Pai K Date: Fri, 27 Sep 2024 12:06:21 +0200 Subject: [PATCH 19/24] Add config files --- config/callbacks/log_images_callback.yaml | 7 +++++++ config/callbacks/track_metrics_per_scanner.yaml | 6 ++++++ 2 files changed, 13 insertions(+) create mode 100644 config/callbacks/log_images_callback.yaml create mode 100644 config/callbacks/track_metrics_per_scanner.yaml diff --git a/config/callbacks/log_images_callback.yaml b/config/callbacks/log_images_callback.yaml new file mode 100644 index 0000000..19c9d57 --- /dev/null +++ b/config/callbacks/log_images_callback.yaml @@ -0,0 +1,7 @@ +log_images_callback: + _target_: ahcore.callbacks.log_images_callback.LogImagesCallback + color_map: ${data_description.color_map} + num_classes: ${data_description.num_classes} + plot_scanner_wise: True + plot_every_n_epochs: 10 + plot_dice: True \ No newline at end of file diff --git a/config/callbacks/track_metrics_per_scanner.yaml b/config/callbacks/track_metrics_per_scanner.yaml new file mode 100644 index 0000000..07bd05d --- /dev/null +++ b/config/callbacks/track_metrics_per_scanner.yaml @@ -0,0 +1,6 @@ +track_metrics_per_scanner: + _target_: ahcore.callbacks.track_metrics_per_scanner.TrackTileMetricsPerScanner + metrics: + - _target_: ahcore.metrics.DiceMetric + data_description: ${data_description} + index_map: ${data_description.index_map} \ No newline at end of file From 224aed2b4a07af8d3e6d700124abd4c3edf2f146 Mon Sep 17 00:00:00 2001 From: Ajey Pai K Date: Fri, 27 Sep 2024 15:03:28 +0200 Subject: [PATCH 20/24] CI/CD --- ahcore/callbacks/log_images_callback.py | 41 ++++++++++++------- ahcore/callbacks/track_metrics_per_scanner.py | 32 +++++++++++---- ahcore/utils/callbacks.py | 12 +++--- ahcore/utils/types.py | 4 +- 4 files changed, 59 insertions(+), 30 deletions(-) diff --git a/ahcore/callbacks/log_images_callback.py b/ahcore/callbacks/log_images_callback.py index 264a75d..3ae48cd 100644 --- a/ahcore/callbacks/log_images_callback.py +++ b/ahcore/callbacks/log_images_callback.py @@ -1,4 +1,4 @@ -from typing import Any, List, Tuple +from typing import Any, List, Optional, Tuple import matplotlib.pyplot as plt import numpy as np @@ -9,7 +9,7 @@ from ahcore.metrics.metrics import _compute_dice from ahcore.utils.callbacks import AhCoreLogger -from ahcore.utils.types import ScannerVendors +from ahcore.utils.types import GenericNumberArray, ScannerVendors def get_sample_wise_dice_components( @@ -64,7 +64,7 @@ def get_sample_wise_dice( dices = [] dice_components = get_sample_wise_dice_components(outputs["prediction"], batch["target"], roi, num_classes) for samplewise_dice_components in dice_components: - classwise_dices_for_a_sample = {1: 0, 2: 0, 3: 0} + classwise_dices_for_a_sample: dict[int, float] = {1: 0.0, 2: 0.0, 3: 0.0} for class_idx in range(num_classes): if class_idx == 0: continue @@ -75,7 +75,7 @@ def get_sample_wise_dice( return dices -def _extract_scanner_name(path) -> str: +def _extract_scanner_name(path: str) -> str: # Extract file extension extension = path.split(".")[-1] # Use the ScannerEnum to get the scanner name @@ -94,13 +94,21 @@ def __init__( super().__init__() self.color_map = {k: np.array(to_rgb(v)) * 255 for k, v in color_map.items()} self._num_classes = num_classes - self._already_seen_scanner = [] + self._already_seen_scanner: List[str] = [] self._plot_scanner_wise = plot_scanner_wise self._plot_every_n_epochs = plot_every_n_epochs self._plot_dice = plot_dice - self._logger: AhCoreLogger | None = None + self._logger: Optional[AhCoreLogger | None] = None - def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx=0) -> None: + def on_validation_batch_end( + self, + trainer: "pl.Trainer", + pl_module: "pl.LightningModule", + outputs: dict[str, Any], # type: ignore + batch: dict[str, Any], + batch_idx: int, + dataloader_idx: int = 0, + ) -> None: if self._logger is None: self._logger = AhCoreLogger(pl_module.logger) if trainer.current_epoch % self._plot_every_n_epochs == 0: @@ -148,13 +156,13 @@ def on_validation_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.Lightnin def _plot_and_log( self, - images_numpy: np.ndarray, - predictions_numpy: np.ndarray, - targets_numpy: np.ndarray, + images_numpy: GenericNumberArray, + predictions_numpy: GenericNumberArray, + targets_numpy: GenericNumberArray, step: int, batch_idx: int, - scanner_name: str, - dices: List[dict[int, float]], + scanner_name: Optional[str | None] = None, + dices: Optional[List[dict[int, float]] | None] = None, ) -> None: class_wise_dice = None batch_size = images_numpy.shape[0] @@ -178,16 +186,19 @@ def _plot_and_log( class_indices_pred = np.argmax(predictions_numpy[i], axis=-1) colored_img_pred = apply_color_map(class_indices_pred, self.color_map, self._num_classes) if dices is not None and class_wise_dice is not None: - dice_values = ' '.join([f"{class_wise_dice[i]:.2f}" for i in range(1, self._num_classes)]) + dice_values = " ".join([f"{class_wise_dice[i]:.2f}" for i in range(1, self._num_classes)]) plt.title(f"Dice: {dice_values}") plt.imshow(colored_img_pred) plt.axis("off") plt.tight_layout() - self._logger.log_figure(figure, step, batch_idx) + if self._logger: # This is for mypy + self._logger.log_figure(figure, step, batch_idx) plt.close() -def apply_color_map(image, color_map, num_classes) -> np.ndarray: +def apply_color_map( + image: GenericNumberArray, color_map: dict[int, Any], num_classes: int +) -> np.ndarray[Any, np.dtype[np.uint8]]: colored_image = np.zeros((*image.shape, 3), dtype=np.uint8) for i in range(1, num_classes): colored_image[image == i] = color_map[i] diff --git a/ahcore/callbacks/track_metrics_per_scanner.py b/ahcore/callbacks/track_metrics_per_scanner.py index adb64c5..a10b7a1 100644 --- a/ahcore/callbacks/track_metrics_per_scanner.py +++ b/ahcore/callbacks/track_metrics_per_scanner.py @@ -1,3 +1,5 @@ +from typing import Any, Optional + import pytorch_lightning as pl from ahcore.metrics import TileMetric @@ -12,19 +14,30 @@ class TrackTileMetricsPerScanner(pl.Callback): - You want to track metrics corresponding to each class in the `index_map` - Each metric is tracked per scanner """ - def __init__(self, metrics: list[TileMetric], index_map: dict[str, int]): + + def __init__(self, metrics: list[TileMetric], index_map: dict[str, int]) -> None: super().__init__() self.metrics = metrics self.index_map = index_map self._metrics_per_scanner = { - scanner.scanner_name: {f"{metric.name}/{class_name}": 0.0 for class_name in self.index_map.keys() for metric in self.metrics} + scanner.scanner_name: { + f"{metric.name}/{class_name}": 0.0 for class_name in self.index_map.keys() for metric in self.metrics + } for scanner in ScannerVendors } self._batch_count_per_scanner = {scanner.scanner_name: 0 for scanner in ScannerVendors} - self._logger: AhCoreLogger | None = None + self._logger: Optional[AhCoreLogger] = None - def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx=0) -> None: - if self._logger is None: + def on_validation_batch_end( + self, + trainer: "pl.Trainer", + pl_module: "pl.LightningModule", + outputs: dict[str, Any], # type: ignore + batch: dict[str, Any], + batch_idx: int, + dataloader_idx: int = 0, + ) -> None: + if not self._logger: self._logger = AhCoreLogger(pl_module.logger) path = batch["path"][0] @@ -47,11 +60,14 @@ def on_validation_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.Lightnin for scanner_name, metrics in self._metrics_per_scanner.items(): batch_count = self._batch_count_per_scanner[scanner_name] if batch_count > 0: - averaged_metrics = {f"{scanner_name}/{key}": value / batch_count for key, value in metrics.items()} - self._logger.log_metrics(averaged_metrics, step=trainer.global_step) + if self._logger: # This is for mypy + averaged_metrics = {f"{scanner_name}/{key}": value / batch_count for key, value in metrics.items()} + self._logger.log_metrics(averaged_metrics, step=trainer.global_step) self._metrics_per_scanner = { - scanner.scanner_name: {f"{metric.name}/{class_name}": 0.0 for class_name in self.index_map.keys() for metric in self.metrics} + scanner.scanner_name: { + f"{metric.name}/{class_name}": 0.0 for class_name in self.index_map.keys() for metric in self.metrics + } for scanner in ScannerVendors } self._batch_count_per_scanner = {scanner.scanner_name: 0 for scanner in ScannerVendors} diff --git a/ahcore/utils/callbacks.py b/ahcore/utils/callbacks.py index ef78168..71118fe 100644 --- a/ahcore/utils/callbacks.py +++ b/ahcore/utils/callbacks.py @@ -14,6 +14,8 @@ from dlup.annotations import WsiAnnotations from dlup.data.transforms import convert_annotations, rename_labels from dlup.tiling import Grid, GridOrder, TilingMode +from matplotlib.figure import Figure +from pytorch_lightning.loggers import Logger from shapely.geometry import MultiPoint, Point from torch.utils.data import Dataset @@ -244,10 +246,10 @@ def get_output_filename(dump_dir: Path, input_path: Path, model_name: str, count class AhCoreLogger: - def __init__(self, logger): - self.logger = logger + def __init__(self, pl_logger: Logger | Any) -> None: + self.logger = pl_logger - def get_logger_type(self): + def get_logger_type(self) -> LoggerEnum: if hasattr(self.logger.experiment, "log_figure"): return LoggerEnum.MLFLOW elif hasattr(self.logger.experiment, "add_figure"): @@ -255,7 +257,7 @@ def get_logger_type(self): else: return LoggerEnum.UNKNOWN - def log_figure(self, figure, step, batch_idx): + def log_figure(self, figure: Figure, step: int, batch_idx: int) -> None: if self.get_logger_type() == LoggerEnum.MLFLOW: artifact_file_name = f"validation_global_step{step:03d}_batch{batch_idx:03d}.png" self.logger.experiment.log_figure(self.logger.run_id, figure, artifact_file=artifact_file_name) @@ -264,7 +266,7 @@ def log_figure(self, figure, step, batch_idx): else: raise NotImplementedError(f"Logging method for logger {type(self.logger).__name__} not implemented.") - def log_metrics(self, metrics, step): + def log_metrics(self, metrics: dict[str, Any], step: int) -> None: logger_type = self.get_logger_type() if logger_type == LoggerEnum.MLFLOW: for metric_name, value in metrics.items(): diff --git a/ahcore/utils/types.py b/ahcore/utils/types.py index 821509d..3da2bc7 100644 --- a/ahcore/utils/types.py +++ b/ahcore/utils/types.py @@ -93,12 +93,12 @@ class ScannerVendors(Enum): MRXS = ("mrxs", "3DHistech") DEFAULT = ("default", "Unknown Scanner") - def __init__(self, extension, scanner_name): + def __init__(self, extension: str, scanner_name: str) -> None: self.extension = extension self.scanner_name = scanner_name @classmethod - def get_vendor_name(cls, file_extension): + def get_vendor_name(cls, file_extension: str) -> str: for scanner in cls: if scanner.extension == file_extension: return scanner.scanner_name From 222c2b72b4354fd1b285537113437db4c7714f5e Mon Sep 17 00:00:00 2001 From: Ajey Pai K Date: Wed, 2 Oct 2024 15:18:41 +0200 Subject: [PATCH 21/24] Rename class --- ahcore/callbacks/track_metrics_per_scanner.py | 2 +- ...track_metrics_per_scanner.yaml => scanner_tile_metrics.yaml} | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) rename config/callbacks/{track_metrics_per_scanner.yaml => scanner_tile_metrics.yaml} (66%) diff --git a/ahcore/callbacks/track_metrics_per_scanner.py b/ahcore/callbacks/track_metrics_per_scanner.py index a10b7a1..e0a34c9 100644 --- a/ahcore/callbacks/track_metrics_per_scanner.py +++ b/ahcore/callbacks/track_metrics_per_scanner.py @@ -7,7 +7,7 @@ from ahcore.utils.types import ScannerVendors -class TrackTileMetricsPerScanner(pl.Callback): +class ScannerTileMetricsCallback(pl.Callback): """ This callback is used to track several `TileMetric` from ahcore per scanner. The callback works on certain assumptions: diff --git a/config/callbacks/track_metrics_per_scanner.yaml b/config/callbacks/scanner_tile_metrics.yaml similarity index 66% rename from config/callbacks/track_metrics_per_scanner.yaml rename to config/callbacks/scanner_tile_metrics.yaml index 07bd05d..b2d0a09 100644 --- a/config/callbacks/track_metrics_per_scanner.yaml +++ b/config/callbacks/scanner_tile_metrics.yaml @@ -1,5 +1,5 @@ track_metrics_per_scanner: - _target_: ahcore.callbacks.track_metrics_per_scanner.TrackTileMetricsPerScanner + _target_: ahcore.callbacks.track_metrics_per_scanner.ScannerTileMetricsCallback metrics: - _target_: ahcore.metrics.DiceMetric data_description: ${data_description} From 0484dd99f58a77bee20cd754d0aa86f9585ee9fe Mon Sep 17 00:00:00 2001 From: Ajey Pai K Date: Wed, 2 Oct 2024 15:23:03 +0200 Subject: [PATCH 22/24] Rename file --- ...ck_metrics_per_scanner.py => scanner_tile_metrics_callback.py} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename ahcore/callbacks/{track_metrics_per_scanner.py => scanner_tile_metrics_callback.py} (100%) diff --git a/ahcore/callbacks/track_metrics_per_scanner.py b/ahcore/callbacks/scanner_tile_metrics_callback.py similarity index 100% rename from ahcore/callbacks/track_metrics_per_scanner.py rename to ahcore/callbacks/scanner_tile_metrics_callback.py From 3d127731f68fe5788abf1a3bd9af21caf257aeeb Mon Sep 17 00:00:00 2001 From: Ajey Pai K Date: Wed, 2 Oct 2024 15:23:44 +0200 Subject: [PATCH 23/24] refactor config file --- config/callbacks/scanner_tile_metrics.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/config/callbacks/scanner_tile_metrics.yaml b/config/callbacks/scanner_tile_metrics.yaml index b2d0a09..b1ef915 100644 --- a/config/callbacks/scanner_tile_metrics.yaml +++ b/config/callbacks/scanner_tile_metrics.yaml @@ -1,5 +1,5 @@ track_metrics_per_scanner: - _target_: ahcore.callbacks.track_metrics_per_scanner.ScannerTileMetricsCallback + _target_: ahcore.callbacks.scanner_tile_metrics_callback.ScannerTileMetricsCallback metrics: - _target_: ahcore.metrics.DiceMetric data_description: ${data_description} From 3cc8c5b2ae116f00fec3903b4a422ac89a09c2aa Mon Sep 17 00:00:00 2001 From: Ajey Pai K Date: Wed, 2 Oct 2024 17:59:49 +0200 Subject: [PATCH 24/24] refactor config file --- config/callbacks/scanner_tile_metrics.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/config/callbacks/scanner_tile_metrics.yaml b/config/callbacks/scanner_tile_metrics.yaml index b1ef915..49cf7cb 100644 --- a/config/callbacks/scanner_tile_metrics.yaml +++ b/config/callbacks/scanner_tile_metrics.yaml @@ -1,4 +1,4 @@ -track_metrics_per_scanner: +scanner_tile_metrics_callback: _target_: ahcore.callbacks.scanner_tile_metrics_callback.ScannerTileMetricsCallback metrics: - _target_: ahcore.metrics.DiceMetric