diff --git a/openavmkit/tuning.py b/openavmkit/tuning.py index 1e7706e..2c4c594 100644 --- a/openavmkit/tuning.py +++ b/openavmkit/tuning.py @@ -485,6 +485,16 @@ def _catboost_rolling_origin_cv( def _lightgbm_rolling_origin_cv(X, y, params, n_splits=5, random_state=42, cat_vars=None): + n_samples = len(X) + n_splits = min(n_splits, n_samples) + if n_splits < 2: + import warnings + warnings.warn( + f"Not enough samples ({n_samples}) for cross-validation with n_splits={n_splits}. " + "Returning penalty MAPE of 1.0.", + UserWarning, + ) + return 1.0 kf = KFold(n_splits=n_splits, shuffle=True, random_state=random_state) mape_scores = []