diff --git a/validmind/tests/model_validation/sklearn/OverfitDiagnosis.py b/validmind/tests/model_validation/sklearn/OverfitDiagnosis.py index 0ef87f5f2..9994efd82 100644 --- a/validmind/tests/model_validation/sklearn/OverfitDiagnosis.py +++ b/validmind/tests/model_validation/sklearn/OverfitDiagnosis.py @@ -220,6 +220,16 @@ def OverfitDiagnosis( - May not capture more subtle forms of overfitting that do not exceed the threshold. - Assumes that the binning of features adequately represents the data segments. """ + + numeric_and_categorical_feature_columns = ( + datasets[0].feature_columns_numeric + datasets[0].feature_columns_categorical + ) + + if not numeric_and_categorical_feature_columns: + raise ValueError( + "No valid numeric or categorical columns found in features_columns" + ) + is_classification = bool(datasets[0].probability_column(model)) if not metric: @@ -246,7 +256,7 @@ def OverfitDiagnosis( figures = [] results_headers = ["slice", "shape", "feature", metric] - for feature_column in datasets[0].feature_columns: + for feature_column in numeric_and_categorical_feature_columns: bins = 10 if feature_column in datasets[0].feature_columns_categorical: bins = len(train_df[feature_column].unique()) diff --git a/validmind/tests/model_validation/sklearn/WeakspotsDiagnosis.py b/validmind/tests/model_validation/sklearn/WeakspotsDiagnosis.py index f8f0b6667..591dccedb 100644 --- a/validmind/tests/model_validation/sklearn/WeakspotsDiagnosis.py +++ b/validmind/tests/model_validation/sklearn/WeakspotsDiagnosis.py @@ -211,6 +211,19 @@ def WeakspotsDiagnosis( improvement. """ feature_columns = features_columns or datasets[0].feature_columns + numeric_and_categorical_columns = ( + datasets[0].feature_columns_numeric + datasets[0].feature_columns_categorical + ) + + feature_columns = [ + col for col in feature_columns if col in numeric_and_categorical_columns + ] + + if not feature_columns: + raise ValueError( + "No valid numeric or categorical columns found in features_columns" + ) + if not all(col in datasets[0].feature_columns for col in feature_columns): raise ValueError( "Column(s) provided in features_columns do not exist in the dataset"