Skip to content

Commit f8a7369

Browse files
committed
update test
1 parent cdbda92 commit f8a7369

File tree

1 file changed

+1
-16
lines changed

1 file changed

+1
-16
lines changed

doubleml/tests/test_set_ml_nuisance_params.py

Lines changed: 1 addition & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -213,7 +213,7 @@ def test_set_params_then_tune_combination(fresh_irm_model):
213213

214214
# Define tuning grid - only tune n_estimators, min_samples_split, not all manually set parameters
215215
par_grid = {"ml_g": {"n_estimators": [10, 20], "min_samples_split": [2, 10]}, "ml_m": {"n_estimators": [15, 25]}}
216-
tune_res = dml_irm.tune(par_grid, return_tune_res=True)
216+
dml_irm.tune(par_grid, return_tune_res=False)
217217

218218
# Verify consistency across folds and repetitions
219219
for rep in range(n_rep):
@@ -236,18 +236,3 @@ def test_set_params_then_tune_combination(fresh_irm_model):
236236
# min_samples_split should be overwritten by tuning for ml_g learners
237237
assert fold_g0_params["min_samples_split"] in [2, 10]
238238
assert fold_g1_params["min_samples_split"] in [2, 10]
239-
240-
# Check that manually set max_depth is preserved in best estimators
241-
for fold in range(n_folds):
242-
# Check if tune_res contains GridSearchCV objects
243-
if hasattr(tune_res[0]["tune_res"]["g0_tune"][fold], "best_estimator_"):
244-
best_estimator_g0 = tune_res[0]["tune_res"]["g0_tune"][fold].best_estimator_
245-
assert best_estimator_g0.max_depth == 3
246-
247-
if hasattr(tune_res[0]["tune_res"]["g1_tune"][fold], "best_estimator_"):
248-
best_estimator_g1 = tune_res[0]["tune_res"]["g1_tune"][fold].best_estimator_
249-
assert best_estimator_g1.max_depth == 3
250-
251-
if hasattr(tune_res[0]["tune_res"]["m_tune"][fold], "best_estimator_"):
252-
best_estimator_m = tune_res[0]["tune_res"]["m_tune"][fold].best_estimator_
253-
assert best_estimator_m.max_depth == 2

0 commit comments

Comments
 (0)