@@ -198,3 +198,56 @@ def test_error_cases(fresh_irm_model):
198198 invalid_nested = [[{"n_estimators" : 50 }], [{"n_estimators" : 60 }]] # Only 1 fold, should be 2 # Only 1 fold, should be 2
199199 with pytest .raises (AssertionError ):
200200 dml_irm .set_ml_nuisance_params ("ml_g0" , "d" , invalid_nested )
201+
202+
203+ @pytest .mark .ci
204+ def test_set_params_then_tune_combination (fresh_irm_model ):
205+ """Test that manually set parameters are preserved and combined with tuned parameters."""
206+ dml_irm = fresh_irm_model
207+
208+ # Set initial parameters that should be preserved after tuning
209+ initial_params = {"max_depth" : 3 , "min_samples_split" : 5 }
210+ dml_irm .set_ml_nuisance_params ("ml_g0" , "d" , initial_params )
211+ dml_irm .set_ml_nuisance_params ("ml_g1" , "d" , initial_params )
212+ dml_irm .set_ml_nuisance_params ("ml_m" , "d" , {"max_depth" : 2 })
213+
214+ # Define tuning grid - only tune n_estimators, min_samples_split, not all manually set parameters
215+ par_grid = {"ml_g" : {"n_estimators" : [10 , 20 ], "min_samples_split" : [2 , 10 ]}, "ml_m" : {"n_estimators" : [15 , 25 ]}}
216+ tune_res = dml_irm .tune (par_grid , return_tune_res = True )
217+
218+ # Verify consistency across folds and repetitions
219+ for rep in range (n_rep ):
220+ for fold in range (n_folds ):
221+ # All should have the same combination of manually set + tuned parameters
222+ fold_g0_params = dml_irm .params ["ml_g0" ]["d" ][rep ][fold ]
223+ fold_g1_params = dml_irm .params ["ml_g1" ]["d" ][rep ][fold ]
224+ fold_m_params = dml_irm .params ["ml_m" ]["d" ][rep ][fold ]
225+
226+ # Manually set parameters that are not tuned should be preserved
227+ assert fold_g0_params ["max_depth" ] == 3
228+ assert fold_g1_params ["max_depth" ] == 3
229+ assert fold_m_params ["max_depth" ] == 2
230+
231+ # Tuned parameters should overwrite manually set ones
232+ assert fold_g0_params ["n_estimators" ] in [10 , 20 ]
233+ assert fold_g1_params ["n_estimators" ] in [10 , 20 ]
234+ assert fold_m_params ["n_estimators" ] in [15 , 25 ]
235+
236+ # min_samples_split should be overwritten by tuning for ml_g learners
237+ assert fold_g0_params ["min_samples_split" ] in [2 , 10 ]
238+ assert fold_g1_params ["min_samples_split" ] in [2 , 10 ]
239+
240+ # Check that manually set max_depth is preserved in best estimators
241+ for fold in range (n_folds ):
242+ # Check if tune_res contains GridSearchCV objects
243+ if hasattr (tune_res [0 ]["tune_res" ]["g0_tune" ][fold ], "best_estimator_" ):
244+ best_estimator_g0 = tune_res [0 ]["tune_res" ]["g0_tune" ][fold ].best_estimator_
245+ assert best_estimator_g0 .max_depth == 3
246+
247+ if hasattr (tune_res [0 ]["tune_res" ]["g1_tune" ][fold ], "best_estimator_" ):
248+ best_estimator_g1 = tune_res [0 ]["tune_res" ]["g1_tune" ][fold ].best_estimator_
249+ assert best_estimator_g1 .max_depth == 3
250+
251+ if hasattr (tune_res [0 ]["tune_res" ]["m_tune" ][fold ], "best_estimator_" ):
252+ best_estimator_m = tune_res [0 ]["tune_res" ]["m_tune" ][fold ].best_estimator_
253+ assert best_estimator_m .max_depth == 2
0 commit comments