From b6123683fd5fc5ed09ecdebb3b290aec13b89c55 Mon Sep 17 00:00:00 2001 From: Ajey Pai K Date: Wed, 3 Jul 2024 16:23:31 +0200 Subject: [PATCH] Implement `ignore_index` in metrics --- .../converters/wsi_metric_callback.py | 12 ++++- ahcore/metrics/metrics.py | 49 ++++++++++++++++--- config/callbacks/write_file_callback.yaml | 2 + config/metrics/segmentation.yaml | 2 + 4 files changed, 57 insertions(+), 8 deletions(-) diff --git a/ahcore/callbacks/converters/wsi_metric_callback.py b/ahcore/callbacks/converters/wsi_metric_callback.py index 3001f7c..bb5918a 100644 --- a/ahcore/callbacks/converters/wsi_metric_callback.py +++ b/ahcore/callbacks/converters/wsi_metric_callback.py @@ -26,7 +26,11 @@ class ComputeWsiMetricsCallback(ConvertCallbacks): def __init__( - self, reader_class: Type[FileImageReader], max_concurrent_tasks: int = 1, save_per_image: bool = True + self, + reader_class: Type[FileImageReader], + ignore_index: Optional[int], + max_concurrent_tasks: int = 1, + save_per_image: bool = True, ) -> None: """ Callback to compute metrics on whole-slide images. This callback is used to compute metrics on whole-slide @@ -36,6 +40,8 @@ def __init__( ---------- reader_class : FileImageReader The reader class to use to read the images, e.g., H5FileImageReader or ZarrFileImageReader. + ignore_index: Optional[int] + The index to ignore when computing the metrics. max_concurrent_tasks : int The maximum number of concurrent processes. save_per_image : bool @@ -48,6 +54,7 @@ def __init__( self._dump_dir: Path self._save_per_image: bool = save_per_image self._filenames: dict[Path, Path] = {} + self._ignore_index: Optional[int] = ignore_index self._wsi_metrics: WSIMetricFactory self._class_names: dict[int, str] = {} @@ -106,6 +113,7 @@ def process_task(self, filename: Path, cache_filename: Path) -> dict[str, str | class_names=self._class_names, data_description=self._data_description, wsi_metrics=self._wsi_metrics, + ignore_index=self._ignore_index, ) if self._save_per_image: @@ -155,6 +163,7 @@ def compute_metrics_for_case( class_names: dict[int, str], data_description: DataDescription, wsi_metrics: WSIMetricFactory, + ignore_index: Optional[int], ) -> dict[str, Any]: with image_reader(task_data.cache_filename, stitching_mode=StitchingMode.CROP) as cache_reader: dataset_of_validation_image = _ValidationDataset( @@ -182,6 +191,7 @@ def compute_metrics_for_case( wsi_metrics_dictionary[metric.name] = { class_names[class_idx]: metric.wsis[str(task_data.filename)][class_idx][metric.name].item() for class_idx in range(data_description.num_classes) + if ignore_index is None or class_idx != ignore_index } return wsi_metrics_dictionary diff --git a/ahcore/metrics/metrics.py b/ahcore/metrics/metrics.py index cb77d18..526007d 100644 --- a/ahcore/metrics/metrics.py +++ b/ahcore/metrics/metrics.py @@ -29,7 +29,7 @@ def __call__( class DiceMetric(TileMetric): - def __init__(self, data_description: DataDescription) -> None: + def __init__(self, data_description: DataDescription, ignore_index: int) -> None: r""" Metric computing dice over classes. The classes are derived from the index_map that's defined in the data_description. @@ -50,6 +50,9 @@ def __init__(self, data_description: DataDescription) -> None: Parameters ---------- data_description : DataDescription + The data description object. + ignore_index : int + The index to ignore in the computation of the dice score. """ super().__init__(data_description=data_description) self._num_classes = self._data_description.num_classes @@ -64,6 +67,7 @@ def __init__(self, data_description: DataDescription) -> None: _label_to_class = {v: k for k, v in _index_map.items()} _label_to_class[0] = "background" self._label_to_class = _label_to_class + self._ignore_index = ignore_index self.name = "dice" @@ -77,7 +81,11 @@ def __call__( dice_score = _compute_dice(intersection, cardinality) dices.append(dice_score) - output = {f"{self.name}/{self._label_to_class[idx]}": dices[idx] for idx in range(0, self._num_classes)} + output = { + f"{self.name}/{self._label_to_class[idx]}": dices[idx] + for idx in range(0, self._num_classes) + if idx != self._ignore_index + } return output def __repr__(self) -> str: @@ -162,9 +170,12 @@ def reset(self) -> None: class WSIDiceMetric(WSIMetric): """WSI Dice metric class, computes the dice score over the whole WSI""" - def __init__(self, data_description: DataDescription, compute_overall_dice: bool = False) -> None: + def __init__( + self, data_description: DataDescription, ignore_index: int, compute_overall_dice: bool = False + ) -> None: super().__init__(data_description=data_description) self.compute_overall_dice = compute_overall_dice + self._ignore_index = ignore_index self._num_classes = self._data_description.num_classes # self._device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self._device = "cpu" @@ -200,11 +211,15 @@ def process_batch( self._num_classes, ) for class_idx, (intersection, cardinality) in enumerate(dice_components): + if class_idx == self._ignore_index: + continue self.wsis[wsi_name][class_idx]["intersection"] += intersection self.wsis[wsi_name][class_idx]["cardinality"] += cardinality def get_wsi_score(self, wsi_name: str) -> None: for class_idx in self.wsis[wsi_name]: + if class_idx == self._ignore_index: + continue intersection = self.wsis[wsi_name][class_idx]["intersection"] cardinality = self.wsis[wsi_name][class_idx]["cardinality"] self.wsis[wsi_name][class_idx]["wsi_dice"] = _compute_dice(intersection, cardinality) @@ -225,18 +240,24 @@ def _get_overall_dice(self) -> dict[int, float]: "overall_dice": 0.0, } for class_idx in range(self._num_classes) + if class_idx != self._ignore_index } for wsi_name in self.wsis: for class_idx in range(self._num_classes): + if class_idx == self._ignore_index: + continue overall_dices[class_idx]["total_intersection"] += self.wsis[wsi_name][class_idx]["intersection"] overall_dices[class_idx]["total_cardinality"] += self.wsis[wsi_name][class_idx]["cardinality"] for class_idx in overall_dices.keys(): + if class_idx == self._ignore_index: + continue intersection = overall_dices[class_idx]["total_intersection"] cardinality = overall_dices[class_idx]["total_cardinality"] overall_dices[class_idx]["overall_dice"] = (2 * intersection + 0.01) / (cardinality + 0.01) return { class_idx: torch.tensor(overall_dices[class_idx]["overall_dice"]).item() for class_idx in overall_dices.keys() + if class_idx != self._ignore_index } def _get_dice_averaged_over_total_wsis(self) -> dict[int, float]: @@ -248,16 +269,26 @@ def _get_dice_averaged_over_total_wsis(self) -> dict[int, float]: dict Dictionary with the dice scores averaged over all the WSIs per class """ - dices: dict[int, list[float]] = {class_idx: [] for class_idx in range(self._num_classes)} + dices: dict[int, list[float]] = { + class_idx: [] for class_idx in range(self._num_classes) if class_idx != self._ignore_index + } for wsi_name in self.wsis: self.get_wsi_score(wsi_name) for class_idx in range(self._num_classes): + if class_idx == self._ignore_index: + continue dices[class_idx].append(self.wsis[wsi_name][class_idx]["dice"].item()) - return {class_idx: sum(dices[class_idx]) / len(dices[class_idx]) for class_idx in dices.keys()} + return { + class_idx: sum(dices[class_idx]) / len(dices[class_idx]) + for class_idx in dices.keys() + if class_idx != self._ignore_index + } def _initialize_wsi_dict(self, wsi_name: str) -> None: self.wsis[wsi_name] = { - class_idx: {"intersection": 0, "cardinality": 0, "dice": None} for class_idx in range(self._num_classes) + class_idx: {"intersection": 0, "cardinality": 0, "dice": None} + for class_idx in range(self._num_classes) + if class_idx != self._ignore_index } def get_average_score( @@ -275,7 +306,11 @@ def get_average_score( dices = self._get_overall_dice() else: dices = self._get_dice_averaged_over_total_wsis() - avg_dict = {f"{self.name}/{self._label_to_class[idx]}": value for idx, value in dices.items()} + avg_dict = { + f"{self.name}/{self._label_to_class[idx]}": value + for idx, value in dices.items() + if idx != self._ignore_index + } return avg_dict @staticmethod diff --git a/config/callbacks/write_file_callback.yaml b/config/callbacks/write_file_callback.yaml index 3456749..4316532 100644 --- a/config/callbacks/write_file_callback.yaml +++ b/config/callbacks/write_file_callback.yaml @@ -18,4 +18,6 @@ write_file_callback: reader_class: _target_: ahcore.readers.ZarrFileImageReader _partial_: true + max_concurrent_tasks: 3 + ignore_index: null max_concurrent_queues: 3 diff --git a/config/metrics/segmentation.yaml b/config/metrics/segmentation.yaml index cac63e8..6bb097b 100644 --- a/config/metrics/segmentation.yaml +++ b/config/metrics/segmentation.yaml @@ -1,6 +1,8 @@ tile_level: _target_: ahcore.metrics.MetricFactory.for_segmentation + ignore_index: null wsi_level: _target_: ahcore.metrics.WSIMetricFactory.for_segmentation compute_overall_dice: True + ignore_index: null