Skip to content

Commit 4cc84b4

Browse files
committed
update the set_ml_nuisance_params method to include the new parameter
1 parent 70300a0 commit 4cc84b4

File tree

3 files changed

+286
-70
lines changed

3 files changed

+286
-70
lines changed

doubleml/double_ml.py

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -946,18 +946,34 @@ def set_ml_nuisance_params(self, learner, treat_var, params):
946946
)
947947

948948
if params is None:
949-
all_params = [None] * self.n_rep
949+
new_params = [None] * self.n_rep
950950
elif isinstance(params, dict):
951-
all_params = [[params] * self.n_folds] * self.n_rep
951+
new_params = [[params] * self.n_folds] * self.n_rep
952952

953953
else:
954954
# ToDo: Add meaningful error message for asserts and corresponding uni tests
955955
assert len(params) == self.n_rep
956956
assert np.all(np.array([len(x) for x in params]) == self.n_folds)
957-
all_params = params
957+
new_params = params
958958

959-
self._params[learner][treat_var] = all_params
959+
existing_params = self._params[learner].get(treat_var, [None] * self.n_rep)
960960

961+
if existing_params == [None] * self.n_rep:
962+
updated_params = new_params
963+
elif new_params == [None] * self.n_rep:
964+
updated_params = existing_params
965+
else:
966+
updated_params = []
967+
for i_rep in range(self.n_rep):
968+
rep_params = []
969+
for i_fold in range(self.n_folds):
970+
existing_dict = existing_params[i_rep][i_fold]
971+
new_dict = new_params[i_rep][i_fold]
972+
updated_dict = existing_dict | new_dict
973+
rep_params.append(updated_dict)
974+
updated_params.append(rep_params)
975+
976+
self._params[learner][treat_var] = updated_params
961977
return self
962978

963979
@abstractmethod

doubleml/tests/test_set_ml_nuisance_params.py

Lines changed: 164 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -2,101 +2,199 @@
22
import pytest
33
from sklearn.ensemble import RandomForestClassifier, RandomForestRegressor
44

5-
from doubleml import DoubleMLCVAR, DoubleMLIIVM, DoubleMLIRM, DoubleMLLPQ, DoubleMLPLIV, DoubleMLPLR, DoubleMLPQ
6-
from doubleml.irm.datasets import make_iivm_data, make_irm_data
7-
from doubleml.plm.datasets import make_pliv_CHS2015, make_plr_CCDDHNR2018
5+
from doubleml import DoubleMLIRM, DoubleMLPLR
6+
from doubleml.irm.datasets import make_irm_data
7+
from doubleml.plm.datasets import make_plr_CCDDHNR2018
88

9-
# set default and test values
10-
n_est_default = 100
11-
n_est_test = 5
12-
n_folds = 2
13-
test_values = [[{"n_estimators": 5}, {"n_estimators": 5}]]
9+
# Test setup
10+
n_folds = 3
11+
n_rep = 2
1412

1513
np.random.seed(3141)
16-
dml_data_plr = make_plr_CCDDHNR2018(n_obs=100)
17-
dml_data_pliv = make_pliv_CHS2015(n_obs=100, dim_z=1)
1814
dml_data_irm = make_irm_data(n_obs=1000)
19-
dml_data_iivm = make_iivm_data(n_obs=2000)
2015

21-
reg_learner = RandomForestRegressor(max_depth=2)
22-
class_learner = RandomForestClassifier(max_depth=2)
16+
reg_learner = RandomForestRegressor(max_depth=2, n_estimators=100)
17+
class_learner = RandomForestClassifier(max_depth=2, n_estimators=100)
2318

24-
# linear models
25-
dml_plr = DoubleMLPLR(dml_data_plr, reg_learner, reg_learner, n_folds=n_folds)
26-
dml_pliv = DoubleMLPLIV(dml_data_pliv, reg_learner, reg_learner, reg_learner, n_folds=n_folds)
27-
dml_irm = DoubleMLIRM(dml_data_irm, reg_learner, class_learner, n_folds=n_folds)
28-
dml_iivm = DoubleMLIIVM(dml_data_iivm, reg_learner, class_learner, class_learner, n_folds=n_folds)
29-
dml_cvar = DoubleMLCVAR(dml_data_irm, ml_g=reg_learner, ml_m=class_learner, n_folds=n_folds)
3019

