Skip to content

Commit ef71205

Browse files
committed
add apo test
1 parent 4cc84b4 commit ef71205

File tree

1 file changed

+9
-1
lines changed

1 file changed

+9
-1
lines changed

doubleml/tests/test_set_ml_nuisance_params_models.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import pytest
33
from sklearn.ensemble import RandomForestClassifier, RandomForestRegressor
44

5-
from doubleml import DoubleMLCVAR, DoubleMLIIVM, DoubleMLIRM, DoubleMLLPQ, DoubleMLPLIV, DoubleMLPLR, DoubleMLPQ
5+
from doubleml import DoubleMLAPO, DoubleMLCVAR, DoubleMLIIVM, DoubleMLIRM, DoubleMLLPQ, DoubleMLPLIV, DoubleMLPLR, DoubleMLPQ
66
from doubleml.irm.datasets import make_iivm_data, make_irm_data
77
from doubleml.plm.datasets import make_pliv_CHS2015, make_plr_CCDDHNR2018
88

@@ -27,18 +27,21 @@
2727
dml_irm = DoubleMLIRM(dml_data_irm, reg_learner, class_learner, n_folds=n_folds)
2828
dml_iivm = DoubleMLIIVM(dml_data_iivm, reg_learner, class_learner, class_learner, n_folds=n_folds)
2929
dml_cvar = DoubleMLCVAR(dml_data_irm, ml_g=reg_learner, ml_m=class_learner, n_folds=n_folds)
30+
dml_apo = DoubleMLAPO(dml_data_irm, ml_g=reg_learner, ml_m=class_learner, n_folds=n_folds, treatment_level=1)
3031

3132
dml_plr.set_ml_nuisance_params("ml_l", "d", {"n_estimators": n_est_test})
3233
dml_pliv.set_ml_nuisance_params("ml_l", "d", {"n_estimators": n_est_test})
3334
dml_irm.set_ml_nuisance_params("ml_g0", "d", {"n_estimators": n_est_test})
3435
dml_iivm.set_ml_nuisance_params("ml_g0", "d", {"n_estimators": n_est_test})
3536
dml_cvar.set_ml_nuisance_params("ml_g", "d", {"n_estimators": n_est_test})
37+
dml_apo.set_ml_nuisance_params("ml_g_d_lvl1", "d", {"n_estimators": n_est_test})
3638

3739
dml_plr.fit(store_models=True)
3840
dml_pliv.fit(store_models=True)
3941
dml_irm.fit(store_models=True)
4042
dml_iivm.fit(store_models=True)
4143
dml_cvar.fit(store_models=True)
44+
dml_apo.fit(store_models=True)
4245

4346
# nonlinear models
4447
dml_pq = DoubleMLPQ(dml_data_irm, ml_g=class_learner, ml_m=class_learner, n_folds=n_folds)
@@ -76,6 +79,11 @@ def test_irm_params():
7679
_assert_nuisance_params(dml_irm, "ml_g0", "ml_g1")
7780

7881

82+
@pytest.mark.ci
83+
def test_apo_params():
84+
_assert_nuisance_params(dml_apo, "ml_g_d_lvl1", "ml_m")
85+
86+
7987
@pytest.mark.ci
8088
def test_iivm_params():
8189
_assert_nuisance_params(dml_iivm, "ml_g0", "ml_g1")

0 commit comments

Comments
 (0)