From a0d0c5b24097e076fb4bba9e5444d62dc45bd605 Mon Sep 17 00:00:00 2001 From: Josh Veitch-Michaelis Date: Tue, 13 Jan 2026 13:55:04 -0500 Subject: [PATCH] use torchmetrics for evaluation --- src/deepforest/main.py | 199 +++++++++++++++----------------------- src/deepforest/metrics.py | 150 ++++++++++++++++++++++++++++ tests/test_main.py | 66 ++++++++++--- 3 files changed, 280 insertions(+), 135 deletions(-) create mode 100644 src/deepforest/metrics.py diff --git a/src/deepforest/main.py b/src/deepforest/main.py index 45de757ea..0ee0eb862 100644 --- a/src/deepforest/main.py +++ b/src/deepforest/main.py @@ -3,12 +3,10 @@ import os import warnings -import geopandas as gpd import numpy as np import pandas as pd import pytorch_lightning as pl import torch -from lightning_fabric.utilities.exceptions import MisconfigurationException from omegaconf import DictConfig, OmegaConf from PIL import Image from pytorch_lightning.callbacks import LearningRateMonitor @@ -19,6 +17,7 @@ from deepforest import evaluate as evaluate_iou from deepforest import predict, utilities from deepforest.datasets import prediction, training +from deepforest.metrics import RecallPrecision class deepforest(pl.LightningModule): @@ -70,24 +69,15 @@ def __init__( self.existing_train_dataloader = existing_train_dataloader self.existing_val_dataloader = existing_val_dataloader - # Metrics - self.iou_metric = IntersectionOverUnion( - class_metrics=True, iou_threshold=self.config.validation.iou_threshold - ) - self.mAP_metric = MeanAveragePrecision(backend="faster_coco_eval") - - # Empty frame accuracy - self.empty_frame_accuracy = BinaryAccuracy() - - # Create a default trainer. - self.create_trainer() - self.model = model self.original_batch_structure = [] if self.model is None: self.create_model() + # Create a default trainer. + self.create_trainer() + # Add user supplied transforms if transforms is None: self.transforms = None @@ -98,6 +88,25 @@ def __init__( {"config": OmegaConf.to_container(self.config, resolve=True)} ) + def setup_metrics(self): + # Guard against initialization before a validation csv_file is set + if not self.config.validation.csv_file: + return + + # Metrics + self.iou_metric = IntersectionOverUnion( + class_metrics=True, iou_threshold=self.config.validation.iou_threshold + ) + self.mAP_metric = MeanAveragePrecision(backend="faster_coco_eval") + + # Empty frame accuracy + self.empty_frame_accuracy = BinaryAccuracy() + + self.precision_recall_metric = RecallPrecision( + csv_file=self.config.validation.csv_file, + label_dict=self.label_dict, + ) + def load_model(self, model_name=None, revision=None): """Loads a model that has already been pretrained for a specific task, like tree crown detection. @@ -190,6 +199,10 @@ def create_trainer(self, logger=None, callbacks=None, **kwargs): callbacks: Optional list of callbacks **kwargs: Additional trainer arguments """ + + # Setup metrics which may have changed if the config was modified + self.setup_metrics() + if callbacks is None: callbacks = [] # If val data is passed, monitor learning rate and setup classification metrics @@ -704,15 +717,10 @@ def validation_step(self, batch, batch_idx): losses = sum(loss_dict.values()) # Log losses - try: - for key, value in loss_dict.items(): - self.log( - f"val_{key}", value.detach(), on_epoch=True, batch_size=len(images) - ) + for key, value in loss_dict.items(): + self.log(f"val_{key}", value.detach(), on_epoch=True, batch_size=len(images)) - self.log("val_loss", losses.detach(), on_epoch=True, batch_size=len(images)) - except MisconfigurationException: - pass + self.log("val_loss", losses.detach(), on_epoch=True, batch_size=len(images)) # In eval model, return predictions to calculate prediction metrics self.model.eval() @@ -723,13 +731,29 @@ def validation_step(self, batch, batch_idx): # Remove empty targets and corresponding predictions filtered_preds = [] filtered_targets = [] + for i, target in enumerate(targets): - if target["boxes"].shape[0] > 0: + # Empty frame accuracy + is_empty_frame = target["boxes"].numel() == 0 or torch.all( + target["boxes"] == 0 + ) + if is_empty_frame: + # 0 indicates empty frame or predication + device = target["boxes"].device + self.empty_frame_accuracy.update( + torch.tensor([min(len(preds[i]["boxes"]), 1)], device=device), + torch.tensor([0.0], device=device), + ) + else: + # Non-empty frames go to all metrics filtered_preds.append(preds[i]) filtered_targets.append(target) + # IoU and mAP metrics need preds/targets to exist self.iou_metric.update(filtered_preds, filtered_targets) self.mAP_metric.update(filtered_preds, filtered_targets) + # Precision recall metric can handle empty frames internally + self.precision_recall_metric.update(preds, image_names) # Log the predictions if you want to use them for evaluation logs for i, result in enumerate(preds): @@ -799,49 +823,38 @@ def calculate_empty_frame_accuracy(self, ground_df, predictions_df): # Calculate accuracy using metric self.empty_frame_accuracy.update(predictions, gt) empty_accuracy = self.empty_frame_accuracy.compute() + self.empty_frame_accuracy.reset() # Log empty frame accuracy - try: - self.log("empty_frame_accuracy", empty_accuracy) - except MisconfigurationException: - pass + self.log("empty_frame_accuracy", empty_accuracy) return empty_accuracy - def log_epoch_metrics(self): + def _compute_epoch_metrics(self) -> dict: + """Compute metrics and returns a Lightning-loggable dictionary. + + This function is called automatically at the end of validation. + """ + metrics = {} + + # IoU and mAP if len(self.iou_metric.groundtruth_labels) > 0: - output = self.iou_metric.compute() + metrics.update(self.iou_metric.compute()) # Lightning bug: claims this is a warning but it's not. See issue #16218 in Lightning-AI/pytorch-lightning - try: - self.log_dict(output) - except Exception: - pass - - self.iou_metric.reset() output = self.mAP_metric.compute() - # Keep only overall mAP; drop extra map_* and classes clutter - if isinstance(output, dict): - # Remove classes entry if present - if "classes" in output: - output.pop("classes", None) - # Reduce to only overall 'map' and map_50 if available - output = {k: v for k, v in output.items() if k in ["map", "map_50"]} - try: - self.log_dict(output) - except MisconfigurationException: - pass - self.mAP_metric.reset() - - # Log empty frame accuracy if it has been updated - if self.empty_frame_accuracy._update_called: - empty_accuracy = self.empty_frame_accuracy.compute() + # Remove classes from output dict + output = {key: value for key, value in output.items() if not key == "classes"} + metrics.update(output) + + # Box recall/precision + metrics.update(self.precision_recall_metric.compute()) + + # Empty frame accuracy + if self.empty_frame_accuracy.update_called: + metrics["empty_frame_accuracy"] = self.empty_frame_accuracy.compute() - # Log empty frame accuracy - try: - self.log("empty_frame_accuracy", empty_accuracy) - except MisconfigurationException: - pass + return metrics def on_validation_epoch_end(self): """Compute metrics and predictions at the end of the validation @@ -850,23 +863,16 @@ def on_validation_epoch_end(self): return # Log epoch metrics - self.log_epoch_metrics() - if (self.current_epoch + 1) % self.config.validation.val_accuracy_interval == 0: - if len(self.predictions) > 0: - predictions = pd.concat(self.predictions) - else: - predictions = pd.DataFrame() - - results = self.__evaluate__( - self.config.validation.csv_file, - root_dir=self.config.validation.root_dir, - predictions=predictions, - ) + metrics = self._compute_epoch_metrics() + self.log_dict(metrics) - self.__evaluation_logs__(results) - - return results + # Manual reset. Lightning does not do this automatically + # unless we log the metric objects directly + self.precision_recall_metric.reset() + self.iou_metric.reset() + self.mAP_metric.reset() + self.empty_frame_accuracy.reset() def predict_step(self, batch, batch_idx): """Predict a batch of images with the deepforest model. If batch is a @@ -1040,8 +1046,6 @@ def __evaluate__( empty_accuracy = self.calculate_empty_frame_accuracy(ground_df, predictions) results["empty_frame_accuracy"] = empty_accuracy - self.__evaluation_logs__(results) - return results def evaluate( @@ -1079,52 +1083,3 @@ def evaluate( root_dir=root_dir, predictions=predictions, ) - - def __evaluation_logs__(self, results): - """Log metrics from evaluation results.""" - # Log metrics - for key, value in results.items(): - if type(value) in [ - pd.DataFrame, - gpd.GeoDataFrame, - utilities.DeepForest_DataFrame, - ]: - pass - elif value is None: - pass - else: - try: - self.log(key, value) - except MisconfigurationException: - pass - - # Log each key value pair of the results dict - if results["class_recall"] is not None and self.config.num_classes > 1: - for key, value in results.items(): - if key in ["class_recall"]: - for _, row in value.iterrows(): - try: - self.log( - "{}_Recall".format( - self.numeric_to_label_dict[row["label"]] - ), - row["recall"], - ) - self.log( - "{}_Precision".format( - self.numeric_to_label_dict[row["label"]] - ), - row["precision"], - ) - except MisconfigurationException: - pass - elif key in ["predictions", "results", "ground_df"]: - # Don't log dataframes of predictions or IoU results per epoch - pass - elif value is None: - pass - else: - try: - self.log(key, value) - except MisconfigurationException: - pass diff --git a/src/deepforest/metrics.py b/src/deepforest/metrics.py new file mode 100644 index 000000000..b6553bf99 --- /dev/null +++ b/src/deepforest/metrics.py @@ -0,0 +1,150 @@ +import warnings + +import geopandas as gpd +import pandas as pd +import torch +from torch import Tensor +from torchmetrics import Metric + +from deepforest import utilities +from deepforest.evaluate import __evaluate_wrapper__ + + +class RecallPrecision(Metric): + """DeepForest box recall and precision metric. + + This class is a thin wrapper around evaluate_boxes to compute box + recall and precision during validation. In multi-GPU environments, + each rank runs evaluation on the full (gathered) dataset. + """ + + boxes: list[Tensor] + labels: list[Tensor] + scores: list[Tensor] + image_indices: list[Tensor] + + def __init__( + self, + csv_file: str, + label_dict: dict, + task="box", + iou_threshold: float = 0.4, + **kwargs, + ) -> None: + """This metric performs DeepForest's box recall and precision + evaluation. + + Args: + csv_file (str): Path to CSV file with ground truth boxes. + label_dict (dict): Dictionary mapping string labels to numeric labels. + iou_threshold (float, optional): IOU threshold for evaluation. Defaults to 0.4. + """ + super().__init__(**kwargs) + + self.csv_file = csv_file + self.iou_threshold = iou_threshold + self.label_dict = label_dict + + if task != "box": + raise NotImplementedError("Only 'box' task is currently supported.") + + # Create image path index mappings. This is necessary + # as we can't sync strings across multiple GPUs by default. + ground_df = utilities.read_file(csv_file) + unique_paths = sorted(ground_df["image_path"].unique()) + self.path_to_index = {path: idx for idx, path in enumerate(unique_paths)} + self.index_to_path = {idx: path for path, idx in self.path_to_index.items()} + + self.add_state("boxes", default=[], dist_reduce_fx=None) + self.add_state("labels", default=[], dist_reduce_fx=None) + self.add_state("scores", default=[], dist_reduce_fx=None) + self.add_state("image_indices", default=[], dist_reduce_fx=None) + + def update(self, preds: list[dict[str, Tensor]], image_names: list[str]) -> None: + """Update the metric with new predictions. + + Args: + preds (list[dict[str, Tensor]]): List of prediction dictionaries from the model. + image_names (list): List of image names corresponding to the predictions. + """ + for pred, image_name in zip(preds, image_names, strict=True): + # Look up image index; skip if not in ground truth + if image_name not in self.path_to_index: + warnings.warn( + f"Image '{image_name}' not found in ground truth CSV. Skipping.", + stacklevel=2, + ) + continue + + image_idx = self.path_to_index[image_name] + num_boxes = len(pred["boxes"]) + + # Store predictions and image indices as tensors + self.boxes.append(pred["boxes"].detach()) + self.labels.append(pred["labels"].detach()) + self.scores.append(pred["scores"].detach()) + # Create a 1D tensor with one index per box + self.image_indices.append( + torch.full((num_boxes,), image_idx, dtype=torch.long).to( + pred["boxes"].device + ) + ) + + def compute(self) -> dict[str, float]: + """Computes the recall/precision metrics.""" + + ground_df = utilities.read_file(self.csv_file) + numeric_to_label_dict = {v: k for k, v in self.label_dict.items()} + ground_df["label"] = ground_df.label.apply(lambda x: self.label_dict[x]) + + predictions = pd.DataFrame() + if self.boxes: + combined = { + "boxes": torch.cat(self.boxes), + "labels": torch.cat(self.labels), + "scores": torch.cat(self.scores), + } + predictions = utilities.format_geometry(combined) + if predictions is None: + predictions = pd.DataFrame() # Reset to empty DataFrame + else: + # Expand image names to one entry per box + predictions["image_path"] = [ + self.index_to_path[int(idx.item())] + for idx in torch.cat(self.image_indices) + ] + + results = __evaluate_wrapper__( + predictions=predictions, + ground_df=ground_df, + iou_threshold=self.iou_threshold, + numeric_to_label_dict=numeric_to_label_dict, + ) + + filtered_results = {} + + # Extract per-class recall/precision for multi class prediction only. + if len(self.label_dict) > 1: + if "class_recall" in results and results["class_recall"] is not None: + for _, row in results["class_recall"].iterrows(): + filtered_results[ + "{}_Recall".format(numeric_to_label_dict[row["label"]]) + ] = row["recall"] + filtered_results[ + "{}_Precision".format(numeric_to_label_dict[row["label"]]) + ] = row["precision"] + + # Filter out values that cannot be logged + for key, value in results.items(): + if isinstance(value, (pd.DataFrame, gpd.GeoDataFrame)): + pass + elif value is None: + pass + else: + filtered_results[key] = value + + return filtered_results + + def reset(self) -> None: + """Reset metric state.""" + super().reset() diff --git a/tests/test_main.py b/tests/test_main.py index 30db1d52f..cd4b46f46 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -29,6 +29,7 @@ # Just download once. from .conftest import download_release from unittest.mock import Mock +import unittest.mock as mock ALL_ARCHITECTURES = ["retinanet", "DeformableDetr"] @@ -212,7 +213,8 @@ def test_validation_step(m): batch = next(iter(val_dataloader)) m.predictions = [] m.targets = {} - val_loss = m.validation_step(batch, 0) + with mock.patch.object(m, 'log') as _: + val_loss = m.validation_step(batch, 0) assert val_loss != 0 def test_validation_step_empty(m_without_release): @@ -226,7 +228,8 @@ def test_validation_step_empty(m_without_release): batch = next(iter(val_dataloader)) m.predictions = [] m.targets = {} - val_predictions = m.validation_step(batch, 0) + with mock.patch.object(m, 'log') as _: + _ = m.validation_step(batch, 0) assert m.iou_metric.compute()["iou"] == 0 def test_validate(m): @@ -558,7 +561,6 @@ def on_train_end(self, trainer, pl_module): trainer = Trainer(fast_dev_run=True) trainer.fit(m, train_ds) - def test_over_score_thresh(m): """A user might want to change the config after model training and update the score thresh""" img = get_data("OSBS_029.png") @@ -573,13 +575,42 @@ def test_over_score_thresh(m): assert m.model.score_thresh == high_thresh assert not m.model.score_thresh == original_score_thresh +def test_logged_metrics(m, tmp_path): + """Test that all expected metrics are logged during training.""" -def test_iou_metric(m): - results = m.trainer.validate(m) - keys = ['val_classification', 'val_bbox_regression', 'iou', 'iou/cl_0'] - for x in keys: - assert x in list(results[0].keys()) + # Create an empty frame using an existing test image + ground_df = pd.read_csv(get_data("example.csv")) + empty_frame = ground_df.iloc[0:1].copy() + empty_frame.loc[:, ["xmin", "ymin", "xmax", "ymax"]] = 0 + empty_frame.loc[:, "image_path"] = "OSBS_029.png" + # Place empty frame first so it's processed in fast_dev_run mode + validation_df = pd.concat([empty_frame, ground_df]) + validation_df.to_csv(tmp_path / "validation.csv", index=False) + m.config.validation.csv_file = tmp_path / "validation.csv" + m.create_trainer() + m.trainer.fit(m) + + logged = m.trainer.logged_metrics + + # Torchmetrics + training metrics + metrics = [ + 'train_loss_step', + 'train_classification_step', # RetinaNet specific + 'train_bbox_regression_step', # RetinaNet specific + 'val_loss', + 'val_classification', + 'val_bbox_regression', + 'iou', + 'map', + 'map_50', + 'map_75', + 'box_precision', + 'box_recall', + 'empty_frame_accuracy' + ] + for metric in metrics: + assert metric in logged, f"Expected metric '{metric}' not found in logged metrics." def test_config_args(m): assert not m.config.num_classes == 2 @@ -921,13 +952,19 @@ def test_epoch_evaluation_end(m, tmp_path): boxes["label"] = "Tree" m.predictions = [predictions] boxes.to_csv(tmp_path / "predictions.csv", index=False) - m.config.validation.csv_file = str(tmp_path / "predictions.csv") + + m.config.validation.csv_file = tmp_path / "predictions.csv" m.config.validation.root_dir = str(tmp_path) + # Recreate metrics after changing validation csv_file + m.setup_metrics() + m.precision_recall_metric.update(preds, ["test"]) - results = m.on_validation_epoch_end() + with mock.patch.object(m, 'log_dict') as mock_log: + m.on_validation_epoch_end() + logged_metrics = mock_log.call_args[0][0] - assert results["box_precision"] == 1.0 - assert results["box_recall"] == 1.0 + assert logged_metrics["box_precision"] == 1.0 + assert logged_metrics["box_recall"] == 1.0 def test_epoch_evaluation_end_empty(m): """If the model returns an empty prediction, the metrics should not fail""" @@ -944,7 +981,9 @@ def test_epoch_evaluation_end_empty(m): boxes = format_geometry(preds[0]) boxes["image_path"] = "test" m.predictions = [boxes] - m.on_validation_epoch_end() + + with mock.patch.object(m, 'log_dict') as _: + m.on_validation_epoch_end() def test_empty_frame_accuracy_all_empty_with_predictions(m, tmp_path): """Test empty frame accuracy when all frames are empty but model predicts objects. @@ -960,6 +999,7 @@ def test_empty_frame_accuracy_all_empty_with_predictions(m, tmp_path): m.config.validation["csv_file"] = str(tmp_path / "ground_truth.csv") m.config.validation["root_dir"] = os.path.dirname(get_data("testfile_deepforest.csv")) + # Recreate metrics after changing validation csv_file m.create_trainer() results = m.trainer.validate(m)