Skip to content

Commit da9bf49

Browse files
authored
Merge pull request #355 from DoubleML/s-improve-set-nuisance-params
Update set nuisance params
2 parents 88f3c99 + f8a7369 commit da9bf49

File tree

3 files changed

+377
-69
lines changed

3 files changed

+377
-69
lines changed

doubleml/double_ml.py

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -949,18 +949,34 @@ def set_ml_nuisance_params(self, learner, treat_var, params):
949949
)
950950

951951
if params is None:
952-
all_params = [None] * self.n_rep
952+
new_params = [None] * self.n_rep
953953
elif isinstance(params, dict):
954-
all_params = [[params] * self.n_folds] * self.n_rep
954+
new_params = [[params] * self.n_folds] * self.n_rep
955955

956956
else:
957957
# ToDo: Add meaningful error message for asserts and corresponding uni tests
958958
assert len(params) == self.n_rep
959959
assert np.all(np.array([len(x) for x in params]) == self.n_folds)
960-
all_params = params
960+
new_params = params
961961

962-
self._params[learner][treat_var] = all_params
962+
existing_params = self._params[learner].get(treat_var, [None] * self.n_rep)
963963

964+
if existing_params == [None] * self.n_rep:
965+
updated_params = new_params
966+
elif new_params == [None] * self.n_rep:
967+
updated_params = existing_params
968+
else:
969+
updated_params = []
970+
for i_rep in range(self.n_rep):
971+
rep_params = []
972+
for i_fold in range(self.n_folds):
973+
existing_dict = existing_params[i_rep][i_fold]
974+
new_dict = new_params[i_rep][i_fold]
975+
updated_dict = existing_dict | new_dict
976+
rep_params.append(updated_dict)
977+
updated_params.append(rep_params)
978+
979+
self._params[learner][treat_var] = updated_params
964980
return self
965981

966982
@abstractmethod

doubleml/tests/test_set_ml_nuisance_params.py

Lines changed: 201 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -2,101 +2,237 @@
22
import pytest
33
from sklearn.ensemble import RandomForestClassifier, RandomForestRegressor
44

5-
from doubleml import DoubleMLCVAR, DoubleMLIIVM, DoubleMLIRM, DoubleMLLPQ, DoubleMLPLIV, DoubleMLPLR, DoubleMLPQ
6-
from doubleml.irm.datasets import make_iivm_data, make_irm_data
7-
from doubleml.plm.datasets import make_pliv_CHS2015, make_plr_CCDDHNR2018
5+
from doubleml import DoubleMLIRM, DoubleMLPLR
6+
from doubleml.irm.datasets import make_irm_data
7+
from doubleml.plm.datasets import make_plr_CCDDHNR2018
88

9-
# set default and test values
10-
n_est_default = 100
11-
n_est_test = 5
12-
n_folds = 2
13-
test_values = [[{"n_estimators": 5}, {"n_estimators": 5}]]
9+
# Test setup
10+
n_folds = 3
11+
n_rep = 2
1412

1513
np.random.seed(3141)
16-
dml_data_plr = make_plr_CCDDHNR2018(n_obs=100)
17-
dml_data_pliv = make_pliv_CHS2015(n_obs=100, dim_z=1)
1814
dml_data_irm = make_irm_data(n_obs=1000)
19-
dml_data_iivm = make_iivm_data(n_obs=2000)
2015

21-
reg_learner = RandomForestRegressor(max_depth=2)
22-
class_learner = RandomForestClassifier(max_depth=2)
16+
reg_learner = RandomForestRegressor(max_depth=2, n_estimators=100)
17+
class_learner = RandomForestClassifier(max_depth=2, n_estimators=100)
2318

24-
# linear models
25-
dml_plr = DoubleMLPLR(dml_data_plr, reg_learner, reg_learner, n_folds=n_folds)
26-
dml_pliv = DoubleMLPLIV(dml_data_pliv, reg_learner, reg_learner, reg_learner, n_folds=n_folds)
27-
dml_irm = DoubleMLIRM(dml_data_irm, reg_learner, class_learner, n_folds=n_folds)
28-
dml_iivm = DoubleMLIIVM(dml_data_iivm, reg_learner, class_learner, class_learner, n_folds=n_folds)
29-
dml_cvar = DoubleMLCVAR(dml_data_irm, ml_g=reg_learner, ml_m=class_learner, n_folds=n_folds)
3019

31-
dml_plr.set_ml_nuisance_params("ml_l", "d", {"n_estimators": n_est_test})
32-
dml_pliv.set_ml_nuisance_params("ml_l", "d", {"n_estimators": n_est_test})
33-
dml_irm.set_ml_nuisance_params("ml_g0", "d", {"n_estimators": n_est_test})
34-
dml_iivm.set_ml_nuisance_params("ml_g0", "d", {"n_estimators": n_est_test})
35-
dml_cvar.set_ml_nuisance_params("ml_g", "d", {"n_estimators": n_est_test})
20+
@pytest.fixture
21+
def fresh_irm_model():
22+
"""Create a fresh IRM model for each test."""
23+
return DoubleMLIRM(dml_data_irm, reg_learner, class_learner, n_folds=n_folds, n_rep=n_rep)
3624