31-
dml_plr.set_ml_nuisance_params("ml_l", "d", {"n_estimators": n_est_test})
32-
dml_pliv.set_ml_nuisance_params("ml_l", "d", {"n_estimators": n_est_test})
33-
dml_irm.set_ml_nuisance_params("ml_g0", "d", {"n_estimators": n_est_test})
34-
dml_iivm.set_ml_nuisance_params("ml_g0", "d", {"n_estimators": n_est_test})
35-
dml_cvar.set_ml_nuisance_params("ml_g", "d", {"n_estimators": n_est_test})
20+
@pytest.fixture
21+
def fresh_irm_model():
22+
"""Create a fresh IRM model for each test."""
23+
return DoubleMLIRM(dml_data_irm, reg_learner, class_learner, n_folds=n_folds, n_rep=n_rep)
3624

37-
dml_plr.fit(store_models=True)
38-
dml_pliv.fit(store_models=True)
39-
dml_irm.fit(store_models=True)
40-
dml_iivm.fit(store_models=True)
41-
dml_cvar.fit(store_models=True)
4225

43-
# nonlinear models
44-
dml_pq = DoubleMLPQ(dml_data_irm, ml_g=class_learner, ml_m=class_learner, n_folds=n_folds)
45-
dml_lpq = DoubleMLLPQ(dml_data_iivm, ml_g=class_learner, ml_m=class_learner, n_folds=n_folds)
26+
@pytest.mark.ci
27+
def test_set_single_params(fresh_irm_model):
28+
"""Test combining behavior where new parameters are merged with existing ones."""
29+
dml_irm = fresh_irm_model
30+
31+
# Set initial parameters
32+
initial_params = {"n_estimators": 50, "max_depth": 3}
33+
dml_irm.set_ml_nuisance_params("ml_g0", "d", initial_params)
4634

47-
dml_pq.set_ml_nuisance_params("ml_g", "d", {"n_estimators": n_est_test})
48-
dml_lpq.set_ml_nuisance_params("ml_m_z", "d", {"n_estimators": n_est_test})
35+
# Set additional parameters (should combine)
36+
additional_params = {"min_samples_split": 5, "n_estimators": 25} # n_estimators should be updated
37+
dml_irm.set_ml_nuisance_params("ml_g0", "d", additional_params)
4938

50-
dml_pq.fit(store_models=True)
51-
dml_lpq.fit(store_models=True)
39+
# With combining behavior, we should have all keys
40+
expected_combined = {"n_estimators": 25, "max_depth": 3, "min_samples_split": 5}
41+
assert dml_irm.params["ml_g0"]["d"][0][0] == expected_combined
42+
assert dml_irm.params["ml_g0"]["d"][1][1] == expected_combined
5243

5344

54-
def _assert_nuisance_params(dml_obj, learner_1, learner_2):
55-
assert dml_obj.params[learner_1]["d"] == test_values
56-
assert dml_obj.params[learner_2]["d"][0] is None
45+
@pytest.mark.ci
46+
def test_none_params_handling(fresh_irm_model):
47+
"""Test handling of None parameters."""
48+
dml_irm = fresh_irm_model
5749

58-
param_list_1 = [dml_obj.models[learner_1]["d"][0][fold].n_estimators for fold in range(n_folds)]
59-
assert all(param == n_est_test for param in param_list_1)
60-
param_list_2 = [dml_obj.models[learner_2]["d"][0][fold].n_estimators for fold in range(n_folds)]
61-
assert all(param == n_est_default for param in param_list_2)
50+
# Set initial parameters
51+
initial_params = {"n_estimators": 50}
52+
dml_irm.set_ml_nuisance_params("ml_g0", "d", initial_params)
6253

54+
# Setting None should not change existing parameters
55+
dml_irm.set_ml_nuisance_params("ml_g0", "d", None)
56+
assert dml_irm.params["ml_g0"]["d"][0][0] == initial_params
6357

64-
@pytest.mark.ci
65-
def test_plr_params():
66-
_assert_nuisance_params(dml_plr, "ml_l", "ml_m")
58+
# Test setting None on empty parameters
59+
dml_irm.set_ml_nuisance_params("ml_g1", "d", None)
60+
assert dml_irm.params["ml_g1"]["d"] == [None] * n_rep
6761

