|
2 | 2 | import pytest |
3 | 3 | from sklearn.ensemble import RandomForestClassifier, RandomForestRegressor |
4 | 4 |
|
5 | | -from doubleml import DoubleMLCVAR, DoubleMLIIVM, DoubleMLIRM, DoubleMLLPQ, DoubleMLPLIV, DoubleMLPLR, DoubleMLPQ |
| 5 | +from doubleml import DoubleMLAPO, DoubleMLCVAR, DoubleMLIIVM, DoubleMLIRM, DoubleMLLPQ, DoubleMLPLIV, DoubleMLPLR, DoubleMLPQ |
6 | 6 | from doubleml.irm.datasets import make_iivm_data, make_irm_data |
7 | 7 | from doubleml.plm.datasets import make_pliv_CHS2015, make_plr_CCDDHNR2018 |
8 | 8 |
|
|
27 | 27 | dml_irm = DoubleMLIRM(dml_data_irm, reg_learner, class_learner, n_folds=n_folds) |
28 | 28 | dml_iivm = DoubleMLIIVM(dml_data_iivm, reg_learner, class_learner, class_learner, n_folds=n_folds) |
29 | 29 | dml_cvar = DoubleMLCVAR(dml_data_irm, ml_g=reg_learner, ml_m=class_learner, n_folds=n_folds) |
| 30 | +dml_apo = DoubleMLAPO(dml_data_irm, ml_g=reg_learner, ml_m=class_learner, n_folds=n_folds, treatment_level=1) |
30 | 31 |
|
31 | 32 | dml_plr.set_ml_nuisance_params("ml_l", "d", {"n_estimators": n_est_test}) |
32 | 33 | dml_pliv.set_ml_nuisance_params("ml_l", "d", {"n_estimators": n_est_test}) |
33 | 34 | dml_irm.set_ml_nuisance_params("ml_g0", "d", {"n_estimators": n_est_test}) |
34 | 35 | dml_iivm.set_ml_nuisance_params("ml_g0", "d", {"n_estimators": n_est_test}) |
35 | 36 | dml_cvar.set_ml_nuisance_params("ml_g", "d", {"n_estimators": n_est_test}) |
| 37 | +dml_apo.set_ml_nuisance_params("ml_g_d_lvl1", "d", {"n_estimators": n_est_test}) |
36 | 38 |
|
37 | 39 | dml_plr.fit(store_models=True) |
38 | 40 | dml_pliv.fit(store_models=True) |
39 | 41 | dml_irm.fit(store_models=True) |
40 | 42 | dml_iivm.fit(store_models=True) |
41 | 43 | dml_cvar.fit(store_models=True) |
| 44 | +dml_apo.fit(store_models=True) |
42 | 45 |
|
43 | 46 | # nonlinear models |
44 | 47 | dml_pq = DoubleMLPQ(dml_data_irm, ml_g=class_learner, ml_m=class_learner, n_folds=n_folds) |
@@ -76,6 +79,11 @@ def test_irm_params(): |
76 | 79 | _assert_nuisance_params(dml_irm, "ml_g0", "ml_g1") |
77 | 80 |
|
78 | 81 |
|
| 82 | +@pytest.mark.ci |
| 83 | +def test_apo_params(): |
| 84 | + _assert_nuisance_params(dml_apo, "ml_g_d_lvl1", "ml_m") |
| 85 | + |
| 86 | + |
79 | 87 | @pytest.mark.ci |
80 | 88 | def test_iivm_params(): |
81 | 89 | _assert_nuisance_params(dml_iivm, "ml_g0", "ml_g1") |
|
0 commit comments