Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 39 additions & 0 deletions tests/test_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,13 @@
import tempfile
import copy
import importlib.util
import math

from deepforest import main, get_data, model
from deepforest.utilities import read_file, format_geometry
from deepforest.datasets import prediction
from deepforest.visualize import plot_results
from deepforest.metrics import RecallPrecision

from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import Callback
Expand Down Expand Up @@ -1167,10 +1169,47 @@ def test_predict_file_mixed_sizes(m, tmp_path):
preds = m.predict_file(csv_file=csv_path, root_dir=str(tmp_path))

assert preds.ymax.max() > 200 # The larger image should have predictions outside the 200px limit

def test_recall_not_lowered_by_unprocessed_images():
"""This test checks that recall is only computed for images that were
passed to the metric and ignores unprocessed images in the ground truth
dataframe."""

label_dict = {'Tree': 0}
metric = RecallPrecision(label_dict=label_dict)

# Simulate limit_val_batches: only 2 images processed
# Using different boxes for each image to catch matching bugs
preds = [
{'boxes': torch.tensor([[10, 10, 50, 50]], dtype=torch.float32),
'labels': torch.tensor([0]), 'scores': torch.tensor([0.9])},
{'boxes': torch.tensor([[60, 60, 100, 100]], dtype=torch.float32),
'labels': torch.tensor([0]), 'scores': torch.tensor([0.85])}
]

targets = [
{'boxes': torch.tensor([[10, 10, 50, 50]], dtype=torch.float32),
'labels': torch.tensor([0])},
{'boxes': torch.tensor([[60, 60, 100, 100]], dtype=torch.float32),
'labels': torch.tensor([0])}
]

metric.update(preds, targets, ['img1.jpg', 'img2.jpg'])
results = metric.compute()

# Verify only 2 images were processed (not affected by any external ground truth)
assert metric.num_images == 2

# With perfect matches, recall should be 1.0 (2/2 images matched)
assert math.isclose(results['box_recall'], 1.0, rel_tol=1e-5), (
f"box_recall={results['box_recall']:.2f}, expected 1.0"
)

def test_custom_log_root(m, tmpdir):
"""Test that setting a custom log_root creates logs in the expected location"""
custom_log_dir = tmpdir.join("custom_logs")
m.config.log_root = str(custom_log_dir)

m.config.train.fast_dev_run = False

m.create_trainer(limit_train_batches=1, limit_val_batches=1, max_epochs=1)
Expand Down
Loading