6862

6963
@pytest.mark.ci
70-
def test_pliv_params():
71-
_assert_nuisance_params(dml_pliv, "ml_l", "ml_m")
64+
def test_set_nested_list_params(fresh_irm_model):
65+
"""Test combining behavior with nested list parameters."""
66+
dml_irm = fresh_irm_model
67+
68+
# Create initial nested parameters
69+
initial_nested = [
70+
[
71+
{"n_estimators": 50, "max_depth": 2},
72+
{"n_estimators": 60, "max_depth": 3},
73+
{"n_estimators": 60, "max_depth": 3},
74+
], # rep 0
75+
[
76+
{"n_estimators": 70, "max_depth": 4},
77+
{"n_estimators": 80, "max_depth": 5},
78+
{"n_estimators": 60, "max_depth": 3},
79+
], # rep 1
80+
]
81+
dml_irm.set_ml_nuisance_params("ml_g0", "d", initial_nested)
82+
83+
# Add additional parameters
84+
additional_nested = [
85+
[
86+
{"min_samples_split": 2, "n_estimators": 25},
87+
{"min_samples_split": 3, "n_estimators": 35},
88+
{"min_samples_split": 3, "n_estimators": 35},
89+
], # rep 0
90+
[
91+
{"min_samples_split": 4, "n_estimators": 45},
92+
{"min_samples_split": 5, "n_estimators": 55},
93+
{"min_samples_split": 3, "n_estimators": 35},
94+
], # rep 1
95+
]
96+
dml_irm.set_ml_nuisance_params("ml_g0", "d", additional_nested)
97+
98+
# Verify combining: existing keys preserved, overlapping keys updated, new keys added
99+
expected_combined = [
100+
[
101+
{"n_estimators": 25, "max_depth": 2, "min_samples_split": 2},
102+
{"n_estimators": 35, "max_depth": 3, "min_samples_split": 3},
103+
{"n_estimators": 35, "max_depth": 3, "min_samples_split": 3},
104+
],
105+
[
106+
{"n_estimators": 45, "max_depth": 4, "min_samples_split": 4},
107+
{"n_estimators": 55, "max_depth": 5, "min_samples_split": 5},
108+
{"n_estimators": 35, "max_depth": 3, "min_samples_split": 3},
109+
],
110+
]
111+
112+
assert dml_irm.params["ml_g0"]["d"] == expected_combined
72113

73114

74115
@pytest.mark.ci
75-
def test_irm_params():
76-
_assert_nuisance_params(dml_irm, "ml_g0", "ml_g1")
116+
def test_multiple_learners_independence(fresh_irm_model):
117+
"""Test that parameters for different learners are independent."""
118+
dml_irm = fresh_irm_model
77119

120+
# Set parameters for different learners
121+
params_g0 = {"n_estimators": 50}
122+
params_g1 = {"n_estimators": 75}
123+
params_m = {"n_estimators": 100}
78124

79-
@pytest.mark.ci
80-
def test_iivm_params():
81-
_assert_nuisance_params(dml_iivm, "ml_g0", "ml_g1")
125+
dml_irm.set_ml_nuisance_params("ml_g0", "d", params_g0)
126+
dml_irm.set_ml_nuisance_params("ml_g1", "d", params_g1)
127+
dml_irm.set_ml_nuisance_params("ml_m", "d", params_m)
82128

129+
# Verify independence
130+
assert dml_irm.params["ml_g0"]["d"][0][0] == params_g0
131+
assert dml_irm.params["ml_g1"]["d"][0][0] == params_g1
132+
assert dml_irm.params["ml_m"]["d"][0][0] == params_m
83133

84-
@pytest.mark.ci
85-
def test_cvar_params():
86-
_assert_nuisance_params(dml_cvar, "ml_g", "ml_m")
134+
# Modify one learner, others should remain unchanged
135+
new_params_g0 = {"max_depth": 3, "n_estimators": 25}
136+
dml_irm.set_ml_nuisance_params("ml_g0", "d", new_params_g0)
137+
138+
# With combining behavior
139+
expected_g0 = {"n_estimators": 25, "max_depth": 3}
140+
assert dml_irm.params["ml_g0"]["d"][0][0] == expected_g0
141+
assert dml_irm.params["ml_g1"]["d"][0][0] == params_g1 # unchanged
142+
assert dml_irm.params["ml_m"]["d"][0][0] == params_m # unchanged
87143

