Skip to content
This repository was archived by the owner on Oct 19, 2024. It is now read-only.
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
205 changes: 205 additions & 0 deletions ahcore/callbacks/log_images_callback.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,205 @@
from typing import Any, List, Optional, Tuple

import matplotlib.pyplot as plt
import numpy as np
import pytorch_lightning as pl
import torch
import torch.nn.functional as F
from matplotlib.colors import to_rgb

from ahcore.metrics.metrics import _compute_dice
from ahcore.utils.callbacks import AhCoreLogger
from ahcore.utils.types import GenericNumberArray, ScannerVendors


def get_sample_wise_dice_components(
predictions: torch.Tensor,
target: torch.Tensor,
roi: torch.Tensor | None,
num_classes: int,
) -> List[List[Tuple[torch.Tensor, torch.Tensor]]]:
# Apply softmax along the class dimension
soft_predictions = F.softmax(predictions, dim=1)

# Get the predicted class per pixel
predictions = soft_predictions.argmax(dim=1)

# Get the target class per pixel
_target = target.argmax(dim=1)

# Initialize list to store dice components for each sample
dice_components = []

# Loop over the batch (B samples)
for batch_idx in range(predictions.size(0)): # Loop through the batch dimension (B)
batch_dice_components = []

for class_idx in range(num_classes): # Loop through the classes
# Get predictions and target for the current class and sample
curr_predictions = (predictions[batch_idx] == class_idx).int()
curr_target = (_target[batch_idx] == class_idx).int()

if roi is not None:
# Apply ROI if it's provided
curr_roi = roi[batch_idx].squeeze(0) # Adjust ROI for current sample
intersection = torch.sum((curr_predictions * curr_target) * curr_roi, dim=(0, 1))
cardinality = torch.sum(curr_predictions * curr_roi, dim=(0, 1)) + torch.sum(
curr_target * curr_roi, dim=(0, 1)
)
else:
# No ROI, just compute intersection and cardinality
intersection = torch.sum(curr_predictions * curr_target, dim=(0, 1))
cardinality = torch.sum(curr_predictions, dim=(0, 1)) + torch.sum(curr_target, dim=(0, 1))

batch_dice_components.append((intersection, cardinality))

dice_components.append(batch_dice_components)

return dice_components


def get_sample_wise_dice(
outputs: dict[str, Any], batch: dict[str, Any], roi: torch.Tensor, num_classes: int
) -> List[dict[int, float]]:
dices = []
dice_components = get_sample_wise_dice_components(outputs["prediction"], batch["target"], roi, num_classes)
for samplewise_dice_components in dice_components:
classwise_dices_for_a_sample: dict[int, float] = {1: 0.0, 2: 0.0, 3: 0.0}
for class_idx in range(num_classes):
if class_idx == 0:
continue
intersection, cardinality = samplewise_dice_components[class_idx]
dice_score = _compute_dice(intersection, cardinality)
classwise_dices_for_a_sample[class_idx] = dice_score.item()
dices.append(classwise_dices_for_a_sample)
return dices


def _extract_scanner_name(path: str) -> str:
# Extract file extension
extension = path.split(".")[-1]
# Use the ScannerEnum to get the scanner name
return ScannerVendors.get_vendor_name(extension)


class LogImagesCallback(pl.Callback):
def __init__(
self,
color_map: dict[int, str],
num_classes: int,
plot_dice: bool = True,
plot_scanner_wise: bool = False,
plot_every_n_epochs: int = 10,
):
super().__init__()
self.color_map = {k: np.array(to_rgb(v)) * 255 for k, v in color_map.items()}
self._num_classes = num_classes
self._already_seen_scanner: List[str] = []
self._plot_scanner_wise = plot_scanner_wise
self._plot_every_n_epochs = plot_every_n_epochs
self._plot_dice = plot_dice
self._logger: Optional[AhCoreLogger | None] = None

def on_validation_batch_end(
self,
trainer: "pl.Trainer",
pl_module: "pl.LightningModule",
outputs: dict[str, Any], # type: ignore
batch: dict[str, Any],
batch_idx: int,
dataloader_idx: int = 0,
) -> None:
if self._logger is None:
self._logger = AhCoreLogger(pl_module.logger)
if trainer.current_epoch % self._plot_every_n_epochs == 0:
val_images = batch["image"]
roi = batch["roi"]
path = batch["path"][0]

if self._plot_scanner_wise:
scanner_name = _extract_scanner_name(path)
else:
scanner_name = None

if self._plot_dice:
dices = get_sample_wise_dice(outputs, batch, roi, self._num_classes)
else:
dices = None

val_images_numpy = val_images.permute(0, 2, 3, 1).detach().cpu().numpy()
val_images_numpy = (val_images_numpy - val_images_numpy.min()) / (
val_images_numpy.max() - val_images_numpy.min()
)

