diff --git a/tests/test_main.py b/tests/test_main.py index 9edf61c3c..562969525 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -10,11 +10,13 @@ import tempfile import copy import importlib.util +import math from deepforest import main, get_data, model from deepforest.utilities import read_file, format_geometry from deepforest.datasets import prediction from deepforest.visualize import plot_results +from deepforest.metrics import RecallPrecision from pytorch_lightning import Trainer from pytorch_lightning.callbacks import Callback @@ -1167,10 +1169,47 @@ def test_predict_file_mixed_sizes(m, tmp_path): preds = m.predict_file(csv_file=csv_path, root_dir=str(tmp_path)) assert preds.ymax.max() > 200 # The larger image should have predictions outside the 200px limit + +def test_recall_not_lowered_by_unprocessed_images(): + """This test checks that recall is only computed for images that were + passed to the metric and ignores unprocessed images in the ground truth + dataframe.""" + + label_dict = {'Tree': 0} + metric = RecallPrecision(label_dict=label_dict) + + # Simulate limit_val_batches: only 2 images processed + # Using different boxes for each image to catch matching bugs + preds = [ + {'boxes': torch.tensor([[10, 10, 50, 50]], dtype=torch.float32), + 'labels': torch.tensor([0]), 'scores': torch.tensor([0.9])}, + {'boxes': torch.tensor([[60, 60, 100, 100]], dtype=torch.float32), + 'labels': torch.tensor([0]), 'scores': torch.tensor([0.85])} + ] + + targets = [ + {'boxes': torch.tensor([[10, 10, 50, 50]], dtype=torch.float32), + 'labels': torch.tensor([0])}, + {'boxes': torch.tensor([[60, 60, 100, 100]], dtype=torch.float32), + 'labels': torch.tensor([0])} + ] + + metric.update(preds, targets, ['img1.jpg', 'img2.jpg']) + results = metric.compute() + + # Verify only 2 images were processed (not affected by any external ground truth) + assert metric.num_images == 2 + + # With perfect matches, recall should be 1.0 (2/2 images matched) + assert math.isclose(results['box_recall'], 1.0, rel_tol=1e-5), ( + f"box_recall={results['box_recall']:.2f}, expected 1.0" + ) + def test_custom_log_root(m, tmpdir): """Test that setting a custom log_root creates logs in the expected location""" custom_log_dir = tmpdir.join("custom_logs") m.config.log_root = str(custom_log_dir) + m.config.train.fast_dev_run = False m.create_trainer(limit_train_batches=1, limit_val_batches=1, max_epochs=1)