@@ -213,7 +213,7 @@ def test_set_params_then_tune_combination(fresh_irm_model):
213213
214214 # Define tuning grid - only tune n_estimators, min_samples_split, not all manually set parameters
215215 par_grid = {"ml_g" : {"n_estimators" : [10 , 20 ], "min_samples_split" : [2 , 10 ]}, "ml_m" : {"n_estimators" : [15 , 25 ]}}
216- tune_res = dml_irm .tune (par_grid , return_tune_res = True )
216+ dml_irm .tune (par_grid , return_tune_res = False )
217217
218218 # Verify consistency across folds and repetitions
219219 for rep in range (n_rep ):
@@ -236,18 +236,3 @@ def test_set_params_then_tune_combination(fresh_irm_model):
236236 # min_samples_split should be overwritten by tuning for ml_g learners
237237 assert fold_g0_params ["min_samples_split" ] in [2 , 10 ]
238238 assert fold_g1_params ["min_samples_split" ] in [2 , 10 ]
239-
240- # Check that manually set max_depth is preserved in best estimators
241- for fold in range (n_folds ):
242- # Check if tune_res contains GridSearchCV objects
243- if hasattr (tune_res [0 ]["tune_res" ]["g0_tune" ][fold ], "best_estimator_" ):
244- best_estimator_g0 = tune_res [0 ]["tune_res" ]["g0_tune" ][fold ].best_estimator_
245- assert best_estimator_g0 .max_depth == 3
246-
247- if hasattr (tune_res [0 ]["tune_res" ]["g1_tune" ][fold ], "best_estimator_" ):
248- best_estimator_g1 = tune_res [0 ]["tune_res" ]["g1_tune" ][fold ].best_estimator_
249- assert best_estimator_g1 .max_depth == 3
250-
251- if hasattr (tune_res [0 ]["tune_res" ]["m_tune" ][fold ], "best_estimator_" ):
252- best_estimator_m = tune_res [0 ]["tune_res" ]["m_tune" ][fold ].best_estimator_
253- assert best_estimator_m .max_depth == 2
0 commit comments