From 8245f6515237930a14dcbd0638bf889411fc5fcc Mon Sep 17 00:00:00 2001 From: Ajey Pai K Date: Tue, 23 Jul 2024 19:06:46 +0200 Subject: [PATCH] Implement NSD --- ahcore/metrics/metrics.py | 151 +++++++++++++++++- config/metrics/segmentation.yaml | 3 + .../test_metrics/test_surface_dice_metric.py | 116 ++++++++++++++ 3 files changed, 268 insertions(+), 2 deletions(-) create mode 100644 tests/test_metrics/test_surface_dice_metric.py diff --git a/ahcore/metrics/metrics.py b/ahcore/metrics/metrics.py index 1ccb48b..407c9c5 100644 --- a/ahcore/metrics/metrics.py +++ b/ahcore/metrics/metrics.py @@ -10,6 +10,7 @@ import torch import torch.nn.functional as F # noqa +from monai.metrics.surface_dice import compute_surface_dice from ahcore.exceptions import ConfigurationError from ahcore.utils.data import DataDescription @@ -159,10 +160,155 @@ def reset(self) -> None: pass +class WSiSurfaceDiceMetric(WSIMetric): + def __init__( + self, + data_description: DataDescription, + class_thresholds: List[float], + **kwargs: Any, + ) -> None: + super().__init__(data_description=data_description) + self._class_thresholds = class_thresholds + self._device = "cpu" + # Invert the index map + _index_map = {} + if self._data_description.index_map is None: + raise ConfigurationError("`index_map` is required for to setup the wsi-dice metric.") + else: + _index_map = self._data_description.index_map + + _label_to_class: dict[int, str] = {v: k for k, v in _index_map.items()} + _label_to_class[0] = "background" + self._label_to_class = _label_to_class + self._num_classes = self._data_description.num_classes + + @property + def name(self) -> str: + return "wsi_surface_dice" + + def process_batch( + self, + predictions: torch.Tensor, + target: torch.Tensor, + roi: torch.Tensor | None, + wsi_name: str, + ) -> None: + if wsi_name not in self.wsis: + self._initialize_wsi_dict(wsi_name) + surface_dice_components = self._get_surface_dice(predictions, target, roi) + for class_idx in range(0, self._num_classes): + self.wsis[wsi_name][class_idx]["surface_dice"] += surface_dice_components[0, class_idx] + self.wsis[wsi_name]["total_tiles"] += 1 + + def get_wsi_score(self, wsi_name: str) -> None: + for class_idx in range(self._num_classes): + self.wsis[wsi_name][class_idx]["wsi_surface_dice"] = ( + self.wsis[wsi_name][class_idx]["surface_dice"] / self.wsis[wsi_name]["total_tiles"] + ) + + def one_hot_encode(self, predictions: torch.Tensor) -> torch.Tensor: + # Create a tensor of zeros with shape (batch_size, num_classes, height, width) + one_hot = torch.zeros(predictions.size(0), self._num_classes, predictions.size(1), predictions.size(2)) + + # Scatter 1s in the appropriate locations + one_hot = one_hot.scatter_(1, predictions.unsqueeze(1), 1.0) + + return one_hot + + def _get_surface_dice( + self, + predictions: torch.Tensor, + target: torch.Tensor, + roi: torch.Tensor | None, + ) -> torch.Tensor: + # One-hot encode the predictions and target + arg_max_predictions = predictions.argmax(dim=1) + one_hot_predictions = self.one_hot_encode(arg_max_predictions) + + if roi is not None: + one_hot_predictions = one_hot_predictions * roi.squeeze(1) + target = target * roi.squeeze(1) + + surface_dice = compute_surface_dice( + one_hot_predictions, + target, + class_thresholds=self._class_thresholds, + distance_metric="chessboard", + include_background=True, + ) + + surface_dice[surface_dice.isnan()] = torch.ones(1) + + return surface_dice + + def _get_surface_dice_averaged_over_total_wsis(self) -> dict[int, float]: + surface_dices: dict[int, list[float]] = {class_idx: [] for class_idx in range(self._num_classes)} + for wsi_name in self.wsis: + self.get_wsi_score(wsi_name) + for class_idx in range(self._num_classes): + surface_dices[class_idx].append( + self.wsis[wsi_name][class_idx]["surface_dice"] / self.wsis[wsi_name]["total_tiles"] + ) + return { + class_idx: sum(surface_dices[class_idx]) / len(surface_dices[class_idx]) + for class_idx in surface_dices.keys() + } + + def _initialize_wsi_dict(self, wsi_name: str) -> None: + self.wsis[wsi_name] = {class_idx: {"surface_dice": torch.zeros(1)} for class_idx in range(self._num_classes)} + self.wsis[wsi_name]["total_tiles"] = 0 + + def get_average_score( + self, precomputed_output: list[list[dict[str, dict[str, float]]]] | None = None + ) -> dict[Any, Any]: + surface_dices = self._get_surface_dice_averaged_over_total_wsis() + avg_dict = { + f"{self.name}/surface_dice/{self._label_to_class[idx]}": value for idx, value in surface_dices.items() + } + return avg_dict + + @staticmethod + def static_average_wsi_surface_dice( + precomputed_output: list[list[dict[str, dict[str, float]]]] + ) -> dict[str, float]: + """Static method to compute the average WSI surface dice score over a list of WSI surface dice scores, + useful for multiprocessing.""" + # Initialize defaultdicts to handle the sum and count of dice scores for each class + class_sum: dict[str, float] = defaultdict(float) + class_count: dict[str, int] = defaultdict(int) + + # Flatten the list and extract 'wsi_dice' dictionaries + wsi_surface_dices: list[dict[str, float]] = [ + wsi_metric.get("wsi_surface_dice", {}) for sublist in precomputed_output for wsi_metric in sublist + ] + # Check if the list is empty -- then the precomputed output did not contain any wsi dice scores + if not wsi_surface_dices: + return {} + + # Update sum and count for each class in a single pass + for wsi_surface_dice in wsi_surface_dices: + for class_name, surface_dice_score in wsi_surface_dice.items(): + class_sum[class_name] += surface_dice_score + class_count[class_name] += 1 + + # Compute average dice scores in a dictionary comprehension with consistent naming + avg_surface_dice_scores = { + f"{'wsi_surface_dice'}/{class_name}": class_sum[class_name] / class_count[class_name] + for class_name in class_sum.keys() + } + return avg_surface_dice_scores + + def reset(self) -> None: + self.wsis = {} + + def __repr__(self) -> str: + return f"{type(self).__name__}(num_classes={self._num_classes})" + + class WSIDiceMetric(WSIMetric): """WSI Dice metric class, computes the dice score over the whole WSI""" - def __init__(self, data_description: DataDescription, compute_overall_dice: bool = False) -> None: + def __init__(self, data_description: DataDescription, compute_overall_dice: bool = False, **kwargs: Any) -> None: super().__init__(data_description=data_description) self.compute_overall_dice = compute_overall_dice self._num_classes = self._data_description.num_classes @@ -331,7 +477,8 @@ def metrics(self) -> list[WSIMetric]: @classmethod def for_segmentation(cls, *args: Any, **kwargs: Any) -> WSIMetricFactory: dices = WSIDiceMetric(*args, **kwargs) - return cls([dices]) + surface_dices = WSiSurfaceDiceMetric(*args, **kwargs) + return cls([dices, surface_dices]) @classmethod def for_wsi_classification(cls, *args: Any, **kwargs: Any) -> WSIMetricFactory: diff --git a/config/metrics/segmentation.yaml b/config/metrics/segmentation.yaml index cac63e8..f688b22 100644 --- a/config/metrics/segmentation.yaml +++ b/config/metrics/segmentation.yaml @@ -3,4 +3,7 @@ tile_level: wsi_level: _target_: ahcore.metrics.WSIMetricFactory.for_segmentation + # This is for the Dice score based on pixel counting compute_overall_dice: True + # This is for the normalized surface dice (boundary overlap) + class_thresholds: [0,3,3,3] diff --git a/tests/test_metrics/test_surface_dice_metric.py b/tests/test_metrics/test_surface_dice_metric.py new file mode 100644 index 0000000..297a4d5 --- /dev/null +++ b/tests/test_metrics/test_surface_dice_metric.py @@ -0,0 +1,116 @@ +from pathlib import Path + +import pytest +import torch + +from ahcore.metrics.metrics import WSiSurfaceDiceMetric +from ahcore.utils.data import DataDescription, GridDescription + + +@pytest.fixture +def data_description() -> DataDescription: + num_classes = 4 + index_map = {"class1": 1, "class2": 2, "class3": 3} + data_dir = Path("data_dir") + manifest_database_uri = "manifest_database_uri" + manifest_name = "manifest_name" + split_version = "split_version" + annotations_dir = Path("annotations_dir") + training_grid = GridDescription(mpp=1.0, tile_size=(256, 256), tile_overlap=(0, 0), output_tile_size=(256, 256)) + inference_grid = GridDescription(mpp=1.0, tile_size=(256, 256), tile_overlap=(0, 0), output_tile_size=(256, 256)) + return DataDescription( + num_classes=num_classes, + index_map=index_map, + data_dir=data_dir, + manifest_database_uri=manifest_database_uri, + manifest_name=manifest_name, + split_version=split_version, + annotations_dir=annotations_dir, + training_grid=training_grid, + inference_grid=inference_grid, + ) + + +@pytest.fixture +def metric(data_description: DataDescription) -> WSiSurfaceDiceMetric: + class_thresholds = [0.0, 1.0, 1.0, 1.0] + return WSiSurfaceDiceMetric(data_description=data_description, class_thresholds=class_thresholds) + + +def get_batch() -> tuple[torch.Tensor, torch.Tensor, None, str]: + predictions = torch.randn(1, 4, 256, 256).float() + target = torch.zeros(1, 4, 256, 256).float() # Mock target tensor + target[0, 1, 128, 128] = 1.0 # Simulate class 1 presence + target[0, 2, 64, 64] = 1.0 # Simulate class 2 presence + target[0, 3, 10, 10] = 1.0 # Simulate class 3 presence + roi = None + wsi_name = "test_wsi" + return predictions, target, roi, wsi_name + + +def test_process_batch(metric: WSiSurfaceDiceMetric) -> None: + predictions, target, roi, wsi_name = get_batch() + metric.process_batch(predictions, target, roi, wsi_name) + metric.get_wsi_score(wsi_name) + + assert wsi_name in metric.wsis + assert all("surface_dice" in metric.wsis[wsi_name][i] for i in range(metric._num_classes)) + + +def test_get_average_score(metric: WSiSurfaceDiceMetric) -> None: + predictions, target, roi, wsi_name = get_batch() + metric.process_batch(predictions, target, roi, wsi_name) + metric.get_wsi_score(wsi_name) + + average_scores = metric.get_average_score() + + assert isinstance(average_scores, dict) + assert all( + f"{metric.name}/surface_dice/{metric._label_to_class[idx]}" in average_scores + for idx in range(metric._num_classes) + ) + + +def test_static_average_wsi_surface_dice(metric: WSiSurfaceDiceMetric) -> None: + precomputed_output = [ + [{"wsi_surface_dice": {"class1": 0.8, "class2": 0.7, "class3": 0.9}}], + [{"wsi_surface_dice": {"class1": 0.85, "class2": 0.75, "class3": 0.95}}], + ] + + average_scores = WSiSurfaceDiceMetric.static_average_wsi_surface_dice(precomputed_output) + + assert isinstance(average_scores, dict) + assert "wsi_surface_dice/class1" in average_scores + assert "wsi_surface_dice/class2" in average_scores + assert "wsi_surface_dice/class3" in average_scores + + +def test_reset(metric: WSiSurfaceDiceMetric) -> None: + predictions, target, roi, wsi_name = get_batch() + metric.process_batch(predictions, target, roi, wsi_name) + metric.reset() + + assert metric.wsis == {} + + +def test_surface_dice_edge_cases(metric: WSiSurfaceDiceMetric) -> None: + # Note: In this test, background is ignored. + predictions = torch.zeros(1, 4, 256, 256).float() + target = torch.zeros(1, 4, 256, 256).float() + # Set targets and predictions in such a way that the boundaries are 1 pixel apart. + # The surface dice should be 1.0 for class 1, as a shift by 1 pixel falls under the tolerance limit. + predictions[0, 1, 0:128, 0:128] = 1.0 + target[0, 1, 1:129, 1:129] = 1.0 + # Test case where there is no overlap between the target and the prediction boundaries. + predictions[0, 2, 129:256, 129:256] = 1.0 + target[0, 2, 0:64, 0:64] = 1.0 + + metric.process_batch(predictions, target, None, "wsi_1") + metric.get_wsi_score("wsi_1") + + scores = metric.get_average_score() + + assert scores[f"{metric.name}/surface_dice/class1"] == 1.0 + assert scores[f"{metric.name}/surface_dice/class2"] == 0.0 + # Monai returns nan for class 3 since there are neither targets nor predictions but we change it to 1.0. + assert scores[f"{metric.name}/surface_dice/class3"] == 1.0