Skip to content
This repository was archived by the owner on Oct 19, 2024. It is now read-only.
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
151 changes: 149 additions & 2 deletions ahcore/metrics/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

import torch
import torch.nn.functional as F # noqa
from monai.metrics.surface_dice import compute_surface_dice

from ahcore.exceptions import ConfigurationError
from ahcore.utils.data import DataDescription
Expand Down Expand Up @@ -159,10 +160,155 @@ def reset(self) -> None:
pass


class WSiSurfaceDiceMetric(WSIMetric):
def __init__(
self,
data_description: DataDescription,
class_thresholds: List[float],
**kwargs: Any,
) -> None:
super().__init__(data_description=data_description)
self._class_thresholds = class_thresholds
self._device = "cpu"
# Invert the index map
_index_map = {}
if self._data_description.index_map is None:
raise ConfigurationError("`index_map` is required for to setup the wsi-dice metric.")
else:
_index_map = self._data_description.index_map

_label_to_class: dict[int, str] = {v: k for k, v in _index_map.items()}
_label_to_class[0] = "background"
self._label_to_class = _label_to_class
self._num_classes = self._data_description.num_classes

@property
def name(self) -> str:
return "wsi_surface_dice"

def process_batch(
self,
predictions: torch.Tensor,
target: torch.Tensor,
roi: torch.Tensor | None,
wsi_name: str,
) -> None:
if wsi_name not in self.wsis:
self._initialize_wsi_dict(wsi_name)
surface_dice_components = self._get_surface_dice(predictions, target, roi)
for class_idx in range(0, self._num_classes):
self.wsis[wsi_name][class_idx]["surface_dice"] += surface_dice_components[0, class_idx]
self.wsis[wsi_name]["total_tiles"] += 1

def get_wsi_score(self, wsi_name: str) -> None:
for class_idx in range(self._num_classes):
self.wsis[wsi_name][class_idx]["wsi_surface_dice"] = (
self.wsis[wsi_name][class_idx]["surface_dice"] / self.wsis[wsi_name]["total_tiles"]
)

def one_hot_encode(self, predictions: torch.Tensor) -> torch.Tensor:
# Create a tensor of zeros with shape (batch_size, num_classes, height, width)
one_hot = torch.zeros(predictions.size(0), self._num_classes, predictions.size(1), predictions.size(2))

# Scatter 1s in the appropriate locations
one_hot = one_hot.scatter_(1, predictions.unsqueeze(1), 1.0)

return one_hot

def _get_surface_dice(
self,
predictions: torch.Tensor,
target: torch.Tensor,
roi: torch.Tensor | None,
) -> torch.Tensor:
# One-hot encode the predictions and target
arg_max_predictions = predictions.argmax(dim=1)
one_hot_predictions = self.one_hot_encode(arg_max_predictions)

if roi is not None:
one_hot_predictions = one_hot_predictions * roi.squeeze(1)
target = target * roi.squeeze(1)

surface_dice = compute_surface_dice(
one_hot_predictions,
target,
class_thresholds=self._class_thresholds,
distance_metric="chessboard",
include_background=True,
)

surface_dice[surface_dice.isnan()] = torch.ones(1)

return surface_dice

def _get_surface_dice_averaged_over_total_wsis(self) -> dict[int, float]:
surface_dices: dict[int, list[float]] = {class_idx: [] for class_idx in range(self._num_classes)}
for wsi_name in self.wsis:
self.get_wsi_score(wsi_name)
for class_idx in range(self._num_classes):
surface_dices[class_idx].append(
self.wsis[wsi_name][class_idx]["surface_dice"] / self.wsis[wsi_name]["total_tiles"]
)
return {
class_idx: sum(surface_dices[class_idx]) / len(surface_dices[class_idx])
for class_idx in surface_dices.keys()
}

def _initialize_wsi_dict(self, wsi_name: str) -> None:
self.wsis[wsi_name] = {class_idx: {"surface_dice": torch.zeros(1)} for class_idx in range(self._num_classes)}
self.wsis[wsi_name]["total_tiles"] = 0