37-
dml_plr.fit(store_models=True)
38-
dml_pliv.fit(store_models=True)
39-
dml_irm.fit(store_models=True)
40-
dml_iivm.fit(store_models=True)
41-
dml_cvar.fit(store_models=True)
4225

43-
# nonlinear models
44-
dml_pq = DoubleMLPQ(dml_data_irm, ml_g=class_learner, ml_m=class_learner, n_folds=n_folds)
45-
dml_lpq = DoubleMLLPQ(dml_data_iivm, ml_g=class_learner, ml_m=class_learner, n_folds=n_folds)
26+
@pytest.mark.ci
27+
def test_set_single_params(fresh_irm_model):
28+
"""Test combining behavior where new parameters are merged with existing ones."""
29+
dml_irm = fresh_irm_model
4630

47-
dml_pq.set_ml_nuisance_params("ml_g", "d", {"n_estimators": n_est_test})
48-
dml_lpq.set_ml_nuisance_params("ml_m_z", "d", {"n_estimators": n_est_test})
31+
# Set initial parameters
32+
initial_params = {"n_estimators": 50, "max_depth": 3}
33+
dml_irm.set_ml_nuisance_params("ml_g0", "d", initial_params)
4934

50-
dml_pq.fit(store_models=True)
51-
dml_lpq.fit(store_models=True)
35+
# Set additional parameters (should combine)
36+
additional_params = {"min_samples_split": 5, "n_estimators": 25} # n_estimators should be updated
37+
dml_irm.set_ml_nuisance_params("ml_g0", "d", additional_params)
5238

39+
# With combining behavior, we should have all keys
40+
expected_combined = {"n_estimators": 25, "max_depth": 3, "min_samples_split": 5}
41+
assert dml_irm.params["ml_g0"]["d"][0][0] == expected_combined
42+
assert dml_irm.params["ml_g0"]["d"][1][1] == expected_combined
5343

54-
def _assert_nuisance_params(dml_obj, learner_1, learner_2):
55-
assert dml_obj.params[learner_1]["d"] == test_values
56-
assert dml_obj.params[learner_2]["d"][0] is None
5744

58-
param_list_1 = [dml_obj.models[learner_1]["d"][0][fold].n_estimators for fold in range(n_folds)]
59-
assert all(param == n_est_test for param in param_list_1)
60-
param_list_2 = [dml_obj.models[learner_2]["d"][0][fold].n_estimators for fold in range(n_folds)]
61-
assert all(param == n_est_default for param in param_list_2)
45+
@pytest.mark.ci
46+
def test_none_params_handling(fresh_irm_model):
47+
"""Test handling of None parameters."""
48+
dml_irm = fresh_irm_model
6249

50+
# Set initial parameters
51+
initial_params = {"n_estimators": 50}
52+
dml_irm.set_ml_nuisance_params("ml_g0", "d", initial_params)
6353

64-
@pytest.mark.ci
65-
def test_plr_params():
66-
_assert_nuisance_params(dml_plr, "ml_l", "ml_m")
54+
# Setting None should not change existing parameters
55+
dml_irm.set_ml_nuisance_params("ml_g0", "d", None)
56+
assert dml_irm.params["ml_g0"]["d"][0][0] == initial_params
57+
58+
# Test setting None on empty parameters
59+
dml_irm.set_ml_nuisance_params("ml_g1", "d", None)
60+
assert dml_irm.params["ml_g1"]["d"] == [None] * n_rep
6761

6862

