Skip to content

Commit 33beb68

Browse files
committed
add exception if outcome is not binary when classifier is used
1 parent 2a9b516 commit 33beb68

File tree

2 files changed

+22
-4
lines changed

2 files changed

+22
-4
lines changed

doubleml/plm/plr.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ def __init__(
9797
valid_scores = ["IV-type", "partialling out"]
9898
_check_score(self.score, valid_scores, allow_callable=True)
9999

100-
_ = self._check_learner(ml_l, "ml_l", regressor=True, classifier=False)
100+
ml_l_is_classifier = self._check_learner(ml_l, "ml_l", regressor=True, classifier=True)
101101
ml_m_is_classifier = self._check_learner(ml_m, "ml_m", regressor=True, classifier=True)
102102
self._learner = {"ml_l": ml_l, "ml_m": ml_m}
103103

@@ -117,7 +117,17 @@ def __init__(
117117
warnings.warn(("For score = 'IV-type', learners ml_l and ml_g should be specified. Set ml_g = clone(ml_l)."))
118118
self._learner["ml_g"] = clone(ml_l)
119119

120-
self._predict_method = {"ml_l": "predict"}
120+
if ml_l_is_classifier:
121+
if obj_dml_data.binary_outcome:
122+
self._predict_method = {"ml_g": "predict_proba"}
123+
else:
124+
raise ValueError(
125+
f"The ml_l learner {str(ml_l)} was identified as classifier "
126+
"but the outcome variable is not binary with values 0 and 1."
127+
)
128+
else:
129+
self._predict_method = {"ml_l": "predict"}
130+
121131
if "ml_g" in self._learner:
122132
self._predict_method["ml_g"] = "predict"
123133
if ml_m_is_classifier:

doubleml/plm/tests/test_plr_binary_outcome.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,12 +54,20 @@ def generate_binary_data():
5454
def test_dml_plr_binary_warnings(generate_binary_data, learner_binary, score):
5555
data = generate_binary_data
5656
obj_dml_data = dml.DoubleMLData(data, "y", ["d"])
57-
msg = "The ml_m learner {str(ml_m)} was identified as classifier ' \
58-
'but at least one treatment variable is not binary with values 0 and 1.'"
57+
msg = "The ml_l learner {str(ml_l)} was identified as classifier"
5958
with pytest.warns(UserWarning, match=msg):
6059
_ = dml.DoubleMLPLR(obj_dml_data, clone(learner_binary), clone(learner_binary), score=score)
6160

6261

62+
@pytest.mark.ci
63+
def test_dml_plr_binary_exceptions(generate_binary_data, learner_binary, score):
64+
data = generate_binary_data
65+
obj_dml_data = dml.DoubleMLData(data, "X1", ["d"])
66+
msg = "The ml_l learner .+ was identified as classifier but the outcome variable is not binary with values 0 and 1."
67+
with pytest.raises(ValueError, match=msg):
68+
_ = dml.DoubleMLPLR(obj_dml_data, clone(learner_binary), clone(learner_binary), score=score)
69+
70+
6371
@pytest.fixture(scope="module")
6472
def dml_plr_binary_fixture(generate_binary_data, learner_binary, score):
6573
boot_methods = ["normal"]

0 commit comments

Comments
 (0)