From ba5be8df652e199851e5f96eb549d1a5965721c7 Mon Sep 17 00:00:00 2001 From: kueltzen Date: Fri, 25 Oct 2024 13:41:42 +0200 Subject: [PATCH 1/8] Removed resetting of loss in fit of ModnetModel. --- modnet/models/vanilla.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/modnet/models/vanilla.py b/modnet/models/vanilla.py index 56ebd7f6..a0237336 100644 --- a/modnet/models/vanilla.py +++ b/modnet/models/vanilla.py @@ -401,14 +401,11 @@ def fit( targ = prop[0] if self.multi_label: y_inner = np.stack(val_data.df_targets[targ].values) - if loss is None: - loss = "binary_crossentropy" else: y_inner = tf.keras.utils.to_categorical( val_data.df_targets[targ].values, num_classes=self.num_classes[targ], ) - loss = "categorical_crossentropy" else: y_inner = val_data.df_targets[prop].values.astype( np.float64, copy=False From 235d2ea5655781674726ee39285858c86b20fb35 Mon Sep 17 00:00:00 2001 From: kueltzen Date: Fri, 25 Oct 2024 18:01:39 +0200 Subject: [PATCH 2/8] Added function for creating shuffled, (stratified) validation data. --- modnet/models/vanilla.py | 48 +++++++++++++++++++++++++++++++++++++- modnet/tests/test_model.py | 6 ++--- 2 files changed, 50 insertions(+), 4 deletions(-) diff --git a/modnet/models/vanilla.py b/modnet/models/vanilla.py index a0237336..f8a92e9f 100644 --- a/modnet/models/vanilla.py +++ b/modnet/models/vanilla.py @@ -412,6 +412,19 @@ def fit( ) val_y.append(y_inner) validation_data = (val_x, val_y) + elif val_fraction > 0: + str_col = ( + [a_idx for a_idx, a in enumerate(y) if not isinstance(a, np.float64)][0] + if max(self.num_classes.values()) >= 2 + else None + ) + x, y, validation_data = generate_shuffled_and_stratified_val_data( + x=x, + y=y, + val_fraction=val_fraction, + classification=max(self.num_classes.values()) >= 2, + str_col=str_col, + ) else: validation_data = None @@ -424,7 +437,7 @@ def fit( # Optionally set up print callback if verbose: - if val_fraction > 0 or validation_data: + if validation_data: if self._multi_target and val_key is not None: val_metric_key = f"val_{val_key}_mae" else: @@ -1531,3 +1544,36 @@ def validate_model( def map_validate_model(kwargs): return validate_model(**kwargs) + + +def generate_shuffled_and_stratified_val_data( + x: np.ndarray, + y: list, + val_fraction: float, + classification: bool, + str_col: int | None, +): + """ + Generate validation data that is shuffled and, if classification, stratified. + """ + if classification: + if isinstance(y[str_col][0], list) or isinstance(y[str_col][0], np.ndarray): + ycv = np.argmax(y[str_col], axis=1) + else: + ycv = y[str_col] + train_idx, val_idx = train_test_split( + range(len(x)), + test_size=val_fraction, + random_state=42, + shuffle=True, + stratify=ycv, + ) + else: + train_idx, val_idx = train_test_split( + range(len(x)), test_size=val_fraction, random_state=42, shuffle=True + ) + return ( + x[train_idx], + [t[train_idx] for t in y], + (x[val_idx], [t[val_idx] for t in y]), + ) diff --git a/modnet/tests/test_model.py b/modnet/tests/test_model.py index c4853e07..ff9095d3 100644 --- a/modnet/tests/test_model.py +++ b/modnet/tests/test_model.py @@ -20,7 +20,7 @@ def test_train_small_model_single_target(subset_moddata, tf_session): n_feat=10, ) - model.fit(data, epochs=2) + model.fit(data, epochs=2, val_fraction=0.15) model.predict(data) assert not np.isnan(model.evaluate(data)) @@ -50,7 +50,7 @@ def is_metal(egap): n_feat=10, ) - model.fit(data, epochs=2) + model.fit(data, epochs=2, val_fraction=0.15) assert not np.isnan(model.evaluate(data)) @@ -71,7 +71,7 @@ def test_train_small_model_multi_target(subset_moddata, tf_session): n_feat=10, ) - model.fit(data, epochs=2) + model.fit(data, epochs=2, val_fraction=0.15) model.predict(data) assert not np.isnan(model.evaluate(data)) From 8ad5dddc120b536eaeae6f433b809a9e3a62ed90 Mon Sep 17 00:00:00 2001 From: kueltzen Date: Fri, 25 Oct 2024 19:06:55 +0200 Subject: [PATCH 3/8] Replaced train_test_split in fit_preset and FitGenetic with custom function. --- modnet/hyper_opt/fit_genetic.py | 13 ++++++--- modnet/models/__init__.py | 9 ++++-- modnet/models/ensemble.py | 12 ++++++-- modnet/models/vanilla.py | 52 +++++++++++++++++++-------------- 4 files changed, 55 insertions(+), 31 deletions(-) diff --git a/modnet/hyper_opt/fit_genetic.py b/modnet/hyper_opt/fit_genetic.py index 03fae50d..eef82d2c 100644 --- a/modnet/hyper_opt/fit_genetic.py +++ b/modnet/hyper_opt/fit_genetic.py @@ -4,9 +4,12 @@ from typing import List, Optional, Dict, Union, Callable import numpy as np import tensorflow as tf -from sklearn.model_selection import train_test_split from modnet.preprocessing import MODData -from modnet.models import MODNetModel, EnsembleMODNetModel +from modnet.models import ( + MODNetModel, + EnsembleMODNetModel, + generate_shuffled_and_stratified_val_split, +) from modnet.utils import LOG import multiprocessing import tqdm @@ -456,8 +459,10 @@ def function_fitness( ) if not nested: splits = [ - train_test_split( - range(len(self.train_data.df_featurized)), test_size=val_fraction + generate_shuffled_and_stratified_val_split( + y=self.train_data.df_targets.values, + val_fraction=val_fraction, + classification=max(self.num_classes.values()) >= 2, ) ] n_splits = 1 diff --git a/modnet/models/__init__.py b/modnet/models/__init__.py index a48034ba..40580a26 100644 --- a/modnet/models/__init__.py +++ b/modnet/models/__init__.py @@ -1,6 +1,6 @@ import warnings -from .vanilla import MODNetModel +from .vanilla import MODNetModel, generate_shuffled_and_stratified_val_split try: from .bayesian import BayesianMODNetModel @@ -14,4 +14,9 @@ from .ensemble import EnsembleMODNetModel -__all__ = ("MODNetModel", "BayesianMODNetModel", "EnsembleMODNetModel") +__all__ = ( + "MODNetModel", + "BayesianMODNetModel", + "EnsembleMODNetModel", + "generate_shuffled_and_stratified_val_split", +) diff --git a/modnet/models/ensemble.py b/modnet/models/ensemble.py index a87954b3..8263a66a 100644 --- a/modnet/models/ensemble.py +++ b/modnet/models/ensemble.py @@ -13,9 +13,11 @@ import tensorflow as tf from sklearn.utils import resample -from sklearn.model_selection import train_test_split -from modnet.models.vanilla import MODNetModel +from modnet.models.vanilla import ( + MODNetModel, + generate_shuffled_and_stratified_val_split, +) from modnet import __version__ from modnet.utils import LOG from modnet.preprocessing import MODData @@ -306,7 +308,11 @@ def fit_preset( ) if not nested: splits = [ - train_test_split(range(len(data.df_featurized)), test_size=val_fraction) + generate_shuffled_and_stratified_val_split( + y=data.df_targets.values, + val_fraction=val_fraction, + classification=classification, + ) ] n_splits = 1 else: diff --git a/modnet/models/vanilla.py b/modnet/models/vanilla.py index f8a92e9f..140c9a84 100644 --- a/modnet/models/vanilla.py +++ b/modnet/models/vanilla.py @@ -25,7 +25,7 @@ import tqdm -__all__ = ("MODNetModel",) +__all__ = ("MODNetModel", "generate_shuffled_and_stratified_val_split") class MODNetModel: @@ -413,17 +413,11 @@ def fit( val_y.append(y_inner) validation_data = (val_x, val_y) elif val_fraction > 0: - str_col = ( - [a_idx for a_idx, a in enumerate(y) if not isinstance(a, np.float64)][0] - if max(self.num_classes.values()) >= 2 - else None - ) x, y, validation_data = generate_shuffled_and_stratified_val_data( x=x, y=y, val_fraction=val_fraction, classification=max(self.num_classes.values()) >= 2, - str_col=str_col, ) else: validation_data = None @@ -587,7 +581,11 @@ def fit_preset( ) if not nested: splits = [ - train_test_split(range(len(data.df_featurized)), test_size=val_fraction) + generate_shuffled_and_stratified_val_split( + y=data.df_targets.values, + val_fraction=val_fraction, + classification=classification, + ) ] n_splits = 1 else: @@ -1546,32 +1544,42 @@ def map_validate_model(kwargs): return validate_model(**kwargs) -def generate_shuffled_and_stratified_val_data( - x: np.ndarray, - y: list, - val_fraction: float, - classification: bool, - str_col: int | None, +def generate_shuffled_and_stratified_val_split( + y: list | np.ndarray, val_fraction: float, classification: bool ): """ - Generate validation data that is shuffled and, if classification, stratified. + Generate train validation split that is shuffled, reproducible and, if classification, stratified. """ if classification: - if isinstance(y[str_col][0], list) or isinstance(y[str_col][0], np.ndarray): - ycv = np.argmax(y[str_col], axis=1) + if isinstance(y[0][0], list) or isinstance(y[0][0], np.ndarray): + ycv = np.argmax(y[0], axis=1) else: - ycv = y[str_col] - train_idx, val_idx = train_test_split( - range(len(x)), + ycv = y[0] + return train_test_split( + range(len(y[0])), test_size=val_fraction, random_state=42, shuffle=True, stratify=ycv, ) else: - train_idx, val_idx = train_test_split( - range(len(x)), test_size=val_fraction, random_state=42, shuffle=True + return train_test_split( + range(len(y[0])), test_size=val_fraction, random_state=42, shuffle=True ) + + +def generate_shuffled_and_stratified_val_data( + x: np.ndarray, + y: list, + val_fraction: float, + classification: bool, +): + """ + Generate train and validation data that is shuffled, reproducible and, if classification, stratified. + """ + train_idx, val_idx = generate_shuffled_and_stratified_val_split( + y=y, val_fraction=val_fraction, classification=classification + ) return ( x[train_idx], [t[train_idx] for t in y], From 2c21ce288f38edc25f69f0a85627aca53e3fbfaa Mon Sep 17 00:00:00 2001 From: kueltzen Date: Mon, 28 Oct 2024 07:52:04 +0100 Subject: [PATCH 4/8] Added custom validation data generation to DeprecaedMODNetModel. --- modnet/models/vanilla.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/modnet/models/vanilla.py b/modnet/models/vanilla.py index 140c9a84..6f4c66ca 100644 --- a/modnet/models/vanilla.py +++ b/modnet/models/vanilla.py @@ -1305,6 +1305,13 @@ def fit( ) val_y.append(y_inner) validation_data = (val_x, val_y) + elif val_fraction > 0: + x, y, validation_data = generate_shuffled_and_stratified_val_data( + x=x, + y=y, + val_fraction=val_fraction, + classification=max(self.num_classes.values()) >= 2, + ) else: validation_data = None @@ -1315,7 +1322,7 @@ def fit( # Optionally set up print callback if verbose: - if val_fraction > 0 or validation_data: + if validation_data: if self._multi_target and val_key is not None: val_metric_key = f"val_{val_key}_mae" else: From 2f9fcd634ae503d63a7267d8b8c5eef67a20dde8 Mon Sep 17 00:00:00 2001 From: kueltzen Date: Mon, 28 Oct 2024 08:47:21 +0100 Subject: [PATCH 5/8] Added stratification column index. --- modnet/hyper_opt/fit_genetic.py | 10 ++++++++++ modnet/models/vanilla.py | 21 ++++++++++++++++----- 2 files changed, 26 insertions(+), 5 deletions(-) diff --git a/modnet/hyper_opt/fit_genetic.py b/modnet/hyper_opt/fit_genetic.py index eef82d2c..a1713ac4 100644 --- a/modnet/hyper_opt/fit_genetic.py +++ b/modnet/hyper_opt/fit_genetic.py @@ -458,11 +458,21 @@ def function_fitness( classification=max(self.num_classes.values()) >= 2, ) if not nested: + str_col = ( + [ + col_idx + for col_idx, col in enumerate(self.train_data.df_targets.columns) + if self.num_classes[col] >= 2 + ][0] + if max(self.num_classes.values()) >= 2 + else None + ) # TODO different nestings of targets # TODO not all properties of df_targets may be learned! splits = [ generate_shuffled_and_stratified_val_split( y=self.train_data.df_targets.values, val_fraction=val_fraction, classification=max(self.num_classes.values()) >= 2, + str_col=str_col, ) ] n_splits = 1 diff --git a/modnet/models/vanilla.py b/modnet/models/vanilla.py index 6f4c66ca..4a7a1cba 100644 --- a/modnet/models/vanilla.py +++ b/modnet/models/vanilla.py @@ -413,11 +413,21 @@ def fit( val_y.append(y_inner) validation_data = (val_x, val_y) elif val_fraction > 0: + str_col = ( + [ + prop_idx + for prop_idx, prop in enumerate(self.targets_groups) + if self.num_classes[prop[0]] >= 2 + ][0] + if max(self.num_classes.values()) >= 2 + else None + ) x, y, validation_data = generate_shuffled_and_stratified_val_data( x=x, y=y, val_fraction=val_fraction, classification=max(self.num_classes.values()) >= 2, + str_col=str_col, ) else: validation_data = None @@ -1552,18 +1562,18 @@ def map_validate_model(kwargs): def generate_shuffled_and_stratified_val_split( - y: list | np.ndarray, val_fraction: float, classification: bool + y: list | np.ndarray, val_fraction: float, classification: bool, str_col: int | None ): """ Generate train validation split that is shuffled, reproducible and, if classification, stratified. """ if classification: - if isinstance(y[0][0], list) or isinstance(y[0][0], np.ndarray): + if isinstance(y[str_col][0], list) or isinstance(y[str_col][0], np.ndarray): ycv = np.argmax(y[0], axis=1) else: - ycv = y[0] + ycv = y[str_col] return train_test_split( - range(len(y[0])), + range(len(y[str_col])), test_size=val_fraction, random_state=42, shuffle=True, @@ -1580,12 +1590,13 @@ def generate_shuffled_and_stratified_val_data( y: list, val_fraction: float, classification: bool, + str_col: int | None, ): """ Generate train and validation data that is shuffled, reproducible and, if classification, stratified. """ train_idx, val_idx = generate_shuffled_and_stratified_val_split( - y=y, val_fraction=val_fraction, classification=classification + y=y, val_fraction=val_fraction, classification=classification, str_col=str_col ) return ( x[train_idx], From 61a8821f76b9d71950693ed4e441c6379c0a21cd Mon Sep 17 00:00:00 2001 From: kueltzen Date: Mon, 18 Nov 2024 13:56:59 +0100 Subject: [PATCH 6/8] Removed determination of target for stratification, using always first target instead. --- modnet/hyper_opt/fit_genetic.py | 10 ---------- modnet/models/vanilla.py | 21 +++++---------------- 2 files changed, 5 insertions(+), 26 deletions(-) diff --git a/modnet/hyper_opt/fit_genetic.py b/modnet/hyper_opt/fit_genetic.py index e6624807..2913798f 100644 --- a/modnet/hyper_opt/fit_genetic.py +++ b/modnet/hyper_opt/fit_genetic.py @@ -459,21 +459,11 @@ def function_fitness( classification=max(self.num_classes.values()) >= 2, ) if not nested: - str_col = ( - [ - col_idx - for col_idx, col in enumerate(self.train_data.df_targets.columns) - if self.num_classes[col] >= 2 - ][0] - if max(self.num_classes.values()) >= 2 - else None - ) # TODO different nestings of targets # TODO not all properties of df_targets may be learned! splits = [ generate_shuffled_and_stratified_val_split( y=self.train_data.df_targets.values, val_fraction=val_fraction, classification=max(self.num_classes.values()) >= 2, - str_col=str_col, ) ] n_splits = 1 diff --git a/modnet/models/vanilla.py b/modnet/models/vanilla.py index 4a7a1cba..6f4c66ca 100644 --- a/modnet/models/vanilla.py +++ b/modnet/models/vanilla.py @@ -413,21 +413,11 @@ def fit( val_y.append(y_inner) validation_data = (val_x, val_y) elif val_fraction > 0: - str_col = ( - [ - prop_idx - for prop_idx, prop in enumerate(self.targets_groups) - if self.num_classes[prop[0]] >= 2 - ][0] - if max(self.num_classes.values()) >= 2 - else None - ) x, y, validation_data = generate_shuffled_and_stratified_val_data( x=x, y=y, val_fraction=val_fraction, classification=max(self.num_classes.values()) >= 2, - str_col=str_col, ) else: validation_data = None @@ -1562,18 +1552,18 @@ def map_validate_model(kwargs): def generate_shuffled_and_stratified_val_split( - y: list | np.ndarray, val_fraction: float, classification: bool, str_col: int | None + y: list | np.ndarray, val_fraction: float, classification: bool ): """ Generate train validation split that is shuffled, reproducible and, if classification, stratified. """ if classification: - if isinstance(y[str_col][0], list) or isinstance(y[str_col][0], np.ndarray): + if isinstance(y[0][0], list) or isinstance(y[0][0], np.ndarray): ycv = np.argmax(y[0], axis=1) else: - ycv = y[str_col] + ycv = y[0] return train_test_split( - range(len(y[str_col])), + range(len(y[0])), test_size=val_fraction, random_state=42, shuffle=True, @@ -1590,13 +1580,12 @@ def generate_shuffled_and_stratified_val_data( y: list, val_fraction: float, classification: bool, - str_col: int | None, ): """ Generate train and validation data that is shuffled, reproducible and, if classification, stratified. """ train_idx, val_idx = generate_shuffled_and_stratified_val_split( - y=y, val_fraction=val_fraction, classification=classification, str_col=str_col + y=y, val_fraction=val_fraction, classification=classification ) return ( x[train_idx], From eaa5f6e33eb6809e9d75c992ddd8178005e9e65f Mon Sep 17 00:00:00 2001 From: kueltzen Date: Mon, 18 Nov 2024 14:19:52 +0100 Subject: [PATCH 7/8] Type changes --- modnet/models/vanilla.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/modnet/models/vanilla.py b/modnet/models/vanilla.py index 6f4c66ca..6ea0beb7 100644 --- a/modnet/models/vanilla.py +++ b/modnet/models/vanilla.py @@ -1552,7 +1552,7 @@ def map_validate_model(kwargs): def generate_shuffled_and_stratified_val_split( - y: list | np.ndarray, val_fraction: float, classification: bool + y: np.ndarray, val_fraction: float, classification: bool ): """ Generate train validation split that is shuffled, reproducible and, if classification, stratified. @@ -1585,7 +1585,7 @@ def generate_shuffled_and_stratified_val_data( Generate train and validation data that is shuffled, reproducible and, if classification, stratified. """ train_idx, val_idx = generate_shuffled_and_stratified_val_split( - y=y, val_fraction=val_fraction, classification=classification + y=np.array(y), val_fraction=val_fraction, classification=classification ) return ( x[train_idx], From 0b9cbd31b662d4e9ea9697cf7393c617fe694599 Mon Sep 17 00:00:00 2001 From: kueltzen Date: Wed, 18 Dec 2024 17:34:52 +0100 Subject: [PATCH 8/8] Moved generate_shuffled_and_stratified_val_split to modnet.utils, corrected axes of y in generate_shuffled_and_stratified_val_split and generate_shuffled_and_stratified_val_data --- modnet/hyper_opt/fit_genetic.py | 3 +-- modnet/models/__init__.py | 3 +-- modnet/models/ensemble.py | 3 +-- modnet/models/vanilla.py | 36 +++++++-------------------------- modnet/utils.py | 28 +++++++++++++++++++++++++ 5 files changed, 38 insertions(+), 35 deletions(-) diff --git a/modnet/hyper_opt/fit_genetic.py b/modnet/hyper_opt/fit_genetic.py index 2913798f..cfa98796 100644 --- a/modnet/hyper_opt/fit_genetic.py +++ b/modnet/hyper_opt/fit_genetic.py @@ -8,9 +8,8 @@ from modnet.models import ( MODNetModel, EnsembleMODNetModel, - generate_shuffled_and_stratified_val_split, ) -from modnet.utils import LOG +from modnet.utils import LOG, generate_shuffled_and_stratified_val_split import multiprocessing import tqdm diff --git a/modnet/models/__init__.py b/modnet/models/__init__.py index 40580a26..f5fdc14f 100644 --- a/modnet/models/__init__.py +++ b/modnet/models/__init__.py @@ -1,6 +1,6 @@ import warnings -from .vanilla import MODNetModel, generate_shuffled_and_stratified_val_split +from .vanilla import MODNetModel try: from .bayesian import BayesianMODNetModel @@ -18,5 +18,4 @@ "MODNetModel", "BayesianMODNetModel", "EnsembleMODNetModel", - "generate_shuffled_and_stratified_val_split", ) diff --git a/modnet/models/ensemble.py b/modnet/models/ensemble.py index 06ee7972..51bcd0b0 100644 --- a/modnet/models/ensemble.py +++ b/modnet/models/ensemble.py @@ -16,10 +16,9 @@ from modnet.models.vanilla import ( MODNetModel, - generate_shuffled_and_stratified_val_split, ) from modnet import __version__ -from modnet.utils import LOG +from modnet.utils import LOG, generate_shuffled_and_stratified_val_split from modnet.preprocessing import MODData __all__ = ("EnsembleMODNetModel",) diff --git a/modnet/models/vanilla.py b/modnet/models/vanilla.py index 6ea0beb7..50521b6d 100644 --- a/modnet/models/vanilla.py +++ b/modnet/models/vanilla.py @@ -13,19 +13,18 @@ import numpy as np import warnings from sklearn.preprocessing import StandardScaler, MinMaxScaler -from sklearn.model_selection import train_test_split from sklearn.metrics import mean_absolute_error, mean_squared_error, roc_auc_score from sklearn.impute import SimpleImputer from sklearn.pipeline import Pipeline import tensorflow as tf from modnet.preprocessing import MODData -from modnet.utils import LOG +from modnet.utils import LOG, generate_shuffled_and_stratified_val_split from modnet import __version__ import tqdm -__all__ = ("MODNetModel", "generate_shuffled_and_stratified_val_split") +__all__ = "MODNetModel" class MODNetModel: @@ -1551,41 +1550,20 @@ def map_validate_model(kwargs): return validate_model(**kwargs) -def generate_shuffled_and_stratified_val_split( - y: np.ndarray, val_fraction: float, classification: bool -): - """ - Generate train validation split that is shuffled, reproducible and, if classification, stratified. - """ - if classification: - if isinstance(y[0][0], list) or isinstance(y[0][0], np.ndarray): - ycv = np.argmax(y[0], axis=1) - else: - ycv = y[0] - return train_test_split( - range(len(y[0])), - test_size=val_fraction, - random_state=42, - shuffle=True, - stratify=ycv, - ) - else: - return train_test_split( - range(len(y[0])), test_size=val_fraction, random_state=42, shuffle=True - ) - - def generate_shuffled_and_stratified_val_data( x: np.ndarray, - y: list, + y: list[np.ndarray], val_fraction: float, classification: bool, ): """ Generate train and validation data that is shuffled, reproducible and, if classification, stratified. + Please note for classification tasks that stratification is performed on first target. + y: list of 2D array with combined dimensions of (n_targets, n_samples, 1 or n_classes) """ + y_split = np.array([y[0]]).swapaxes(1, 0) train_idx, val_idx = generate_shuffled_and_stratified_val_split( - y=np.array(y), val_fraction=val_fraction, classification=classification + y=y_split, val_fraction=val_fraction, classification=classification ) return ( x[train_idx], diff --git a/modnet/utils.py b/modnet/utils.py index 69c00183..af002ab7 100644 --- a/modnet/utils.py +++ b/modnet/utils.py @@ -1,4 +1,6 @@ import logging +import numpy as np +from sklearn.model_selection import train_test_split import sys LOG = logging.getLogger("modnet") @@ -28,3 +30,29 @@ def get_hash_of_file(fname, algo="sha512"): fb = f.read(block_size) return _hash.hexdigest() + + +def generate_shuffled_and_stratified_val_split( + y: np.ndarray, val_fraction: float, classification: bool +): + """ + Generate train validation split that is shuffled, reproducible and, if classification, stratified. + Please note for classification tasks that stratification is performed on first target. + y: np.ndarray (n_samples, n_targets) or (n_samples, n_targets, n_classes) + """ + if classification: + if isinstance(y[0][0], list) or isinstance(y[0][0], np.ndarray): + ycv = np.argmax(y[:, 0], axis=1) + else: + ycv = y[:, 0] + return train_test_split( + range(len(y)), + test_size=val_fraction, + random_state=42, + shuffle=True, + stratify=ycv, + ) + else: + return train_test_split( + range(len(y)), test_size=val_fraction, random_state=42, shuffle=True + )