6963
@pytest.mark.ci
70-
def test_pliv_params():
71-
_assert_nuisance_params(dml_pliv, "ml_l", "ml_m")
64+
def test_set_nested_list_params(fresh_irm_model):
65+
"""Test combining behavior with nested list parameters."""
66+
dml_irm = fresh_irm_model
67+
68+
# Create initial nested parameters
69+
initial_nested = [
70+
[
71+
{"n_estimators": 50, "max_depth": 2},
72+
{"n_estimators": 60, "max_depth": 3},
73+
{"n_estimators": 60, "max_depth": 3},
74+
], # rep 0
75+
[
76+
{"n_estimators": 70, "max_depth": 4},
77+
{"n_estimators": 80, "max_depth": 5},
78+
{"n_estimators": 60, "max_depth": 3},
79+
], # rep 1
80+
]
81+
dml_irm.set_ml_nuisance_params("ml_g0", "d", initial_nested)
82+
83+
# Add additional parameters
84+
additional_nested = [
85+
[
86+
{"min_samples_split": 2, "n_estimators": 25},
87+
{"min_samples_split": 3, "n_estimators": 35},
88+
{"min_samples_split": 3, "n_estimators": 35},
89+
], # rep 0
90+
[
91+
{"min_samples_split": 4, "n_estimators": 45},
92+
{"min_samples_split": 5, "n_estimators": 55},
93+
{"min_samples_split": 3, "n_estimators": 35},
94+
], # rep 1
95+
]
96+
dml_irm.set_ml_nuisance_params("ml_g0", "d", additional_nested)
97+
98+
# Verify combining: existing keys preserved, overlapping keys updated, new keys added
99+
expected_combined = [
100+
[
101+
{"n_estimators": 25, "max_depth": 2, "min_samples_split": 2},
102+
{"n_estimators": 35, "max_depth": 3, "min_samples_split": 3},
103+
{"n_estimators": 35, "max_depth": 3, "min_samples_split": 3},
104+
],
105+
[
106+
{"n_estimators": 45, "max_depth": 4, "min_samples_split": 4},
107+
{"n_estimators": 55, "max_depth": 5, "min_samples_split": 5},
108+
{"n_estimators": 35, "max_depth": 3, "min_samples_split": 3},
109+
],
110+
]
111+
112+
assert dml_irm.params["ml_g0"]["d"] == expected_combined
72113

73114

74115
@pytest.mark.ci
75-
def test_irm_params():
76-
_assert_nuisance_params(dml_irm, "ml_g0", "ml_g1")
116+
def test_multiple_learners_independence(fresh_irm_model):
117+
"""Test that parameters for different learners are independent."""
118+
dml_irm = fresh_irm_model
77119

120+
# Set parameters for different learners
121+
params_g0 = {"n_estimators": 50}
122+
params_g1 = {"n_estimators": 75}
123+
params_m = {"n_estimators": 100}
78124

79-
@pytest.mark.ci
80-
def test_iivm_params():
81-
_assert_nuisance_params(dml_iivm, "ml_g0", "ml_g1")
125+
dml_irm.set_ml_nuisance_params("ml_g0", "d", params_g0)
126+
dml_irm.set_ml_nuisance_params("ml_g1", "d", params_g1)
127+
dml_irm.set_ml_nuisance_params("ml_m", "d", params_m)
128+
129+
# Verify independence
130+
assert dml_irm.params["ml_g0"]["d"][0][0] == params_g0
131+
assert dml_irm.params["ml_g1"]["d"][0][0] == params_g1
132+
assert dml_irm.params["ml_m"]["d"][0][0] == params_m
133+
134+
# Modify one learner, others should remain unchanged
135+
new_params_g0 = {"max_depth": 3, "n_estimators": 25}
136+
dml_irm.set_ml_nuisance_params("ml_g0", "d", new_params_g0)
137+
138+
# With combining behavior
139+
expected_g0 = {"n_estimators": 25, "max_depth": 3}
140+
assert dml_irm.params["ml_g0"]["d"][0][0] == expected_g0
141+
assert dml_irm.params["ml_g1"]["d"][0][0] == params_g1 # unchanged
142+
assert dml_irm.params["ml_m"]["d"][0][0] == params_m # unchanged
82143

83144

84145
@pytest.mark.ci
85-
def test_cvar_params():
86-
_assert_nuisance_params(dml_cvar, "ml_g", "ml_m")
146+
def test_multiple_treatment_variables_independence():
147+
"""Test that parameters for different treatment variables are independent."""
148+
# Create PLR data with multiple treatment variables
149+
np.random.seed(3141)
150+
multi_treat_data = make_plr_CCDDHNR2018(n_obs=100)
151+
152+
# Add a second treatment variable for testing
153+
multi_treat_data.data["d2"] = np.random.normal(0, 1, 100)
154+
multi_treat_data._d_cols = ["d", "d2"]
155+
156+
dml_plr = DoubleMLPLR(multi_treat_data, reg_learner, reg_learner, n_folds=n_folds, n_rep=n_rep)
157+
158+
# Set parameters for different treatment variables
159+
params_d = {"n_estimators": 50}
160+
params_d2 = {"n_estimators": 75}
161+
162+
dml_plr.set_ml_nuisance_params("ml_l", "d", params_d)
163+
dml_plr.set_ml_nuisance_params("ml_l", "d2", params_d2)
164+
165+
# Verify independence
166+
assert dml_plr.params["ml_l"]["d"][0][0] == params_d
167+
assert dml_plr.params["ml_l"]["d2"][0][0] == params_d2
168+
169+
# Modify one treatment variable, other should remain unchanged
170+
new_params_d = {"max_depth": 3, "n_estimators": 25}
171+
dml_plr.set_ml_nuisance_params("ml_l", "d", new_params_d)
172+
173+
# With combining behavior
174+
expected_d = {"n_estimators": 25, "max_depth": 3}
175+
assert dml_plr.params["ml_l"]["d"][0][0] == expected_d
176+
assert dml_plr.params["ml_l"]["d2"][0][0] == params_d2 # unchanged
87177

