Skip to content

Commit cd03023

Browse files
committed
minor changes in data split code
- removed list comprehension from data split logic - used dataframe operations instead as they are faster - remove looping for msss.split as no need for it, used `next` instead
1 parent 8c9dfe1 commit cd03023

File tree

1 file changed

+41
-32
lines changed

1 file changed

+41
-32
lines changed

chebai/preprocessing/datasets/chebi.py

Lines changed: 41 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -22,12 +22,10 @@
2222
import pandas as pd
2323
import requests
2424
import torch
25-
import yaml
2625
from iterstrat.ml_stratifiers import (
2726
MultilabelStratifiedKFold,
2827
MultilabelStratifiedShuffleSplit,
2928
)
30-
from torch.utils.data import DataLoader
3129

3230
from chebai.preprocessing import reader as dr
3331
from chebai.preprocessing.datasets.base import XYBaseDataModule
@@ -363,36 +361,51 @@ def setup_processed(self):
363361
self._chebi_version_train_obj.setup()
364362

365363
def get_test_split(self, df: pd.DataFrame, seed: int = None):
364+
"""
365+
Split the input DataFrame into training and testing sets based on multilabel stratified sampling.
366+
367+
This method uses MultilabelStratifiedShuffleSplit to split the data such that the distribution of labels
368+
in the training and testing sets is approximately the same. The split is based on the "labels" column
369+
in the DataFrame.
370+
371+
Parameters:
372+
----------
373+
df : pd.DataFrame
374+
The input DataFrame containing the data to be split. It must contain a column named "labels"
375+
with the multilabel data.
376+
377+
seed : int, optional
378+
The random seed to be used for reproducibility. Default is None.
379+
380+
Returns:
381+
-------
382+
df_train : pd.DataFrame
383+
The training set split from the input DataFrame.
384+
385+
df_test : pd.DataFrame
386+
The testing set split from the input DataFrame.
387+
"""
366388
print("\nGet test data split")
367389

368-
# df_list = df.values.tolist()
369-
# df_list = [row[1] for row in df_list]
370390
labels_list = df["labels"].tolist()
371391

372392
test_size = 1 - self.train_split - (1 - self.train_split) ** 2
373393
msss = MultilabelStratifiedShuffleSplit(
374394
n_splits=1, test_size=test_size, random_state=seed
375395
)
376396

377-
train_split = []
378-
test_split = []
379-
for train_split, test_split in msss.split(
380-
labels_list,
381-
labels_list,
382-
):
383-
train_split = train_split
384-
test_split = test_split
385-
break
386-
df_train = df.iloc[train_split]
387-
df_test = df.iloc[test_split]
397+
train_indices, test_indices = next(msss.split(labels_list, labels_list))
398+
399+
df_train = df.iloc[train_indices]
400+
df_test = df.iloc[test_indices]
388401
return df_train, df_test
389402

390403
def get_train_val_splits_given_test(
391404
self, df: pd.DataFrame, test_df: pd.DataFrame, seed: int = None
392405
):
393406
"""
394407
Split the dataset into train and validation sets, given a test set.
395-
Use test set (e.g., loaded from another chebi version or generated in get_test_split), avoid overlap
408+
Use test set (e.g., loaded from another chebi version or generated in get_test_split), to avoid overlap
396409
397410
Args:
398411
df (pd.DataFrame): The original dataset.
@@ -404,12 +417,11 @@ def get_train_val_splits_given_test(
404417
"""
405418
print(f"Split dataset into train / val with given test set")
406419

407-
df_trainval = df
408420
test_ids = test_df["ident"].tolist()
409-
mask = [trainval_id not in test_ids for trainval_id in df_trainval["ident"]]
410-
df_trainval = df_trainval[mask]
411-
# df_trainval_list = df_trainval.values.tolist()
412-
# df_trainval_list = [row[3:] for row in df_trainval_list]
421+
# ---- list comprehension degrades performance, dataframe operations are faster
422+
# mask = [trainval_id not in test_ids for trainval_id in df_trainval["ident"]]
423+
# df_trainval = df_trainval[mask]
424+
df_trainval = df[~df["ident"].isin(test_ids)]
413425
labels_list_trainval = df_trainval["labels"].tolist()
414426

415427
if self.use_inner_cross_validation:
@@ -437,16 +449,13 @@ def get_train_val_splits_given_test(
437449
msss = MultilabelStratifiedShuffleSplit(
438450
n_splits=1, test_size=test_size, random_state=seed
439451
)
440-
train_split = []
441-
validation_split = []
442-
for train_split, validation_split in msss.split(
443-
labels_list_trainval, labels_list_trainval
444-
):
445-
train_split = train_split
446-
validation_split = validation_split
447452

448-
df_validation = df_trainval.iloc[validation_split]
449-
df_train = df_trainval.iloc[train_split]
453+
train_indices, validation_indices = next(
454+
msss.split(labels_list_trainval, labels_list_trainval)
455+
)
456+
457+
df_validation = df_trainval.iloc[validation_indices]
458+
df_train = df_trainval.iloc[train_indices]
450459
return df_train, df_validation
451460

452461
@property
@@ -632,7 +641,7 @@ def _get_dynamic_splits(self):
632641
data_chebi_version = torch.load(os.path.join(self.processed_dir, filename))
633642
except FileNotFoundError:
634643
raise FileNotFoundError(
635-
f"File {filename} doesn't exists. "
644+
f"File data.pt doesn't exists. "
636645
f"Please call 'prepare_data' and/or 'setup' methods to generate the dataset files"
637646
)
638647

@@ -654,7 +663,7 @@ def _get_dynamic_splits(self):
654663
)
655664
except FileNotFoundError:
656665
raise FileNotFoundError(
657-
f"File {filename_train} doesn't exists related to chebi_version_train {self.chebi_version_train}."
666+
f"File data.pt doesn't exists related to chebi_version_train {self.chebi_version_train}."
658667
f"Please call 'prepare_data' and/or 'setup' methods to generate the dataset files"
659668
)
660669

0 commit comments

Comments
 (0)