Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 16 additions & 2 deletions openavmkit/tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,20 @@ def _tune_lightgbm(
dict: Best hyperparameters found by Optuna.
"""

# Bound search space by training-fold size to prevent memorisation on thin datasets.
# Each CV fold trains on roughly (n_splits-1)/n_splits of the data.
n_train_per_fold = int(len(X) * (n_splits - 1) / n_splits)
# num_leaves: cap at n_train_per_fold // 4 so each leaf covers ~4+ samples on average.
# This prevents the tuner from selecting thousands of leaves from a few hundred rows.
max_num_leaves = max(8, min(2048, n_train_per_fold // 4))
# min_data_in_leaf: upper bound must not exceed training fold size or every split is illegal.
max_min_data_in_leaf = max(2, min(500, n_train_per_fold // 4))
if verbose and max_num_leaves < 64:
print(
f" [tune_lightgbm] thin dataset (n_train_per_fold={n_train_per_fold}): "
f"num_leaves capped at {max_num_leaves}, min_data_in_leaf capped at {max_min_data_in_leaf}"
)

def objective(trial):
"""Objective function for Optuna to optimize LightGBM hyperparameters."""
params = {
Expand All @@ -132,12 +146,12 @@ def objective(trial):
"learning_rate", 0.0001, 0.1, log=True
),
"max_bin": trial.suggest_int("max_bin", 64, 1024),
"num_leaves": trial.suggest_int("num_leaves", 64, 2048),
"num_leaves": trial.suggest_int("num_leaves", min(64, max_num_leaves), max_num_leaves),
"max_depth": trial.suggest_int("max_depth", 5, 15),
"min_gain_to_split": trial.suggest_float(
"min_gain_to_split", 1e-4, 50, log=True
),
"min_data_in_leaf": trial.suggest_int("min_data_in_leaf", 20, 500),
"min_data_in_leaf": trial.suggest_int("min_data_in_leaf", min(20, max_min_data_in_leaf), max_min_data_in_leaf),
"feature_fraction": trial.suggest_float(
"feature_fraction", 0.4, 0.9, log=False
),
Expand Down
Loading