From 565ced17314ce5bec6e14b629b203202f0d57615 Mon Sep 17 00:00:00 2001 From: vicky sharma Date: Wed, 18 Feb 2026 16:25:02 +0530 Subject: [PATCH] Add detections_per_img and topk_candidates config options for RetinaNet --- src/deepforest/conf/config.yaml | 3 ++- src/deepforest/conf/schema.py | 2 ++ src/deepforest/models/retinanet.py | 4 ++++ tests/test_main.py | 22 ++++++++++++++++++++++ 4 files changed, 30 insertions(+), 1 deletion(-) diff --git a/src/deepforest/conf/config.yaml b/src/deepforest/conf/config.yaml index 8655f9281..92516dbe8 100644 --- a/src/deepforest/conf/config.yaml +++ b/src/deepforest/conf/config.yaml @@ -11,7 +11,8 @@ batch_size: 1 architecture: 'retinanet' nms_thresh: 0.05 score_thresh: 0.1 - +detections_per_img: 300 +topk_candidates: 1000 # Set model name to None to initialize from scratch model: name: 'weecology/deepforest-tree' diff --git a/src/deepforest/conf/schema.py b/src/deepforest/conf/schema.py index d0aa84806..eebdd03dd 100644 --- a/src/deepforest/conf/schema.py +++ b/src/deepforest/conf/schema.py @@ -138,6 +138,8 @@ class Config: nms_thresh: float = 0.05 score_thresh: float = 0.1 + detections_per_img: int = 300 + topk_candidates: int = 1000 model: ModelConfig = field(default_factory=ModelConfig) log_root: str = "./" diff --git a/src/deepforest/models/retinanet.py b/src/deepforest/models/retinanet.py index cfb114318..554c24780 100644 --- a/src/deepforest/models/retinanet.py +++ b/src/deepforest/models/retinanet.py @@ -199,6 +199,8 @@ def create_model( nms_thresh=self.config.nms_thresh, score_thresh=self.config.score_thresh, label_dict=label_dict, + detections_per_img=self.config.detections_per_img, + topk_candidates=self.config.topk_candidates, ) else: # Pre 2.0 compatibility, the score_threshold used to be stored under retinanet.score_thresh @@ -214,6 +216,8 @@ def create_model( label_dict=label_dict, nms_thresh=self.config.nms_thresh, score_thresh=self.config.score_thresh, + detections_per_img=self.config.detections_per_img, + topk_candidates=self.config.topk_candidates, **hf_args, ) diff --git a/tests/test_main.py b/tests/test_main.py index 5450dfce7..131815dd4 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -1173,6 +1173,7 @@ def test_custom_log_root(m, tmpdir): version_dir = version_dirs[0] assert version_dir.join("hparams.yaml").exists(), "hparams.yaml not found" + def test_huggingface_model_loads_correct_label_dict(): """Regression test for #1286: HuggingFace models should load correct label_dict from config.json. @@ -1194,3 +1195,24 @@ def test_huggingface_model_loads_correct_label_dict(): actual = set(m.label_dict.keys()) assert actual == expected, f"Expected {expected}, got {actual}" + +def test_detections_per_img_and_topk_candidates_config(): + """Test that detections_per_img and topk_candidates can be configured + and are passed through to the underlying model.""" + m = main.deepforest() + + # Check default values + assert m.config.detections_per_img == 300 + assert m.config.topk_candidates == 1000 + + # Test custom values + m.config.detections_per_img = 500 + m.config.topk_candidates = 2000 + + assert m.config.detections_per_img == 500 + assert m.config.topk_candidates == 2000 + + # Verify values are passed to actual model + m.create_model() + assert m.model.detections_per_img == 500 + assert m.model.topk_candidates == 2000