Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/deepforest/conf/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@ batch_size: 1
architecture: 'retinanet'
nms_thresh: 0.05
score_thresh: 0.1

detections_per_img: 300
topk_candidates: 1000
# Set model name to None to initialize from scratch
model:
name: 'weecology/deepforest-tree'
Expand Down
2 changes: 2 additions & 0 deletions src/deepforest/conf/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,8 @@ class Config:

nms_thresh: float = 0.05
score_thresh: float = 0.1
detections_per_img: int = 300
topk_candidates: int = 1000
model: ModelConfig = field(default_factory=ModelConfig)

log_root: str = "./"
Expand Down
4 changes: 4 additions & 0 deletions src/deepforest/models/retinanet.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,8 @@ def create_model(
nms_thresh=self.config.nms_thresh,
score_thresh=self.config.score_thresh,
label_dict=label_dict,
detections_per_img=self.config.detections_per_img,
topk_candidates=self.config.topk_candidates,
)
else:
# Pre 2.0 compatibility, the score_threshold used to be stored under retinanet.score_thresh
Expand All @@ -214,6 +216,8 @@ def create_model(
label_dict=label_dict,
nms_thresh=self.config.nms_thresh,
score_thresh=self.config.score_thresh,
detections_per_img=self.config.detections_per_img,
topk_candidates=self.config.topk_candidates,
**hf_args,
)

Expand Down
22 changes: 22 additions & 0 deletions tests/test_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -1173,6 +1173,7 @@ def test_custom_log_root(m, tmpdir):
version_dir = version_dirs[0]
assert version_dir.join("hparams.yaml").exists(), "hparams.yaml not found"


def test_huggingface_model_loads_correct_label_dict():
"""Regression test for #1286:
HuggingFace models should load correct label_dict from config.json.
Expand All @@ -1194,3 +1195,24 @@ def test_huggingface_model_loads_correct_label_dict():

actual = set(m.label_dict.keys())
assert actual == expected, f"Expected {expected}, got {actual}"

def test_detections_per_img_and_topk_candidates_config():
"""Test that detections_per_img and topk_candidates can be configured
and are passed through to the underlying model."""
m = main.deepforest()

# Check default values
assert m.config.detections_per_img == 300
assert m.config.topk_candidates == 1000

# Test custom values
m.config.detections_per_img = 500
m.config.topk_candidates = 2000

assert m.config.detections_per_img == 500
assert m.config.topk_candidates == 2000

# Verify values are passed to actual model
m.create_model()
assert m.model.detections_per_img == 500
assert m.model.topk_candidates == 2000