Skip to content
This repository was archived by the owner on Oct 19, 2024. It is now read-only.
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion ahcore/callbacks/converters/wsi_metric_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,11 @@

class ComputeWsiMetricsCallback(ConvertCallbacks):
def __init__(
self, reader_class: Type[FileImageReader], max_concurrent_tasks: int = 1, save_per_image: bool = True
self,
reader_class: Type[FileImageReader],
ignore_index: Optional[int],
max_concurrent_tasks: int = 1,
save_per_image: bool = True,
) -> None:
"""
Callback to compute metrics on whole-slide images. This callback is used to compute metrics on whole-slide
Expand All @@ -36,6 +40,8 @@ def __init__(
----------
reader_class : FileImageReader
The reader class to use to read the images, e.g., H5FileImageReader or ZarrFileImageReader.
ignore_index: Optional[int]
The index to ignore when computing the metrics.
max_concurrent_tasks : int
The maximum number of concurrent processes.
save_per_image : bool
Expand All @@ -48,6 +54,7 @@ def __init__(
self._dump_dir: Path
self._save_per_image: bool = save_per_image
self._filenames: dict[Path, Path] = {}
self._ignore_index: Optional[int] = ignore_index

self._wsi_metrics: WSIMetricFactory
self._class_names: dict[int, str] = {}
Expand Down Expand Up @@ -106,6 +113,7 @@ def process_task(self, filename: Path, cache_filename: Path) -> dict[str, str |
class_names=self._class_names,
data_description=self._data_description,
wsi_metrics=self._wsi_metrics,
ignore_index=self._ignore_index,
)

if self._save_per_image:
Expand Down Expand Up @@ -155,6 +163,7 @@ def compute_metrics_for_case(
class_names: dict[int, str],
data_description: DataDescription,
wsi_metrics: WSIMetricFactory,
ignore_index: Optional[int],
) -> dict[str, Any]:
with image_reader(task_data.cache_filename, stitching_mode=StitchingMode.CROP) as cache_reader:
dataset_of_validation_image = _ValidationDataset(
Expand Down Expand Up @@ -182,6 +191,7 @@ def compute_metrics_for_case(
wsi_metrics_dictionary[metric.name] = {
class_names[class_idx]: metric.wsis[str(task_data.filename)][class_idx][metric.name].item()
for class_idx in range(data_description.num_classes)
if ignore_index is None or class_idx != ignore_index
}

return wsi_metrics_dictionary
49 changes: 42 additions & 7 deletions ahcore/metrics/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def __call__(


class DiceMetric(TileMetric):
def __init__(self, data_description: DataDescription) -> None:
def __init__(self, data_description: DataDescription, ignore_index: int) -> None:
r"""
Metric computing dice over classes. The classes are derived from the index_map that's defined in the
data_description.
Expand All @@ -50,6 +50,9 @@ def __init__(self, data_description: DataDescription) -> None:
Parameters
----------
data_description : DataDescription
The data description object.
ignore_index : int
The index to ignore in the computation of the dice score.
"""
super().__init__(data_description=data_description)
self._num_classes = self._data_description.num_classes
Expand All @@ -64,6 +67,7 @@ def __init__(self, data_description: DataDescription) -> None:
_label_to_class = {v: k for k, v in _index_map.items()}
_label_to_class[0] = "background"
self._label_to_class = _label_to_class
self._ignore_index = ignore_index

self.name = "dice"

Expand All @@ -77,7 +81,11 @@ def __call__(
dice_score = _compute_dice(intersection, cardinality)
dices.append(dice_score)

output = {f"{self.name}/{self._label_to_class[idx]}": dices[idx] for idx in range(0, self._num_classes)}
output = {
f"{self.name}/{self._label_to_class[idx]}": dices[idx]
for idx in range(0, self._num_classes)
if idx != self._ignore_index
}
return output

def __repr__(self) -> str:
Expand Down Expand Up @@ -162,9 +170,12 @@ def reset(self) -> None:
class WSIDiceMetric(WSIMetric):
"""WSI Dice metric class, computes the dice score over the whole WSI"""

def __init__(self, data_description: DataDescription, compute_overall_dice: bool = False) -> None:
def __init__(
self, data_description: DataDescription, ignore_index: int, compute_overall_dice: bool = False
) -> None:
super().__init__(data_description=data_description)
self.compute_overall_dice = compute_overall_dice
self._ignore_index = ignore_index
self._num_classes = self._data_description.num_classes
# self._device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self._device = "cpu"
Expand Down Expand Up @@ -200,11 +211,15 @@ def process_batch(
self._num_classes,
)
for class_idx, (intersection, cardinality) in enumerate(dice_components):
if class_idx == self._ignore_index:
continue
self.wsis[wsi_name][class_idx]["intersection"] += intersection
self.wsis[wsi_name][class_idx]["cardinality"] += cardinality

def get_wsi_score(self, wsi_name: str) -> None:
for class_idx in self.wsis[wsi_name]:
if class_idx == self._ignore_index:
continue
intersection = self.wsis[wsi_name][class_idx]["intersection"]
cardinality = self.wsis[wsi_name][class_idx]["cardinality"]
self.wsis[wsi_name][class_idx]["wsi_dice"] = _compute_dice(intersection, cardinality)
Expand All @@ -225,18 +240,24 @@ def _get_overall_dice(self) -> dict[int, float]:
"overall_dice": 0.0,
}
for class_idx in range(self._num_classes)
if class_idx != self._ignore_index
}
for wsi_name in self.wsis:
for class_idx in range(self._num_classes):
if class_idx == self._ignore_index:
continue
overall_dices[class_idx]["total_intersection"] += self.wsis[wsi_name][class_idx]["intersection"]
overall_dices[class_idx]["total_cardinality"] += self.wsis[wsi_name][class_idx]["cardinality"]
for class_idx in overall_dices.keys():
if class_idx == self._ignore_index:
continue
intersection = overall_dices[class_idx]["total_intersection"]
cardinality = overall_dices[class_idx]["total_cardinality"]
overall_dices[class_idx]["overall_dice"] = (2 * intersection + 0.01) / (cardinality + 0.01)
return {
class_idx: torch.tensor(overall_dices[class_idx]["overall_dice"]).item()
for class_idx in overall_dices.keys()
if class_idx != self._ignore_index
}

def _get_dice_averaged_over_total_wsis(self) -> dict[int, float]:
Expand All @@ -248,16 +269,26 @@ def _get_dice_averaged_over_total_wsis(self) -> dict[int, float]:
dict
Dictionary with the dice scores averaged over all the WSIs per class
"""
dices: dict[int, list[float]] = {class_idx: [] for class_idx in range(self._num_classes)}
dices: dict[int, list[float]] = {
class_idx: [] for class_idx in range(self._num_classes) if class_idx != self._ignore_index
}
for wsi_name in self.wsis:
self.get_wsi_score(wsi_name)
for class_idx in range(self._num_classes):
if class_idx == self._ignore_index:
continue
dices[class_idx].append(self.wsis[wsi_name][class_idx]["dice"].item())
return {class_idx: sum(dices[class_idx]) / len(dices[class_idx]) for class_idx in dices.keys()}
return {
class_idx: sum(dices[class_idx]) / len(dices[class_idx])
for class_idx in dices.keys()
if class_idx != self._ignore_index
}

