|
3 | 3 | from sklearn.ensemble import RandomForestClassifier, RandomForestRegressor |
4 | 4 |
|
5 | 5 | from doubleml import DoubleMLAPO, DoubleMLCVAR, DoubleMLIIVM, DoubleMLIRM, DoubleMLLPQ, DoubleMLPLIV, DoubleMLPLR, DoubleMLPQ |
| 6 | +from doubleml.data import DoubleMLPanelData |
| 7 | +from doubleml.did import DoubleMLDIDBinary, DoubleMLDIDCSBinary |
| 8 | +from doubleml.did.datasets import make_did_CS2021 |
6 | 9 | from doubleml.irm.datasets import make_iivm_data, make_irm_data |
7 | 10 | from doubleml.plm.datasets import make_pliv_CHS2015, make_plr_CCDDHNR2018 |
8 | 11 |
|
|
18 | 21 | dml_data_irm = make_irm_data(n_obs=1000) |
19 | 22 | dml_data_iivm = make_iivm_data(n_obs=2000) |
20 | 23 |
|
| 24 | +# Create DID data |
| 25 | +df_did = make_did_CS2021(n_obs=500, dgp_type=1, n_pre_treat_periods=2, n_periods=4, time_type="float") |
| 26 | +dml_data_did = DoubleMLPanelData(df_did, y_col="y", d_cols="d", id_col="id", t_col="t", x_cols=["Z1", "Z2", "Z3", "Z4"]) |
| 27 | + |
21 | 28 | reg_learner = RandomForestRegressor(max_depth=2) |
22 | 29 | class_learner = RandomForestClassifier(max_depth=2) |
23 | 30 |
|
|
43 | 50 | dml_cvar.fit(store_models=True) |
44 | 51 | dml_apo.fit(store_models=True) |
45 | 52 |
|
| 53 | +# DID models |
| 54 | +dml_did_binary = DoubleMLDIDBinary( |
| 55 | + obj_dml_data=dml_data_did, |
| 56 | + ml_g=reg_learner, |
| 57 | + ml_m=class_learner, |
| 58 | + g_value=2, |
| 59 | + t_value_pre=0, |
| 60 | + t_value_eval=1, |
| 61 | + score="observational", |
| 62 | + n_folds=n_folds, |
| 63 | +) |
| 64 | + |
| 65 | +dml_did_cs_binary = DoubleMLDIDCSBinary( |
| 66 | + obj_dml_data=dml_data_did, |
| 67 | + ml_g=reg_learner, |
| 68 | + ml_m=class_learner, |
| 69 | + g_value=2, |
| 70 | + t_value_pre=0, |
| 71 | + t_value_eval=1, |
| 72 | + score="observational", |
| 73 | + n_folds=n_folds, |
| 74 | +) |
| 75 | + |
| 76 | +dml_did_binary.set_ml_nuisance_params("ml_g0", "d", {"n_estimators": n_est_test}) |
| 77 | +dml_did_cs_binary.set_ml_nuisance_params("ml_g_d0_t0", "d", {"n_estimators": n_est_test}) |
| 78 | + |
| 79 | +dml_did_binary.fit(store_models=True) |
| 80 | +dml_did_cs_binary.fit(store_models=True) |
| 81 | + |
46 | 82 | # nonlinear models |
47 | 83 | dml_pq = DoubleMLPQ(dml_data_irm, ml_g=class_learner, ml_m=class_learner, n_folds=n_folds) |
48 | 84 | dml_lpq = DoubleMLLPQ(dml_data_iivm, ml_g=class_learner, ml_m=class_learner, n_folds=n_folds) |
@@ -108,3 +144,13 @@ def test_lpq_params(): |
108 | 144 | assert all(param == n_est_default for param in param_list_2) |
109 | 145 | param_list_2 = [dml_lpq.models["ml_g_du_z1"]["d"][0][fold].n_estimators for fold in range(n_folds)] |
110 | 146 | assert all(param == n_est_default for param in param_list_2) |
| 147 | + |
| 148 | + |
| 149 | +@pytest.mark.ci |
| 150 | +def test_did_binary_params(): |
| 151 | + _assert_nuisance_params(dml_did_binary, "ml_g0", "ml_g1") |
| 152 | + |
| 153 | + |
| 154 | +@pytest.mark.ci |
| 155 | +def test_did_cs_binary_params(): |
| 156 | + _assert_nuisance_params(dml_did_cs_binary, "ml_g_d0_t0", "ml_g_d1_t0") |
0 commit comments