val_predictions = outputs["prediction"]
val_predictions_numpy = val_predictions.permute(0, 2, 3, 1).detach().cpu().numpy()

val_targets = batch["target"]
val_targets_numpy = val_targets.permute(0, 2, 3, 1).detach().cpu().numpy()

if scanner_name not in self._already_seen_scanner:
self._plot_and_log(
val_images_numpy,
val_predictions_numpy,
val_targets_numpy,
trainer.global_step,
batch_idx,
scanner_name,
dices,
)

if self._plot_scanner_wise and scanner_name is not None:
self._already_seen_scanner.append(scanner_name)

def on_validation_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
self._already_seen_scanner = []

def _plot_and_log(
self,
images_numpy: GenericNumberArray,
predictions_numpy: GenericNumberArray,
targets_numpy: GenericNumberArray,
step: int,
batch_idx: int,
scanner_name: Optional[str | None] = None,
dices: Optional[List[dict[int, float]] | None] = None,
) -> None:
class_wise_dice = None
batch_size = images_numpy.shape[0]
figure = plt.figure(figsize=(15, batch_size * 5))
for i in range(batch_size):
if dices is not None:
class_wise_dice = dices[i]
plt.subplot(batch_size, 3, i * 3 + 1)
plt.imshow(images_numpy[i])
plt.axis("off")
if scanner_name is not None:
plt.title(f"Original Image (Scanner: {scanner_name})")

plt.subplot(batch_size, 3, i * 3 + 2)
class_indices_gt = np.argmax(targets_numpy[i], axis=-1)
colored_img_gt = apply_color_map(class_indices_gt, self.color_map, self._num_classes)
plt.imshow(colored_img_gt)
plt.axis("off")

plt.subplot(batch_size, 3, i * 3 + 3)
class_indices_pred = np.argmax(predictions_numpy[i], axis=-1)
colored_img_pred = apply_color_map(class_indices_pred, self.color_map, self._num_classes)
if dices is not None and class_wise_dice is not None:
dice_values = " ".join([f"{class_wise_dice[i]:.2f}" for i in range(1, self._num_classes)])
plt.title(f"Dice: {dice_values}")
plt.imshow(colored_img_pred)
plt.axis("off")
plt.tight_layout()
if self._logger: # This is for mypy
self._logger.log_figure(figure, step, batch_idx)
plt.close()


def apply_color_map(
image: GenericNumberArray, color_map: dict[int, Any], num_classes: int
) -> np.ndarray[Any, np.dtype[np.uint8]]:
colored_image = np.zeros((*image.shape, 3), dtype=np.uint8)
for i in range(1, num_classes):
colored_image[image == i] = color_map[i]
return colored_image
73 changes: 73 additions & 0 deletions ahcore/callbacks/scanner_tile_metrics_callback.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
from typing import Any, Optional

import pytorch_lightning as pl

from ahcore.metrics import TileMetric
from ahcore.utils.callbacks import AhCoreLogger
from ahcore.utils.types import ScannerVendors


class ScannerTileMetricsCallback(pl.Callback):
"""
This callback is used to track several `TileMetric` from ahcore per scanner.
The callback works on certain assumptions:
- You want to track metrics corresponding to each class in the `index_map`
- Each metric is tracked per scanner
"""

def __init__(self, metrics: list[TileMetric], index_map: dict[str, int]) -> None:
super().__init__()
self.metrics = metrics
self.index_map = index_map
self._metrics_per_scanner = {
scanner.scanner_name: {
f"{metric.name}/{class_name}": 0.0 for class_name in self.index_map.keys() for metric in self.metrics
}
for scanner in ScannerVendors
}
self._batch_count_per_scanner = {scanner.scanner_name: 0 for scanner in ScannerVendors}
self._logger: Optional[AhCoreLogger] = None

def on_validation_batch_end(
self,
trainer: "pl.Trainer",
pl_module: "pl.LightningModule",
outputs: dict[str, Any], # type: ignore
batch: dict[str, Any],
batch_idx: int,
dataloader_idx: int = 0,
) -> None:
if not self._logger:
self._logger = AhCoreLogger(pl_module.logger)

path = batch["path"][0]
file_extension = path.split(".")[-1]
scanner_name = ScannerVendors.get_vendor_name(file_extension)

prediction = outputs["prediction"]
target = batch["target"]
roi = batch.get("roi", None)

for metric in self.metrics:
batch_metrics = metric(prediction, target, roi)
for class_name, class_index in self.index_map.items():
metric_key = f"{metric.name}/{class_name}"
self._metrics_per_scanner[scanner_name][metric_key] += batch_metrics[metric_key].item()

self._batch_count_per_scanner[scanner_name] += 1

def on_validation_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
for scanner_name, metrics in self._metrics_per_scanner.items():
batch_count = self._batch_count_per_scanner[scanner_name]
if batch_count > 0:
if self._logger: # This is for mypy
averaged_metrics = {f"{scanner_name}/{key}": value / batch_count for key, value in metrics.items()}
self._logger.log_metrics(averaged_metrics, step=trainer.global_step)