def get_average_score(
self, precomputed_output: list[list[dict[str, dict[str, float]]]] | None = None
) -> dict[Any, Any]:
surface_dices = self._get_surface_dice_averaged_over_total_wsis()
avg_dict = {
f"{self.name}/surface_dice/{self._label_to_class[idx]}": value for idx, value in surface_dices.items()
}
return avg_dict

@staticmethod
def static_average_wsi_surface_dice(
precomputed_output: list[list[dict[str, dict[str, float]]]]
) -> dict[str, float]:
"""Static method to compute the average WSI surface dice score over a list of WSI surface dice scores,
useful for multiprocessing."""
# Initialize defaultdicts to handle the sum and count of dice scores for each class
class_sum: dict[str, float] = defaultdict(float)
class_count: dict[str, int] = defaultdict(int)

# Flatten the list and extract 'wsi_dice' dictionaries
wsi_surface_dices: list[dict[str, float]] = [
wsi_metric.get("wsi_surface_dice", {}) for sublist in precomputed_output for wsi_metric in sublist
]
# Check if the list is empty -- then the precomputed output did not contain any wsi dice scores
if not wsi_surface_dices:
return {}

# Update sum and count for each class in a single pass
for wsi_surface_dice in wsi_surface_dices:
for class_name, surface_dice_score in wsi_surface_dice.items():
class_sum[class_name] += surface_dice_score
class_count[class_name] += 1

# Compute average dice scores in a dictionary comprehension with consistent naming
avg_surface_dice_scores = {
f"{'wsi_surface_dice'}/{class_name}": class_sum[class_name] / class_count[class_name]
for class_name in class_sum.keys()
}
return avg_surface_dice_scores

def reset(self) -> None:
self.wsis = {}

def __repr__(self) -> str:
return f"{type(self).__name__}(num_classes={self._num_classes})"


class WSIDiceMetric(WSIMetric):
"""WSI Dice metric class, computes the dice score over the whole WSI"""

def __init__(self, data_description: DataDescription, compute_overall_dice: bool = False) -> None:
def __init__(self, data_description: DataDescription, compute_overall_dice: bool = False, **kwargs: Any) -> None:
super().__init__(data_description=data_description)
self.compute_overall_dice = compute_overall_dice
self._num_classes = self._data_description.num_classes
Expand Down Expand Up @@ -331,7 +477,8 @@ def metrics(self) -> list[WSIMetric]:
@classmethod
def for_segmentation(cls, *args: Any, **kwargs: Any) -> WSIMetricFactory:
dices = WSIDiceMetric(*args, **kwargs)
return cls([dices])
surface_dices = WSiSurfaceDiceMetric(*args, **kwargs)
return cls([dices, surface_dices])

@classmethod
def for_wsi_classification(cls, *args: Any, **kwargs: Any) -> WSIMetricFactory:
Expand Down
3 changes: 3 additions & 0 deletions config/metrics/segmentation.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,7 @@ tile_level:

wsi_level:
_target_: ahcore.metrics.WSIMetricFactory.for_segmentation
# This is for the Dice score based on pixel counting
compute_overall_dice: True
# This is for the normalized surface dice (boundary overlap)
class_thresholds: [0,3,3,3]
116 changes: 116 additions & 0 deletions tests/test_metrics/test_surface_dice_metric.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
from pathlib import Path

import pytest
import torch

from ahcore.metrics.metrics import WSiSurfaceDiceMetric
from ahcore.utils.data import DataDescription, GridDescription


@pytest.fixture
def data_description() -> DataDescription:
num_classes = 4
index_map = {"class1": 1, "class2": 2, "class3": 3}
data_dir = Path("data_dir")
manifest_database_uri = "manifest_database_uri"
manifest_name = "manifest_name"
split_version = "split_version"
annotations_dir = Path("annotations_dir")
training_grid = GridDescription(mpp=1.0, tile_size=(256, 256), tile_overlap=(0, 0), output_tile_size=(256, 256))
inference_grid = GridDescription(mpp=1.0, tile_size=(256, 256), tile_overlap=(0, 0), output_tile_size=(256, 256))
return DataDescription(
num_classes=num_classes,
index_map=index_map,
data_dir=data_dir,
manifest_database_uri=manifest_database_uri,
manifest_name=manifest_name,
split_version=split_version,
annotations_dir=annotations_dir,
training_grid=training_grid,
inference_grid=inference_grid,
)