88178

89179
@pytest.mark.ci
90-
def test_pq_params():
91-
_assert_nuisance_params(dml_pq, "ml_g", "ml_m")
180+
def test_error_cases(fresh_irm_model):
181+
"""Test error handling for invalid inputs."""
182+
dml_irm = fresh_irm_model
183+
184+
# Invalid learner
185+
with pytest.raises(ValueError, match="Invalid nuisance learner"):
186+
dml_irm.set_ml_nuisance_params("invalid_learner", "d", {"n_estimators": 50})
187+
188+
# Invalid treatment variable
189+
with pytest.raises(ValueError, match="Invalid treatment variable"):
190+
dml_irm.set_ml_nuisance_params("ml_g0", "invalid_treat", {"n_estimators": 50})
191+
192+
# Invalid nested list length (wrong n_rep)
193+
invalid_nested = [[{"n_estimators": 50}, {"n_estimators": 60}]] # Only 1 rep, should be 2
194+
with pytest.raises(AssertionError):
195+
dml_irm.set_ml_nuisance_params("ml_g0", "d", invalid_nested)
196+
197+
# Invalid nested list length (wrong n_folds)
198+
invalid_nested = [[{"n_estimators": 50}], [{"n_estimators": 60}]] # Only 1 fold, should be 2 # Only 1 fold, should be 2
199+
with pytest.raises(AssertionError):
200+
dml_irm.set_ml_nuisance_params("ml_g0", "d", invalid_nested)
92201

93202

94203
@pytest.mark.ci
95-
def test_lpq_params():
96-
_assert_nuisance_params(dml_lpq, "ml_m_z", "ml_m_d_z0")
97-
param_list_2 = [dml_lpq.models["ml_m_d_z1"]["d"][0][fold].n_estimators for fold in range(n_folds)]
98-
assert all(param == n_est_default for param in param_list_2)
99-
param_list_2 = [dml_lpq.models["ml_g_du_z0"]["d"][0][fold].n_estimators for fold in range(n_folds)]
100-
assert all(param == n_est_default for param in param_list_2)
101-
param_list_2 = [dml_lpq.models["ml_g_du_z1"]["d"][0][fold].n_estimators for fold in range(n_folds)]
102-
assert all(param == n_est_default for param in param_list_2)
204+
def test_set_params_then_tune_combination(fresh_irm_model):
205+
"""Test that manually set parameters are preserved and combined with tuned parameters."""
206+
dml_irm = fresh_irm_model
207+
208+
# Set initial parameters that should be preserved after tuning
209+
initial_params = {"max_depth": 3, "min_samples_split": 5}
210+
dml_irm.set_ml_nuisance_params("ml_g0", "d", initial_params)
211+
dml_irm.set_ml_nuisance_params("ml_g1", "d", initial_params)
212+
dml_irm.set_ml_nuisance_params("ml_m", "d", {"max_depth": 2})
213+
214+
# Define tuning grid - only tune n_estimators, min_samples_split, not all manually set parameters
215+
par_grid = {"ml_g": {"n_estimators": [10, 20], "min_samples_split": [2, 10]}, "ml_m": {"n_estimators": [15, 25]}}
216+
dml_irm.tune(par_grid, return_tune_res=False)
217+
218+
# Verify consistency across folds and repetitions
219+
for rep in range(n_rep):
220+
for fold in range(n_folds):
221+
# All should have the same combination of manually set + tuned parameters
222+
fold_g0_params = dml_irm.params["ml_g0"]["d"][rep][fold]
223+
fold_g1_params = dml_irm.params["ml_g1"]["d"][rep][fold]
224+
fold_m_params = dml_irm.params["ml_m"]["d"][rep][fold]
225+
226+
# Manually set parameters that are not tuned should be preserved
227+
assert fold_g0_params["max_depth"] == 3
228+
assert fold_g1_params["max_depth"] == 3
229+
assert fold_m_params["max_depth"] == 2
230+
231+
# Tuned parameters should overwrite manually set ones
232+
assert fold_g0_params["n_estimators"] in [10, 20]
233+
assert fold_g1_params["n_estimators"] in [10, 20]
234+
assert fold_m_params["n_estimators"] in [15, 25]
235+
236+
# min_samples_split should be overwritten by tuning for ml_g learners
237+
assert fold_g0_params["min_samples_split"] in [2, 10]
238+
assert fold_g1_params["min_samples_split"] in [2, 10]

0 commit comments

Comments
 (0)