Skip to content
Open
14 changes: 9 additions & 5 deletions modnet/hyper_opt/fit_genetic.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,12 @@
from typing import List, Optional, Dict, Union, Callable
import numpy as np
import tensorflow as tf
from sklearn.model_selection import train_test_split
from modnet.preprocessing import MODData
from modnet.models import MODNetModel, EnsembleMODNetModel
from modnet.utils import LOG
from modnet.models import (
MODNetModel,
EnsembleMODNetModel,
)
from modnet.utils import LOG, generate_shuffled_and_stratified_val_split
import multiprocessing
import tqdm

Expand Down Expand Up @@ -457,8 +459,10 @@ def function_fitness(
)
if not nested:
splits = [
train_test_split(
range(len(self.train_data.df_featurized)), test_size=val_fraction
generate_shuffled_and_stratified_val_split(
y=self.train_data.df_targets.values,
val_fraction=val_fraction,
classification=max(self.num_classes.values()) >= 2,
)
]
n_splits = 1
Expand Down
6 changes: 5 additions & 1 deletion modnet/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,8 @@

from .ensemble import EnsembleMODNetModel

__all__ = ("MODNetModel", "BayesianMODNetModel", "EnsembleMODNetModel")
__all__ = (
"MODNetModel",
"BayesianMODNetModel",
"EnsembleMODNetModel",
)
13 changes: 9 additions & 4 deletions modnet/models/ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,12 @@
import tensorflow as tf

from sklearn.utils import resample
from sklearn.model_selection import train_test_split

from modnet.models.vanilla import MODNetModel
from modnet.models.vanilla import (
MODNetModel,
)
from modnet import __version__
from modnet.utils import LOG
from modnet.utils import LOG, generate_shuffled_and_stratified_val_split
from modnet.preprocessing import MODData

__all__ = ("EnsembleMODNetModel",)
Expand Down Expand Up @@ -323,7 +324,11 @@ def fit_preset(
)
if not nested:
splits = [
train_test_split(range(len(data.df_featurized)), test_size=val_fraction)
generate_shuffled_and_stratified_val_split(
y=data.df_targets.values,
val_fraction=val_fraction,
classification=classification,
)
]
n_splits = 1
else:
Expand Down
54 changes: 45 additions & 9 deletions modnet/models/vanilla.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,19 +13,18 @@
import numpy as np
import warnings
from sklearn.preprocessing import StandardScaler, MinMaxScaler
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_absolute_error, mean_squared_error, roc_auc_score
from sklearn.impute import SimpleImputer
from sklearn.pipeline import Pipeline
import tensorflow as tf

from modnet.preprocessing import MODData
from modnet.utils import LOG
from modnet.utils import LOG, generate_shuffled_and_stratified_val_split
from modnet import __version__

import tqdm

__all__ = ("MODNetModel",)
__all__ = "MODNetModel"


class MODNetModel:
Expand Down Expand Up @@ -401,20 +400,24 @@ def fit(
targ = prop[0]
if self.multi_label:
y_inner = np.stack(val_data.df_targets[targ].values)
if loss is None:
loss = "binary_crossentropy"
else:
y_inner = tf.keras.utils.to_categorical(
val_data.df_targets[targ].values,
num_classes=self.num_classes[targ],
)
loss = "categorical_crossentropy"
else:
y_inner = val_data.df_targets[prop].values.astype(
np.float64, copy=False
)
val_y.append(y_inner)
validation_data = (val_x, val_y)
elif val_fraction > 0:
x, y, validation_data = generate_shuffled_and_stratified_val_data(
x=x,
y=y,
val_fraction=val_fraction,
classification=max(self.num_classes.values()) >= 2,
)
else:
validation_data = None

Expand All @@ -427,7 +430,7 @@ def fit(

# Optionally set up print callback
if verbose:
if val_fraction > 0 or validation_data:
if validation_data:
if self._multi_target and val_key is not None:
val_metric_key = f"val_{val_key}_mae"
else:
Expand Down Expand Up @@ -577,7 +580,11 @@ def fit_preset(
)
if not nested:
splits = [
train_test_split(range(len(data.df_featurized)), test_size=val_fraction)
generate_shuffled_and_stratified_val_split(
y=data.df_targets.values,
val_fraction=val_fraction,
classification=classification,
)
]
n_splits = 1
else:
Expand Down Expand Up @@ -1297,6 +1304,13 @@ def fit(
)
val_y.append(y_inner)
validation_data = (val_x, val_y)
elif val_fraction > 0:
x, y, validation_data = generate_shuffled_and_stratified_val_data(
x=x,
y=y,
val_fraction=val_fraction,
classification=max(self.num_classes.values()) >= 2,
)
else:
validation_data = None

Expand All @@ -1307,7 +1321,7 @@ def fit(

# Optionally set up print callback
if verbose:
if val_fraction > 0 or validation_data:
if validation_data:
if self._multi_target and val_key is not None:
val_metric_key = f"val_{val_key}_mae"
else:
Expand Down Expand Up @@ -1537,3 +1551,25 @@ def validate_model(

def map_validate_model(kwargs):
return validate_model(**kwargs)


def generate_shuffled_and_stratified_val_data(
x: np.ndarray,
y: list[np.ndarray],
val_fraction: float,
classification: bool,
):
"""
Generate train and validation data that is shuffled, reproducible and, if classification, stratified.
Please note for classification tasks that stratification is performed on first target.
y: list of 2D array with combined dimensions of (n_targets, n_samples, 1 or n_classes)
"""
y_split = np.array([y[0]]).swapaxes(1, 0)
train_idx, val_idx = generate_shuffled_and_stratified_val_split(
y=y_split, val_fraction=val_fraction, classification=classification
)
return (
x[train_idx],
[t[train_idx] for t in y],
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This doesn't seem correct ? y[train_idx] should work, right?. It would pick the wrong columns now.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The y of MODNetModel.fit() and of generate_shuffled_and_stratified_val_data has the first two dimensions (n_target_groups, n_samples) while the y of generate_shuffled_and_stratified_val_split that is used elsewhere (data.df_targets.values) has the first 2 dimensions (n_samples, n_targets).

(x[val_idx], [t[val_idx] for t in y]),
)
6 changes: 3 additions & 3 deletions modnet/tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def test_train_small_model_single_target(subset_moddata, tf_session):
n_feat=10,
)

model.fit(data, epochs=2)
model.fit(data, epochs=2, val_fraction=0.15)
model.predict(data)
assert not np.isnan(model.evaluate(data))

Expand Down Expand Up @@ -50,7 +50,7 @@ def is_metal(egap):
n_feat=10,
)

model.fit(data, epochs=2)
model.fit(data, epochs=2, val_fraction=0.15)
assert not np.isnan(model.evaluate(data))


Expand All @@ -71,7 +71,7 @@ def test_train_small_model_multi_target(subset_moddata, tf_session):
n_feat=10,
)

model.fit(data, epochs=2)
model.fit(data, epochs=2, val_fraction=0.15)
model.predict(data)
assert not np.isnan(model.evaluate(data))

Expand Down
28 changes: 28 additions & 0 deletions modnet/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import logging
import numpy as np
from sklearn.model_selection import train_test_split
import sys

LOG = logging.getLogger("modnet")
Expand Down Expand Up @@ -28,3 +30,29 @@ def get_hash_of_file(fname, algo="sha512"):
fb = f.read(block_size)

return _hash.hexdigest()


def generate_shuffled_and_stratified_val_split(
y: np.ndarray, val_fraction: float, classification: bool
):
"""
Generate train validation split that is shuffled, reproducible and, if classification, stratified.
Please note for classification tasks that stratification is performed on first target.
y: np.ndarray (n_samples, n_targets) or (n_samples, n_targets, n_classes)
"""
if classification:
if isinstance(y[0][0], list) or isinstance(y[0][0], np.ndarray):
ycv = np.argmax(y[:, 0], axis=1)
else:
ycv = y[:, 0]
return train_test_split(
range(len(y)),
test_size=val_fraction,
random_state=42,
shuffle=True,
stratify=ycv,
)
else:
return train_test_split(
range(len(y)), test_size=val_fraction, random_state=42, shuffle=True
)
Loading