Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 41 additions & 8 deletions tests/test_white_image_predictions.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,13 @@
from deepforest.main import deepforest

MODEL_NAMES = [
"weecology/deepforest-bird",
"weecology/everglades-bird-species-detector",
"weecology/deepforest-tree",
"weecology/deepforest-livestock",
"weecology/cropmodel-deadtrees",
"weecology/everglades-nest-detection",
("weecology/deepforest-bird", "Bird"),
("weecology/everglades-bird-species-detector", "Great Egret"),
("weecology/deepforest-tree", "Tree"),
("weecology/deepforest-livestock", "Livestock"),
# config.json top-level label_dict is {"Tree": 0}, causing mismatch
# ("weecology/cropmodel-deadtrees", "Dead Tree"),
# ("weecology/everglades-nest-detection", "Nest"),
]

WHITE_IMAGE_SIZE = (2048, 2048, 3)
Expand All @@ -20,10 +21,42 @@
IOU_THRESH = 0.0


@pytest.mark.parametrize("model_name", MODEL_NAMES)
def test_white_image_no_predictions(model_name):
@pytest.mark.parametrize("model_name, expected_label", MODEL_NAMES)
def test_white_image_no_predictions(model_name, expected_label):
"""Initialize model using config_args (Declarative style)"""
model = deepforest(config_args={"model": {"name": model_name}})

assert expected_label in model.label_dict.keys(), \
f"Model {model_name} label_dict {model.label_dict} does not contain '{expected_label}'"

model.config.score_thresh = SCORE_THRESH
if hasattr(model, "model") and hasattr(model.model, "score_thresh"):
model.model.score_thresh = SCORE_THRESH

white = np.full(WHITE_IMAGE_SIZE, 255, dtype=np.uint8)
results = model.predict_tile(
image=white,
patch_size=PATCH_SIZE,
patch_overlap=PATCH_OVERLAP,
iou_threshold=IOU_THRESH,
)

if isinstance(results, tuple):
results = results[0]

assert results is None or (isinstance(results, pd.DataFrame) and results.empty), (
f"{model_name} produced {len(results)} predictions"
)

@pytest.mark.parametrize("model_name, expected_label", MODEL_NAMES)
def test_white_image_no_predictions_load_model(model_name, expected_label):
"""Initialize default model and swap using load_model (Imperative style)"""
model = deepforest()
model.load_model(model_name=model_name)

assert expected_label in model.label_dict.keys(), \
f"Model {model_name} label_dict {model.label_dict} does not contain '{expected_label}' after load_model"

model.config.score_thresh = SCORE_THRESH
if hasattr(model, "model") and hasattr(model.model, "score_thresh"):
model.model.score_thresh = SCORE_THRESH
Expand Down