@pytest.fixture
def metric(data_description: DataDescription) -> WSiSurfaceDiceMetric:
class_thresholds = [0.0, 1.0, 1.0, 1.0]
return WSiSurfaceDiceMetric(data_description=data_description, class_thresholds=class_thresholds)


def get_batch() -> tuple[torch.Tensor, torch.Tensor, None, str]:
predictions = torch.randn(1, 4, 256, 256).float()
target = torch.zeros(1, 4, 256, 256).float() # Mock target tensor
target[0, 1, 128, 128] = 1.0 # Simulate class 1 presence
target[0, 2, 64, 64] = 1.0 # Simulate class 2 presence
target[0, 3, 10, 10] = 1.0 # Simulate class 3 presence
roi = None
wsi_name = "test_wsi"
return predictions, target, roi, wsi_name


def test_process_batch(metric: WSiSurfaceDiceMetric) -> None:
predictions, target, roi, wsi_name = get_batch()
metric.process_batch(predictions, target, roi, wsi_name)
metric.get_wsi_score(wsi_name)

assert wsi_name in metric.wsis
assert all("surface_dice" in metric.wsis[wsi_name][i] for i in range(metric._num_classes))


def test_get_average_score(metric: WSiSurfaceDiceMetric) -> None:
predictions, target, roi, wsi_name = get_batch()
metric.process_batch(predictions, target, roi, wsi_name)
metric.get_wsi_score(wsi_name)

average_scores = metric.get_average_score()

assert isinstance(average_scores, dict)
assert all(
f"{metric.name}/surface_dice/{metric._label_to_class[idx]}" in average_scores
for idx in range(metric._num_classes)
)


def test_static_average_wsi_surface_dice(metric: WSiSurfaceDiceMetric) -> None:
precomputed_output = [
[{"wsi_surface_dice": {"class1": 0.8, "class2": 0.7, "class3": 0.9}}],
[{"wsi_surface_dice": {"class1": 0.85, "class2": 0.75, "class3": 0.95}}],
]

average_scores = WSiSurfaceDiceMetric.static_average_wsi_surface_dice(precomputed_output)

assert isinstance(average_scores, dict)
assert "wsi_surface_dice/class1" in average_scores
assert "wsi_surface_dice/class2" in average_scores
assert "wsi_surface_dice/class3" in average_scores


def test_reset(metric: WSiSurfaceDiceMetric) -> None:
predictions, target, roi, wsi_name = get_batch()
metric.process_batch(predictions, target, roi, wsi_name)
metric.reset()

assert metric.wsis == {}


def test_surface_dice_edge_cases(metric: WSiSurfaceDiceMetric) -> None:
# Note: In this test, background is ignored.
predictions = torch.zeros(1, 4, 256, 256).float()
target = torch.zeros(1, 4, 256, 256).float()
# Set targets and predictions in such a way that the boundaries are 1 pixel apart.
# The surface dice should be 1.0 for class 1, as a shift by 1 pixel falls under the tolerance limit.
predictions[0, 1, 0:128, 0:128] = 1.0
target[0, 1, 1:129, 1:129] = 1.0
# Test case where there is no overlap between the target and the prediction boundaries.
predictions[0, 2, 129:256, 129:256] = 1.0
target[0, 2, 0:64, 0:64] = 1.0

metric.process_batch(predictions, target, None, "wsi_1")
metric.get_wsi_score("wsi_1")

scores = metric.get_average_score()

assert scores[f"{metric.name}/surface_dice/class1"] == 1.0
assert scores[f"{metric.name}/surface_dice/class2"] == 0.0
# Monai returns nan for class 3 since there are neither targets nor predictions but we change it to 1.0.
assert scores[f"{metric.name}/surface_dice/class3"] == 1.0