self._metrics_per_scanner = {
scanner.scanner_name: {
f"{metric.name}/{class_name}": 0.0 for class_name in self.index_map.keys() for metric in self.metrics
}
for scanner in ScannerVendors
}
self._batch_count_per_scanner = {scanner.scanner_name: 0 for scanner in ScannerVendors}
37 changes: 36 additions & 1 deletion ahcore/utils/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,16 @@
from dlup.annotations import WsiAnnotations
from dlup.data.transforms import convert_annotations, rename_labels
from dlup.tiling import Grid, GridOrder, TilingMode
from matplotlib.figure import Figure
from pytorch_lightning.loggers import Logger
from shapely.geometry import MultiPoint, Point
from torch.utils.data import Dataset

from ahcore.readers import FileImageReader
from ahcore.transforms.pre_transforms import one_hot_encoding
from ahcore.utils.data import DataDescription
from ahcore.utils.io import get_logger
from ahcore.utils.types import DlupDatasetSample
from ahcore.utils.types import DlupDatasetSample, LoggerEnum

logger = get_logger(__name__)

Expand Down Expand Up @@ -241,3 +243,36 @@ def get_output_filename(dump_dir: Path, input_path: Path, model_name: str, count
if counter is not None:
return dump_dir / "outputs" / model_name / f"{counter}" / f"{hex_dig}.cache"
return dump_dir / "outputs" / model_name / f"{hex_dig}.cache"


class AhCoreLogger:
def __init__(self, pl_logger: Logger | Any) -> None:
self.logger = pl_logger

def get_logger_type(self) -> LoggerEnum:
if hasattr(self.logger.experiment, "log_figure"):
return LoggerEnum.MLFLOW
elif hasattr(self.logger.experiment, "add_figure"):
return LoggerEnum.TENSORBOARD
else:
return LoggerEnum.UNKNOWN

def log_figure(self, figure: Figure, step: int, batch_idx: int) -> None:
if self.get_logger_type() == LoggerEnum.MLFLOW:
artifact_file_name = f"validation_global_step{step:03d}_batch{batch_idx:03d}.png"
self.logger.experiment.log_figure(self.logger.run_id, figure, artifact_file=artifact_file_name)
elif self.get_logger_type() == LoggerEnum.TENSORBOARD:
self.logger.experiment.add_figure(f"validation_step_{step}_batch_{batch_idx}", figure, global_step=step)
else:
raise NotImplementedError(f"Logging method for logger {type(self.logger).__name__} not implemented.")

def log_metrics(self, metrics: dict[str, Any], step: int) -> None:
logger_type = self.get_logger_type()
if logger_type == LoggerEnum.MLFLOW:
for metric_name, value in metrics.items():
self.logger.experiment.log_metric(self.logger.run_id, key=metric_name, value=value, step=step)
elif logger_type == LoggerEnum.TENSORBOARD:
for metric_name, value in metrics.items():
self.logger.experiment.add_scalar(metric_name, value, global_step=step)
else:
raise NotImplementedError(f"Logging method for logger {type(self.logger).__name__} not implemented.")
25 changes: 25 additions & 0 deletions ahcore/utils/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,3 +86,28 @@ class ViTEmbedMode(str, Enum):
CONCAT_MEAN = "embed_concat_mean"
CONCAT = "embed_concat"
# Extend as necessary


class ScannerVendors(Enum):
SVS = ("svs", "Aperio")
MRXS = ("mrxs", "3DHistech")
DEFAULT = ("default", "Unknown Scanner")

def __init__(self, extension: str, scanner_name: str) -> None:
self.extension = extension
self.scanner_name = scanner_name

@classmethod
def get_vendor_name(cls, file_extension: str) -> str:
for scanner in cls:
if scanner.extension == file_extension:
return scanner.scanner_name
# Return a default value if extension is not found
return cls.DEFAULT.scanner_name


class LoggerEnum(Enum):
TENSORBOARD = "tensorboard"
MLFLOW = "mlflow"
UNKNOWN = "unknown"
# Extend as necessary
7 changes: 7 additions & 0 deletions config/callbacks/log_images_callback.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
log_images_callback:
_target_: ahcore.callbacks.log_images_callback.LogImagesCallback
color_map: ${data_description.color_map}
num_classes: ${data_description.num_classes}
plot_scanner_wise: True
plot_every_n_epochs: 10
plot_dice: True
6 changes: 6 additions & 0 deletions config/callbacks/scanner_tile_metrics.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
scanner_tile_metrics_callback:
_target_: ahcore.callbacks.scanner_tile_metrics_callback.ScannerTileMetricsCallback
metrics:
- _target_: ahcore.metrics.DiceMetric
data_description: ${data_description}
index_map: ${data_description.index_map}