Skip to content

Commit 81799fd

Browse files
committed
update tests for overfitting
1 parent e764d95 commit 81799fd

File tree

2 files changed

+6
-6
lines changed

2 files changed

+6
-6
lines changed

doubleml/plm/tests/test_lplr_tune_ml_models.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ def score(request):
1717
return request.param
1818

1919

20-
@pytest.fixture(scope="module", params=[DecisionTreeRegressor(random_state=567), None])
20+
@pytest.fixture(scope="module", params=[DecisionTreeRegressor(random_state=567, max_depth=None, min_samples_split=2), None])
2121
def ml_a(request):
2222
return request.param
2323

@@ -27,11 +27,11 @@ def ml_a(request):
2727
def test_doubleml_lplr_optuna_tune(sampler_name, optuna_sampler, score, ml_a):
2828
np.random.seed(3141)
2929
alpha = 0.5
30-
dml_data = make_lplr_LZZ2020(n_obs=500, dim_x=15, alpha=alpha)
30+
dml_data = make_lplr_LZZ2020(n_obs=200, dim_x=15, alpha=alpha)
3131

32-
ml_M = DecisionTreeClassifier(random_state=123, max_leaf_nodes=50)
33-
ml_t = DecisionTreeRegressor(random_state=234, max_leaf_nodes=50)
34-
ml_m = DecisionTreeRegressor(random_state=456, max_leaf_nodes=50)
32+
ml_M = DecisionTreeClassifier(random_state=123, max_depth=None, min_samples_split=2)
33+
ml_t = DecisionTreeRegressor(random_state=234, max_depth=None, min_samples_split=2)
34+
ml_m = DecisionTreeRegressor(random_state=456, max_depth=None, min_samples_split=2)
3535

3636
dml_lplr = dml.DoubleMLLPLR(
3737
dml_data,

doubleml/tests/_utils_tune_optuna.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ def _basic_optuna_settings(additional=None):
2323
def _small_tree_params(trial):
2424
return {
2525
"max_depth": trial.suggest_int("max_depth", 1, 10),
26-
"min_samples_leaf": trial.suggest_int("min_samples_leaf", 2, 100),
26+
"min_samples_leaf": trial.suggest_int("min_samples_leaf", 5, 20),
2727
"max_leaf_nodes": trial.suggest_int("max_leaf_nodes", 2, 20),
2828
}
2929

0 commit comments

Comments
 (0)