Skip to content

Commit a37e8ea

Browse files
authored
Merge pull request #360 from DoubleML/s-add-binary-outcome-plr
Add option to for binary outcomes in PLR
2 parents 815f651 + a5bef26 commit a37e8ea

File tree

4 files changed

+264
-10
lines changed

4 files changed

+264
-10
lines changed

doubleml/plm/plr.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -96,8 +96,10 @@ def __init__(
9696
self._is_cluster_data = self._dml_data.is_cluster_data
9797
valid_scores = ["IV-type", "partialling out"]
9898
_check_score(self.score, valid_scores, allow_callable=True)
99+
if self.score == "IV-type" and obj_dml_data.binary_outcome:
100+
raise ValueError("For score = 'IV-type', additive probability models (binary outcomes) are not supported.")
99101

100-
_ = self._check_learner(ml_l, "ml_l", regressor=True, classifier=False)
102+
ml_l_is_classifier = self._check_learner(ml_l, "ml_l", regressor=True, classifier=True)
101103
ml_m_is_classifier = self._check_learner(ml_m, "ml_m", regressor=True, classifier=True)
102104
self._learner = {"ml_l": ml_l, "ml_m": ml_m}
103105

@@ -117,7 +119,20 @@ def __init__(
117119
warnings.warn(("For score = 'IV-type', learners ml_l and ml_g should be specified. Set ml_g = clone(ml_l)."))
118120
self._learner["ml_g"] = clone(ml_l)
119121

120-
self._predict_method = {"ml_l": "predict"}
122+
if ml_l_is_classifier:
123+
if obj_dml_data.binary_outcome:
124+
self._predict_method = {"ml_l": "predict_proba"}
125+
warnings.warn(
126+
f"The ml_l learner {str(ml_l)} was identified as classifier. Fitting an additive probability model."
127+
)
128+
else:
129+
raise ValueError(
130+
f"The ml_l learner {str(ml_l)} was identified as classifier "
131+
"but the outcome variable is not binary with values 0 and 1."
132+
)
133+
else:
134+
self._predict_method = {"ml_l": "predict"}
135+
121136
if "ml_g" in self._learner:
122137
self._predict_method["ml_g"] = "predict"
123138
if ml_m_is_classifier:

doubleml/plm/tests/_utils_plr_manual.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,10 @@ def fit_nuisance_plr_classifier(
153153
y, x, d, learner_l, learner_m, learner_g, smpls, fit_g=True, l_params=None, m_params=None, g_params=None
154154
):
155155
ml_l = clone(learner_l)
156-
l_hat = fit_predict(y, x, ml_l, l_params, smpls)
156+
if is_classifier(learner_l):
157+
l_hat = fit_predict_proba(y, x, ml_l, l_params, smpls)
158+
else:
159+
l_hat = fit_predict(y, x, ml_l, l_params, smpls)
157160

158161
ml_m = clone(learner_m)
159162
m_hat = fit_predict_proba(d, x, ml_m, m_params, smpls)
Lines changed: 243 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,243 @@
1+
import math
2+
3+
import numpy as np
4+
import pandas as pd
5+
import pytest
6+
from sklearn.base import clone
7+
from sklearn.ensemble import RandomForestClassifier
8+
from sklearn.linear_model import LogisticRegression
9+
10+
import doubleml as dml
11+
12+
from ...tests._utils import draw_smpls
13+
from ._utils_plr_manual import boot_plr, fit_plr, fit_sensitivity_elements_plr
14+
15+
16+
@pytest.fixture(
17+
scope="module", params=[RandomForestClassifier(max_depth=2, n_estimators=10), LogisticRegression(max_iter=1000)]
18+
)
19+
def learner_binary(request):
20+
return request.param
21+
22+
23+
@pytest.fixture(scope="module", params=["partialling out"])
24+
def score(request):
25+
return request.param
26+
27+
28+
@pytest.fixture(scope="module")
29+
def generate_binary_data():
30+
"""Generate synthetic data with binary outcome"""
31+
np.random.seed(42)
32+
n = 500
33+
p = 5
34+
35+
# Generate covariates
36+
X = np.random.normal(0, 1, size=(n, p))
37+
38+
# Generate treatment
39+
d_prob = 1 / (1 + np.exp(-(X[:, 0] + X[:, 1] + np.random.normal(0, 1, n))))
40+
d = np.random.binomial(1, d_prob)
41+
42+
# Generate binary outcome with treatment effect
43+
theta_true = 0.5 # true treatment effect
44+
y_prob = 1 / (1 + np.exp(-(X[:, 0] + X[:, 2] + theta_true * d + np.random.normal(0, 0.5, n))))
45+
y = np.random.binomial(1, y_prob)
46+
47+
# Combine into DataFrame
48+
data = pd.DataFrame({"y": y, "d": d, **{f"X{i+1}": X[:, i] for i in range(p)}})
49+
50+
return data
51+
52+
53+
@pytest.mark.ci
54+
def test_dml_plr_binary_warnings(generate_binary_data, learner_binary, score):
55+
data = generate_binary_data
56+
obj_dml_data = dml.DoubleMLData(data, "y", ["d"])
57+
msg = "The ml_l learner .+ was identified as classifier. Fitting an additive probability model."
58+
with pytest.warns(UserWarning, match=msg):
59+
_ = dml.DoubleMLPLR(obj_dml_data, clone(learner_binary), clone(learner_binary), score=score)
60+
61+
62+
@pytest.mark.ci
63+
def test_dml_plr_binary_exceptions(generate_binary_data, learner_binary, score):
64+
data = generate_binary_data
65+
obj_dml_data = dml.DoubleMLData(data, "X1", ["d"])
66+
msg = "The ml_l learner .+ was identified as classifier but the outcome variable is not binary with values 0 and 1."
67+
with pytest.raises(ValueError, match=msg):
68+
_ = dml.DoubleMLPLR(obj_dml_data, clone(learner_binary), clone(learner_binary), score=score)
69+
70+
# IV-type not possible with binary outcome
71+
obj_dml_data = dml.DoubleMLData(data, "y", ["d"])
72+
msg = r"For score = 'IV-type', additive probability models \(binary outcomes\) are not supported."
73+
with pytest.raises(ValueError, match=msg):
74+
_ = dml.DoubleMLPLR(obj_dml_data, clone(learner_binary), clone(learner_binary), score="IV-type")
75+
76+
77+
@pytest.fixture(scope="module")
78+
def dml_plr_binary_fixture(generate_binary_data, learner_binary, score):
79+
boot_methods = ["normal"]
80+
n_folds = 2
81+
n_rep_boot = 502
82+
83+
# collect data
84+
data = generate_binary_data
85+
x_cols = data.columns[data.columns.str.startswith("X")].tolist()
86+
87+
# Set machine learning methods for m & g
88+
ml_l = clone(learner_binary)
89+
ml_m = clone(learner_binary)
90+
91+
np.random.seed(3141)
92+
obj_dml_data = dml.DoubleMLData(data, "y", ["d"], x_cols)
93+
dml_plr_obj = dml.DoubleMLPLR(obj_dml_data, ml_l, ml_m, n_folds=n_folds, score=score)
94+
dml_plr_obj.fit()
95+
96+
np.random.seed(3141)
97+
y = data["y"].values
98+
x = data.loc[:, x_cols].values
99+
d = data["d"].values
100+
n_obs = len(y)
101+
all_smpls = draw_smpls(n_obs, n_folds)
102+
103+
res_manual = fit_plr(y, x, d, clone(learner_binary), clone(learner_binary), clone(learner_binary), all_smpls, score)
104+
105+
np.random.seed(3141)
106+
# test with external nuisance predictions
107+
dml_plr_obj_ext = dml.DoubleMLPLR(obj_dml_data, ml_l, ml_m, n_folds, score=score)
108+
109+
# synchronize the sample splitting
110+
dml_plr_obj_ext.set_sample_splitting(all_smpls=all_smpls)
111+
prediction_dict = {
112+
"d": {
113+
"ml_l": dml_plr_obj.predictions["ml_l"].reshape(-1, 1),
114+
"ml_m": dml_plr_obj.predictions["ml_m"].reshape(-1, 1),
115+
}
116+
}
117+
dml_plr_obj_ext.fit(external_predictions=prediction_dict)
118+
119+
res_dict = {
120+
"coef": dml_plr_obj.coef.item(),
121+
"coef_manual": res_manual["theta"],
122+
"coef_ext": dml_plr_obj_ext.coef.item(),
123+
"se": dml_plr_obj.se.item(),
124+
"se_manual": res_manual["se"],
125+
"se_ext": dml_plr_obj_ext.se.item(),
126+
"boot_methods": boot_methods,
127+
}
128+
129+
for bootstrap in boot_methods:
130+
np.random.seed(3141)
131+
boot_t_stat = boot_plr(
132+
y,
133+
d,
134+
res_manual["thetas"],
135+
res_manual["ses"],
136+
res_manual["all_l_hat"],
137+
res_manual["all_m_hat"],
138+
res_manual["all_g_hat"],
139+
all_smpls,
140+
score,
141+
bootstrap,
142+
n_rep_boot,
143+
)
144+
145+
np.random.seed(3141)
146+
dml_plr_obj.bootstrap(method=bootstrap, n_rep_boot=n_rep_boot)
147+
np.random.seed(3141)
148+
dml_plr_obj_ext.bootstrap(method=bootstrap, n_rep_boot=n_rep_boot)
149+
res_dict["boot_t_stat" + bootstrap] = dml_plr_obj.boot_t_stat
150+
res_dict["boot_t_stat" + bootstrap + "_manual"] = boot_t_stat.reshape(-1, 1, 1)
151+
res_dict["boot_t_stat" + bootstrap + "_ext"] = dml_plr_obj_ext.boot_t_stat
152+
153+
# sensitivity tests
154+
res_dict["sensitivity_elements"] = dml_plr_obj.sensitivity_elements
155+
res_dict["sensitivity_elements_manual"] = fit_sensitivity_elements_plr(
156+
y, d.reshape(-1, 1), all_coef=dml_plr_obj.all_coef, predictions=dml_plr_obj.predictions, score=score, n_rep=1
157+
)
158+
# check if sensitivity score with rho=0 gives equal asymptotic standard deviation
159+
dml_plr_obj.sensitivity_analysis(rho=0.0)
160+
res_dict["sensitivity_ses"] = dml_plr_obj.sensitivity_params["se"]
161+
162+
return res_dict
163+
164+
165+
@pytest.mark.ci
166+
def test_dml_plr_binary_coef(dml_plr_binary_fixture):
167+
assert math.isclose(dml_plr_binary_fixture["coef"], dml_plr_binary_fixture["coef_manual"], rel_tol=1e-9, abs_tol=1e-4)
168+
assert math.isclose(dml_plr_binary_fixture["coef"], dml_plr_binary_fixture["coef_ext"], rel_tol=1e-9, abs_tol=1e-4)
169+
170+
171+
@pytest.mark.ci
172+
def test_dml_plr_binary_se(dml_plr_binary_fixture):
173+
assert math.isclose(dml_plr_binary_fixture["se"], dml_plr_binary_fixture["se_manual"], rel_tol=1e-9, abs_tol=1e-4)
174+
assert math.isclose(dml_plr_binary_fixture["se"], dml_plr_binary_fixture["se_ext"], rel_tol=1e-9, abs_tol=1e-4)
175+
176+
177+
@pytest.mark.ci
178+
def test_dml_plr_binary_boot(dml_plr_binary_fixture):
179+
for bootstrap in dml_plr_binary_fixture["boot_methods"]:
180+
assert np.allclose(
181+
dml_plr_binary_fixture["boot_t_stat" + bootstrap],
182+
dml_plr_binary_fixture["boot_t_stat" + bootstrap + "_manual"],
183+
rtol=1e-9,
184+
atol=1e-4,
185+
)
186+
assert np.allclose(
187+
dml_plr_binary_fixture["boot_t_stat" + bootstrap],
188+
dml_plr_binary_fixture["boot_t_stat" + bootstrap + "_ext"],
189+
rtol=1e-9,
190+
atol=1e-4,
191+
)
192+
193+
194+
@pytest.mark.ci
195+
def test_dml_plr_binary_sensitivity(dml_plr_binary_fixture):
196+
sensitivity_element_names = ["sigma2", "nu2", "psi_sigma2", "psi_nu2"]
197+
for sensitivity_element in sensitivity_element_names:
198+
assert np.allclose(
199+
dml_plr_binary_fixture["sensitivity_elements"][sensitivity_element],
200+
dml_plr_binary_fixture["sensitivity_elements_manual"][sensitivity_element],
201+
)
202+
203+
204+
@pytest.mark.ci
205+
def test_dml_plr_binary_sensitivity_rho0(dml_plr_binary_fixture):
206+
assert np.allclose(dml_plr_binary_fixture["se"], dml_plr_binary_fixture["sensitivity_ses"]["lower"], rtol=1e-9, atol=1e-4)
207+
assert np.allclose(dml_plr_binary_fixture["se"], dml_plr_binary_fixture["sensitivity_ses"]["upper"], rtol=1e-9, atol=1e-4)
208+
209+
210+
@pytest.fixture(scope="module", params=["nonrobust", "HC0", "HC1", "HC2", "HC3"])
211+
def cov_type(request):
212+
return request.param
213+
214+
215+
@pytest.mark.ci
216+
def test_dml_plr_binary_cate_gate(score, cov_type, generate_binary_data):
217+
n = 12
218+
219+
# Use generated binary data
220+
data = generate_binary_data.head(n)
221+
x_cols = data.columns[data.columns.str.startswith("X")].tolist()
222+
223+
obj_dml_data = dml.DoubleMLData(data, "y", ["d"], x_cols)
224+
ml_l = LogisticRegression(max_iter=1000)
225+
ml_m = LogisticRegression(max_iter=1000)
226+
227+
dml_plr_obj = dml.DoubleMLPLR(obj_dml_data, ml_l, ml_m, n_folds=2, score=score)
228+
dml_plr_obj.fit()
229+
230+
random_basis = pd.DataFrame(np.random.normal(0, 1, size=(n, 3)))
231+
cate = dml_plr_obj.cate(random_basis, cov_type=cov_type)
232+
assert isinstance(cate, dml.DoubleMLBLP)
233+
assert isinstance(cate.confint(), pd.DataFrame)
234+
assert cate.blp_model.cov_type == cov_type
235+
236+
groups_1 = pd.DataFrame(np.column_stack([data["X1"] <= 0, data["X1"] > 0.2]), columns=["Group 1", "Group 2"])
237+
msg = "At least one group effect is estimated with less than 6 observations."
238+
with pytest.warns(UserWarning, match=msg):
239+
gate_1 = dml_plr_obj.gate(groups_1, cov_type=cov_type)
240+
assert isinstance(gate_1, dml.utils.blp.DoubleMLBLP)
241+
assert isinstance(gate_1.confint(), pd.DataFrame)
242+
assert all(gate_1.confint().index == groups_1.columns.tolist())
243+
assert gate_1.blp_model.cov_type == cov_type

doubleml/tests/test_exceptions.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -986,7 +986,6 @@ def predict(self, X):
986986
@pytest.mark.ci
987987
def test_doubleml_exception_learner():
988988
err_msg_prefix = "Invalid learner provided for ml_l: "
989-
warn_msg_prefix = "Learner provided for ml_l is probably invalid: "
990989

991990
msg = err_msg_prefix + "provide an instance of a learner instead of a class."
992991
with pytest.raises(TypeError, match=msg):
@@ -1005,12 +1004,6 @@ def test_doubleml_exception_learner():
10051004
with pytest.warns(UserWarning):
10061005
_ = DoubleMLIRM(dml_data_irm, Lasso(), _DummyNoClassifier())
10071006

1008-
# ToDo: Currently for ml_l (and others) we only check whether the learner can be identified as regressor. However,
1009-
# we do not check whether it can instead be identified as classifier, which could be used to throw an error.
1010-
msg = warn_msg_prefix + r"LogisticRegression\(\) is \(probably\) no regressor."
1011-
with pytest.warns(UserWarning, match=msg):
1012-
_ = DoubleMLPLR(dml_data, LogisticRegression(), Lasso())
1013-
10141007
# we allow classifiers for ml_m in PLR, but only for binary treatment variables
10151008
msg = (
10161009
r"The ml_m learner LogisticRegression\(\) was identified as classifier "

0 commit comments

Comments
 (0)