Skip to content

Commit 92b1f70

Browse files
committed
add did binary models to set nuisance params test
1 parent ef71205 commit 92b1f70

File tree

1 file changed

+46
-0
lines changed

1 file changed

+46
-0
lines changed

doubleml/tests/test_set_ml_nuisance_params_models.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,9 @@
33
from sklearn.ensemble import RandomForestClassifier, RandomForestRegressor
44

55
from doubleml import DoubleMLAPO, DoubleMLCVAR, DoubleMLIIVM, DoubleMLIRM, DoubleMLLPQ, DoubleMLPLIV, DoubleMLPLR, DoubleMLPQ
6+
from doubleml.data import DoubleMLPanelData
7+
from doubleml.did import DoubleMLDIDBinary, DoubleMLDIDCSBinary
8+
from doubleml.did.datasets import make_did_CS2021
69
from doubleml.irm.datasets import make_iivm_data, make_irm_data
710
from doubleml.plm.datasets import make_pliv_CHS2015, make_plr_CCDDHNR2018
811

@@ -18,6 +21,10 @@
1821
dml_data_irm = make_irm_data(n_obs=1000)
1922
dml_data_iivm = make_iivm_data(n_obs=2000)
2023

24+
# Create DID data
25+
df_did = make_did_CS2021(n_obs=500, dgp_type=1, n_pre_treat_periods=2, n_periods=4, time_type="float")
26+
dml_data_did = DoubleMLPanelData(df_did, y_col="y", d_cols="d", id_col="id", t_col="t", x_cols=["Z1", "Z2", "Z3", "Z4"])
27+
2128
reg_learner = RandomForestRegressor(max_depth=2)
2229
class_learner = RandomForestClassifier(max_depth=2)
2330

@@ -43,6 +50,35 @@
4350
dml_cvar.fit(store_models=True)
4451
dml_apo.fit(store_models=True)
4552

53+
# DID models
54+
dml_did_binary = DoubleMLDIDBinary(
55+
obj_dml_data=dml_data_did,
56+
ml_g=reg_learner,
57+
ml_m=class_learner,
58+
g_value=2,
59+
t_value_pre=0,
60+
t_value_eval=1,
61+
score="observational",
62+
n_folds=n_folds,
63+
)
64+
65+
dml_did_cs_binary = DoubleMLDIDCSBinary(
66+
obj_dml_data=dml_data_did,
67+
ml_g=reg_learner,
68+
ml_m=class_learner,
69+
g_value=2,
70+
t_value_pre=0,
71+
t_value_eval=1,
72+
score="observational",
73+
n_folds=n_folds,
74+
)
75+
76+
dml_did_binary.set_ml_nuisance_params("ml_g0", "d", {"n_estimators": n_est_test})
77+
dml_did_cs_binary.set_ml_nuisance_params("ml_g_d0_t0", "d", {"n_estimators": n_est_test})
78+
79+
dml_did_binary.fit(store_models=True)
80+
dml_did_cs_binary.fit(store_models=True)
81+
4682
# nonlinear models
4783
dml_pq = DoubleMLPQ(dml_data_irm, ml_g=class_learner, ml_m=class_learner, n_folds=n_folds)
4884
dml_lpq = DoubleMLLPQ(dml_data_iivm, ml_g=class_learner, ml_m=class_learner, n_folds=n_folds)
@@ -108,3 +144,13 @@ def test_lpq_params():
108144
assert all(param == n_est_default for param in param_list_2)
109145
param_list_2 = [dml_lpq.models["ml_g_du_z1"]["d"][0][fold].n_estimators for fold in range(n_folds)]
110146
assert all(param == n_est_default for param in param_list_2)
147+
148+
149+
@pytest.mark.ci
150+
def test_did_binary_params():
151+
_assert_nuisance_params(dml_did_binary, "ml_g0", "ml_g1")
152+
153+
154+
@pytest.mark.ci
155+
def test_did_cs_binary_params():
156+
_assert_nuisance_params(dml_did_cs_binary, "ml_g_d0_t0", "ml_g_d1_t0")

0 commit comments

Comments
 (0)