def _initialize_wsi_dict(self, wsi_name: str) -> None:
self.wsis[wsi_name] = {
class_idx: {"intersection": 0, "cardinality": 0, "dice": None} for class_idx in range(self._num_classes)
class_idx: {"intersection": 0, "cardinality": 0, "dice": None}
for class_idx in range(self._num_classes)
if class_idx != self._ignore_index
}

def get_average_score(
Expand All @@ -275,7 +306,11 @@ def get_average_score(
dices = self._get_overall_dice()
else:
dices = self._get_dice_averaged_over_total_wsis()
avg_dict = {f"{self.name}/{self._label_to_class[idx]}": value for idx, value in dices.items()}
avg_dict = {
f"{self.name}/{self._label_to_class[idx]}": value
for idx, value in dices.items()
if idx != self._ignore_index
}
return avg_dict

@staticmethod
Expand Down
2 changes: 2 additions & 0 deletions config/callbacks/write_file_callback.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -18,4 +18,6 @@ write_file_callback:
reader_class:
_target_: ahcore.readers.ZarrFileImageReader
_partial_: true
max_concurrent_tasks: 3
ignore_index: null
max_concurrent_queues: 3
2 changes: 2 additions & 0 deletions config/metrics/segmentation.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
tile_level:
_target_: ahcore.metrics.MetricFactory.for_segmentation
ignore_index: null

wsi_level:
_target_: ahcore.metrics.WSIMetricFactory.for_segmentation
compute_overall_dice: True
ignore_index: null