Skip to content

Commit 6bfb43c

Browse files
committed
add test with binary outcome
1 parent 7a79af4 commit 6bfb43c

File tree

1 file changed

+219
-0
lines changed

1 file changed

+219
-0
lines changed
Lines changed: 219 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,219 @@
1+
import math
2+
3+
import numpy as np
4+
import pandas as pd
5+
import pytest
6+
from sklearn.base import clone
7+
from sklearn.ensemble import RandomForestClassifier
8+
from sklearn.linear_model import LogisticRegression
9+
10+
import doubleml as dml
11+
12+
from ...tests._utils import draw_smpls
13+
from ._utils_plr_manual import boot_plr, fit_plr, fit_sensitivity_elements_plr
14+
15+
16+
@pytest.fixture(
17+
scope="module", params=[RandomForestClassifier(max_depth=2, n_estimators=10), LogisticRegression(max_iter=1000)]
18+
)
19+
def learner_binary(request):
20+
return request.param
21+
22+
23+
@pytest.fixture(scope="module", params=["partialling out"])
24+
def score(request):
25+
return request.param
26+
27+
28+
@pytest.fixture(scope="module")
29+
def generate_binary_data():
30+
"""Generate synthetic data with binary outcome"""
31+
np.random.seed(42)
32+
n = 500
33+
p = 5
34+
35+
# Generate covariates
36+
X = np.random.normal(0, 1, size=(n, p))
37+
38+
# Generate treatment
39+
d_prob = 1 / (1 + np.exp(-(X[:, 0] + X[:, 1] + np.random.normal(0, 1, n))))
40+
d = np.random.binomial(1, d_prob)
41+
42+
# Generate binary outcome with treatment effect
43+
theta_true = 0.5 # true treatment effect
44+
y_prob = 1 / (1 + np.exp(-(X[:, 0] + X[:, 2] + theta_true * d + np.random.normal(0, 0.5, n))))
45+
y = np.random.binomial(1, y_prob)
46+
47+
# Combine into DataFrame
48+
data = pd.DataFrame({"y": y, "d": d, **{f"X{i+1}": X[:, i] for i in range(p)}})
49+
50+
return data
51+
52+
53+
@pytest.fixture(scope="module")
54+
def dml_plr_binary_fixture(generate_binary_data, learner_binary, score):
55+
boot_methods = ["normal"]
56+
n_folds = 2
57+
n_rep_boot = 502
58+
59+
# collect data
60+
data = generate_binary_data
61+
x_cols = data.columns[data.columns.str.startswith("X")].tolist()
62+
63+
# Set machine learning methods for m & g
64+
ml_l = clone(learner_binary)
65+
ml_m = clone(learner_binary)
66+
67+
np.random.seed(3141)
68+
obj_dml_data = dml.DoubleMLData(data, "y", ["d"], x_cols)
69+
dml_plr_obj = dml.DoubleMLPLR(obj_dml_data, ml_l, ml_m, n_folds=n_folds, score=score)
70+
dml_plr_obj.fit()
71+
72+
np.random.seed(3141)
73+
y = data["y"].values
74+
x = data.loc[:, x_cols].values
75+
d = data["d"].values
76+
n_obs = len(y)
77+
all_smpls = draw_smpls(n_obs, n_folds)
78+
79+
res_manual = fit_plr(y, x, d, clone(learner_binary), clone(learner_binary), clone(learner_binary), all_smpls, score)
80+
81+
np.random.seed(3141)
82+
# test with external nuisance predictions
83+
dml_plr_obj_ext = dml.DoubleMLPLR(obj_dml_data, ml_l, ml_m, n_folds, score=score)
84+
85+
# synchronize the sample splitting
86+
dml_plr_obj_ext.set_sample_splitting(all_smpls=all_smpls)
87+
prediction_dict = {
88+
"d": {
89+
"ml_l": dml_plr_obj.predictions["ml_l"].reshape(-1, 1),
90+
"ml_m": dml_plr_obj.predictions["ml_m"].reshape(-1, 1),
91+
}
92+
}
93+
dml_plr_obj_ext.fit(external_predictions=prediction_dict)
94+
95+
res_dict = {
96+
"coef": dml_plr_obj.coef.item(),
97+
"coef_manual": res_manual["theta"],
98+
"coef_ext": dml_plr_obj_ext.coef.item(),
99+
"se": dml_plr_obj.se.item(),
100+
"se_manual": res_manual["se"],
101+
"se_ext": dml_plr_obj_ext.se.item(),
102+
"boot_methods": boot_methods,
103+
}
104+
105+
for bootstrap in boot_methods:
106+
np.random.seed(3141)
107+
boot_t_stat = boot_plr(
108+
y,
109+
d,
110+
res_manual["thetas"],
111+
res_manual["ses"],
112+
res_manual["all_l_hat"],
113+
res_manual["all_m_hat"],
114+
res_manual["all_g_hat"],
115+
all_smpls,
116+
score,
117+
bootstrap,
118+
n_rep_boot,
119+
)
120+
121+
np.random.seed(3141)
122+
dml_plr_obj.bootstrap(method=bootstrap, n_rep_boot=n_rep_boot)
123+
np.random.seed(3141)
124+
dml_plr_obj_ext.bootstrap(method=bootstrap, n_rep_boot=n_rep_boot)
125+
res_dict["boot_t_stat" + bootstrap] = dml_plr_obj.boot_t_stat
126+
res_dict["boot_t_stat" + bootstrap + "_manual"] = boot_t_stat.reshape(-1, 1, 1)
127+
res_dict["boot_t_stat" + bootstrap + "_ext"] = dml_plr_obj_ext.boot_t_stat
128+
129+
# sensitivity tests
130+
res_dict["sensitivity_elements"] = dml_plr_obj.sensitivity_elements
131+
res_dict["sensitivity_elements_manual"] = fit_sensitivity_elements_plr(
132+
y, d.reshape(-1, 1), all_coef=dml_plr_obj.all_coef, predictions=dml_plr_obj.predictions, score=score, n_rep=1
133+
)
134+
# check if sensitivity score with rho=0 gives equal asymptotic standard deviation
135+
dml_plr_obj.sensitivity_analysis(rho=0.0)
136+
res_dict["sensitivity_ses"] = dml_plr_obj.sensitivity_params["se"]
137+
138+
return res_dict
139+
140+
141+
@pytest.mark.ci
142+
def test_dml_plr_binary_coef(dml_plr_binary_fixture):
143+
assert math.isclose(dml_plr_binary_fixture["coef"], dml_plr_binary_fixture["coef_manual"], rel_tol=1e-9, abs_tol=1e-4)
144+
assert math.isclose(dml_plr_binary_fixture["coef"], dml_plr_binary_fixture["coef_ext"], rel_tol=1e-9, abs_tol=1e-4)
145+
146+
147+
@pytest.mark.ci
148+
def test_dml_plr_binary_se(dml_plr_binary_fixture):
149+
assert math.isclose(dml_plr_binary_fixture["se"], dml_plr_binary_fixture["se_manual"], rel_tol=1e-9, abs_tol=1e-4)
150+
assert math.isclose(dml_plr_binary_fixture["se"], dml_plr_binary_fixture["se_ext"], rel_tol=1e-9, abs_tol=1e-4)
151+
152+
153+
@pytest.mark.ci
154+
def test_dml_plr_binary_boot(dml_plr_binary_fixture):
155+
for bootstrap in dml_plr_binary_fixture["boot_methods"]:
156+
assert np.allclose(
157+
dml_plr_binary_fixture["boot_t_stat" + bootstrap],
158+
dml_plr_binary_fixture["boot_t_stat" + bootstrap + "_manual"],
159+
rtol=1e-9,
160+
atol=1e-4,
161+
)
162+
assert np.allclose(
163+
dml_plr_binary_fixture["boot_t_stat" + bootstrap],
164+
dml_plr_binary_fixture["boot_t_stat" + bootstrap + "_ext"],
165+
rtol=1e-9,
166+
atol=1e-4,
167+
)
168+
169+
170+
@pytest.mark.ci
171+
def test_dml_plr_binary_sensitivity(dml_plr_binary_fixture):
172+
sensitivity_element_names = ["sigma2", "nu2", "psi_sigma2", "psi_nu2"]
173+
for sensitivity_element in sensitivity_element_names:
174+
assert np.allclose(
175+
dml_plr_binary_fixture["sensitivity_elements"][sensitivity_element],
176+
dml_plr_binary_fixture["sensitivity_elements_manual"][sensitivity_element],
177+
)
178+
179+
180+
@pytest.mark.ci
181+
def test_dml_plr_binary_sensitivity_rho0(dml_plr_binary_fixture):
182+
assert np.allclose(dml_plr_binary_fixture["se"], dml_plr_binary_fixture["sensitivity_ses"]["lower"], rtol=1e-9, atol=1e-4)
183+
assert np.allclose(dml_plr_binary_fixture["se"], dml_plr_binary_fixture["sensitivity_ses"]["upper"], rtol=1e-9, atol=1e-4)
184+
185+
186+
@pytest.fixture(scope="module", params=["nonrobust", "HC0", "HC1", "HC2", "HC3"])
187+
def cov_type(request):
188+
return request.param
189+
190+
191+
@pytest.mark.ci
192+
def test_dml_plr_binary_cate_gate(score, cov_type, generate_binary_data):
193+
n = 12
194+
195+
# Use generated binary data
196+
data = generate_binary_data.head(n)
197+
x_cols = data.columns[data.columns.str.startswith("X")].tolist()
198+
199+
obj_dml_data = dml.DoubleMLData(data, "y", ["d"], x_cols)
200+
ml_l = LogisticRegression(max_iter=1000)
201+
ml_m = LogisticRegression(max_iter=1000)
202+
203+
dml_plr_obj = dml.DoubleMLPLR(obj_dml_data, ml_l, ml_m, n_folds=2, score=score)
204+
dml_plr_obj.fit()
205+
206+
random_basis = pd.DataFrame(np.random.normal(0, 1, size=(n, 3)))
207+
cate = dml_plr_obj.cate(random_basis, cov_type=cov_type)
208+
assert isinstance(cate, dml.DoubleMLBLP)
209+
assert isinstance(cate.confint(), pd.DataFrame)
210+
assert cate.blp_model.cov_type == cov_type
211+
212+
groups_1 = pd.DataFrame(np.column_stack([data["X1"] <= 0, data["X1"] > 0.2]), columns=["Group 1", "Group 2"])
213+
msg = "At least one group effect is estimated with less than 6 observations."
214+
with pytest.warns(UserWarning, match=msg):
215+
gate_1 = dml_plr_obj.gate(groups_1, cov_type=cov_type)
216+
assert isinstance(gate_1, dml.utils.blp.DoubleMLBLP)
217+
assert isinstance(gate_1.confint(), pd.DataFrame)
218+
assert all(gate_1.confint().index == groups_1.columns.tolist())
219+
assert gate_1.blp_model.cov_type == cov_type

0 commit comments

Comments
 (0)