88144

89145
@pytest.mark.ci
90-
def test_pq_params():
91-
_assert_nuisance_params(dml_pq, "ml_g", "ml_m")
146+
def test_multiple_treatment_variables_independence():
147+
"""Test that parameters for different treatment variables are independent."""
148+
# Create PLR data with multiple treatment variables
149+
np.random.seed(3141)
150+
multi_treat_data = make_plr_CCDDHNR2018(n_obs=100)
151+
152+
# Add a second treatment variable for testing
153+
multi_treat_data.data["d2"] = np.random.normal(0, 1, 100)
154+
multi_treat_data._d_cols = ["d", "d2"]
155+
156+
dml_plr = DoubleMLPLR(multi_treat_data, reg_learner, reg_learner, n_folds=n_folds, n_rep=n_rep)
157+
158+
# Set parameters for different treatment variables
159+
params_d = {"n_estimators": 50}
160+
params_d2 = {"n_estimators": 75}
161+
162+
dml_plr.set_ml_nuisance_params("ml_l", "d", params_d)
163+
dml_plr.set_ml_nuisance_params("ml_l", "d2", params_d2)
164+
165+
# Verify independence
166+
assert dml_plr.params["ml_l"]["d"][0][0] == params_d
167+
assert dml_plr.params["ml_l"]["d2"][0][0] == params_d2
168+
169+
# Modify one treatment variable, other should remain unchanged
170+
new_params_d = {"max_depth": 3, "n_estimators": 25}
171+
dml_plr.set_ml_nuisance_params("ml_l", "d", new_params_d)
172+
173+
# With combining behavior
174+
expected_d = {"n_estimators": 25, "max_depth": 3}
175+
assert dml_plr.params["ml_l"]["d"][0][0] == expected_d
176+
assert dml_plr.params["ml_l"]["d2"][0][0] == params_d2 # unchanged
92177

93178

94179
@pytest.mark.ci
95-
def test_lpq_params():
96-
_assert_nuisance_params(dml_lpq, "ml_m_z", "ml_m_d_z0")
97-
param_list_2 = [dml_lpq.models["ml_m_d_z1"]["d"][0][fold].n_estimators for fold in range(n_folds)]
98-
assert all(param == n_est_default for param in param_list_2)
99-
param_list_2 = [dml_lpq.models["ml_g_du_z0"]["d"][0][fold].n_estimators for fold in range(n_folds)]
100-
assert all(param == n_est_default for param in param_list_2)
101-
param_list_2 = [dml_lpq.models["ml_g_du_z1"]["d"][0][fold].n_estimators for fold in range(n_folds)]
102-
assert all(param == n_est_default for param in param_list_2)
180+
def test_error_cases(fresh_irm_model):
181+
"""Test error handling for invalid inputs."""
182+
dml_irm = fresh_irm_model
183+
184+
# Invalid learner
185+
with pytest.raises(ValueError, match="Invalid nuisance learner"):
186+
dml_irm.set_ml_nuisance_params("invalid_learner", "d", {"n_estimators": 50})
187+
188+
# Invalid treatment variable
189+
with pytest.raises(ValueError, match="Invalid treatment variable"):
190+
dml_irm.set_ml_nuisance_params("ml_g0", "invalid_treat", {"n_estimators": 50})
191+
192+
# Invalid nested list length (wrong n_rep)
193+
invalid_nested = [[{"n_estimators": 50}, {"n_estimators": 60}]] # Only 1 rep, should be 2
194+
with pytest.raises(AssertionError):
195+
dml_irm.set_ml_nuisance_params("ml_g0", "d", invalid_nested)
196+
197+
# Invalid nested list length (wrong n_folds)
198+
invalid_nested = [[{"n_estimators": 50}], [{"n_estimators": 60}]] # Only 1 fold, should be 2 # Only 1 fold, should be 2
199+
with pytest.raises(AssertionError):
200+
dml_irm.set_ml_nuisance_params("ml_g0", "d", invalid_nested)

0 commit comments

Comments
 (0)