diff --git a/tests/test_white_image_predictions.py b/tests/test_white_image_predictions.py index 7fda53876..66f5153b6 100644 --- a/tests/test_white_image_predictions.py +++ b/tests/test_white_image_predictions.py @@ -5,12 +5,13 @@ from deepforest.main import deepforest MODEL_NAMES = [ - "weecology/deepforest-bird", - "weecology/everglades-bird-species-detector", - "weecology/deepforest-tree", - "weecology/deepforest-livestock", - "weecology/cropmodel-deadtrees", - "weecology/everglades-nest-detection", + ("weecology/deepforest-bird", "Bird"), + ("weecology/everglades-bird-species-detector", "Great Egret"), + ("weecology/deepforest-tree", "Tree"), + ("weecology/deepforest-livestock", "Livestock"), + # config.json top-level label_dict is {"Tree": 0}, causing mismatch + # ("weecology/cropmodel-deadtrees", "Dead Tree"), + # ("weecology/everglades-nest-detection", "Nest"), ] WHITE_IMAGE_SIZE = (2048, 2048, 3) @@ -20,10 +21,42 @@ IOU_THRESH = 0.0 -@pytest.mark.parametrize("model_name", MODEL_NAMES) -def test_white_image_no_predictions(model_name): +@pytest.mark.parametrize("model_name, expected_label", MODEL_NAMES) +def test_white_image_no_predictions(model_name, expected_label): + """Initialize model using config_args (Declarative style)""" + model = deepforest(config_args={"model": {"name": model_name}}) + + assert expected_label in model.label_dict.keys(), \ + f"Model {model_name} label_dict {model.label_dict} does not contain '{expected_label}'" + + model.config.score_thresh = SCORE_THRESH + if hasattr(model, "model") and hasattr(model.model, "score_thresh"): + model.model.score_thresh = SCORE_THRESH + + white = np.full(WHITE_IMAGE_SIZE, 255, dtype=np.uint8) + results = model.predict_tile( + image=white, + patch_size=PATCH_SIZE, + patch_overlap=PATCH_OVERLAP, + iou_threshold=IOU_THRESH, + ) + + if isinstance(results, tuple): + results = results[0] + + assert results is None or (isinstance(results, pd.DataFrame) and results.empty), ( + f"{model_name} produced {len(results)} predictions" + ) + +@pytest.mark.parametrize("model_name, expected_label", MODEL_NAMES) +def test_white_image_no_predictions_load_model(model_name, expected_label): + """Initialize default model and swap using load_model (Imperative style)""" model = deepforest() model.load_model(model_name=model_name) + + assert expected_label in model.label_dict.keys(), \ + f"Model {model_name} label_dict {model.label_dict} does not contain '{expected_label}' after load_model" + model.config.score_thresh = SCORE_THRESH if hasattr(model, "model") and hasattr(model.model, "score_thresh"): model.model.score_thresh = SCORE_THRESH