Skip to content

Commit cdbda92

Browse files
committed
add testt for tuning with set parameters
1 parent 92b1f70 commit cdbda92

File tree

1 file changed

+53
-0
lines changed

1 file changed

+53
-0
lines changed

doubleml/tests/test_set_ml_nuisance_params.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -198,3 +198,56 @@ def test_error_cases(fresh_irm_model):
198198
invalid_nested = [[{"n_estimators": 50}], [{"n_estimators": 60}]] # Only 1 fold, should be 2 # Only 1 fold, should be 2
199199
with pytest.raises(AssertionError):
200200
dml_irm.set_ml_nuisance_params("ml_g0", "d", invalid_nested)
201+
202+
203+
@pytest.mark.ci
204+
def test_set_params_then_tune_combination(fresh_irm_model):
205+
"""Test that manually set parameters are preserved and combined with tuned parameters."""
206+
dml_irm = fresh_irm_model
207+
208+
# Set initial parameters that should be preserved after tuning
209+
initial_params = {"max_depth": 3, "min_samples_split": 5}
210+
dml_irm.set_ml_nuisance_params("ml_g0", "d", initial_params)
211+
dml_irm.set_ml_nuisance_params("ml_g1", "d", initial_params)
212+
dml_irm.set_ml_nuisance_params("ml_m", "d", {"max_depth": 2})
213+
214+
# Define tuning grid - only tune n_estimators, min_samples_split, not all manually set parameters
215+
par_grid = {"ml_g": {"n_estimators": [10, 20], "min_samples_split": [2, 10]}, "ml_m": {"n_estimators": [15, 25]}}
216+
tune_res = dml_irm.tune(par_grid, return_tune_res=True)
217+
218+
# Verify consistency across folds and repetitions
219+
for rep in range(n_rep):
220+
for fold in range(n_folds):
221+
# All should have the same combination of manually set + tuned parameters
222+
fold_g0_params = dml_irm.params["ml_g0"]["d"][rep][fold]
223+
fold_g1_params = dml_irm.params["ml_g1"]["d"][rep][fold]
224+
fold_m_params = dml_irm.params["ml_m"]["d"][rep][fold]
225+
226+
# Manually set parameters that are not tuned should be preserved
227+
assert fold_g0_params["max_depth"] == 3
228+
assert fold_g1_params["max_depth"] == 3
229+
assert fold_m_params["max_depth"] == 2
230+
231+
# Tuned parameters should overwrite manually set ones
232+
assert fold_g0_params["n_estimators"] in [10, 20]
233+
assert fold_g1_params["n_estimators"] in [10, 20]
234+
assert fold_m_params["n_estimators"] in [15, 25]
235+
236+
# min_samples_split should be overwritten by tuning for ml_g learners
237+
assert fold_g0_params["min_samples_split"] in [2, 10]
238+
assert fold_g1_params["min_samples_split"] in [2, 10]
239+
240+
# Check that manually set max_depth is preserved in best estimators
241+
for fold in range(n_folds):
242+
# Check if tune_res contains GridSearchCV objects
243+
if hasattr(tune_res[0]["tune_res"]["g0_tune"][fold], "best_estimator_"):
244+
best_estimator_g0 = tune_res[0]["tune_res"]["g0_tune"][fold].best_estimator_
245+
assert best_estimator_g0.max_depth == 3
246+
247+
if hasattr(tune_res[0]["tune_res"]["g1_tune"][fold], "best_estimator_"):
248+
best_estimator_g1 = tune_res[0]["tune_res"]["g1_tune"][fold].best_estimator_
249+
assert best_estimator_g1.max_depth == 3
250+
251+
if hasattr(tune_res[0]["tune_res"]["m_tune"][fold], "best_estimator_"):
252+
best_estimator_m = tune_res[0]["tune_res"]["m_tune"][fold].best_estimator_
253+
assert best_estimator_m.max_depth == 2

0 commit comments

Comments
 (0)