2222import pandas as pd
2323import requests
2424import torch
25- import yaml
2625from iterstrat .ml_stratifiers import (
2726 MultilabelStratifiedKFold ,
2827 MultilabelStratifiedShuffleSplit ,
2928)
30- from torch .utils .data import DataLoader
3129
3230from chebai .preprocessing import reader as dr
3331from chebai .preprocessing .datasets .base import XYBaseDataModule
@@ -363,36 +361,51 @@ def setup_processed(self):
363361 self ._chebi_version_train_obj .setup ()
364362
365363 def get_test_split (self , df : pd .DataFrame , seed : int = None ):
364+ """
365+ Split the input DataFrame into training and testing sets based on multilabel stratified sampling.
366+
367+ This method uses MultilabelStratifiedShuffleSplit to split the data such that the distribution of labels
368+ in the training and testing sets is approximately the same. The split is based on the "labels" column
369+ in the DataFrame.
370+
371+ Parameters:
372+ ----------
373+ df : pd.DataFrame
374+ The input DataFrame containing the data to be split. It must contain a column named "labels"
375+ with the multilabel data.
376+
377+ seed : int, optional
378+ The random seed to be used for reproducibility. Default is None.
379+
380+ Returns:
381+ -------
382+ df_train : pd.DataFrame
383+ The training set split from the input DataFrame.
384+
385+ df_test : pd.DataFrame
386+ The testing set split from the input DataFrame.
387+ """
366388 print ("\n Get test data split" )
367389
368- # df_list = df.values.tolist()
369- # df_list = [row[1] for row in df_list]
370390 labels_list = df ["labels" ].tolist ()
371391
372392 test_size = 1 - self .train_split - (1 - self .train_split ) ** 2
373393 msss = MultilabelStratifiedShuffleSplit (
374394 n_splits = 1 , test_size = test_size , random_state = seed
375395 )
376396
377- train_split = []
378- test_split = []
379- for train_split , test_split in msss .split (
380- labels_list ,
381- labels_list ,
382- ):
383- train_split = train_split
384- test_split = test_split
385- break
386- df_train = df .iloc [train_split ]
387- df_test = df .iloc [test_split ]
397+ train_indices , test_indices = next (msss .split (labels_list , labels_list ))
398+
399+ df_train = df .iloc [train_indices ]
400+ df_test = df .iloc [test_indices ]
388401 return df_train , df_test
389402
390403 def get_train_val_splits_given_test (
391404 self , df : pd .DataFrame , test_df : pd .DataFrame , seed : int = None
392405 ):
393406 """
394407 Split the dataset into train and validation sets, given a test set.
395- Use test set (e.g., loaded from another chebi version or generated in get_test_split), avoid overlap
408+ Use test set (e.g., loaded from another chebi version or generated in get_test_split), to avoid overlap
396409
397410 Args:
398411 df (pd.DataFrame): The original dataset.
@@ -404,12 +417,11 @@ def get_train_val_splits_given_test(
404417 """
405418 print (f"Split dataset into train / val with given test set" )
406419
407- df_trainval = df
408420 test_ids = test_df ["ident" ].tolist ()
409- mask = [ trainval_id not in test_ids for trainval_id in df_trainval [ "ident" ]]
410- df_trainval = df_trainval [mask ]
411- # df_trainval_list = df_trainval.values.tolist()
412- # df_trainval_list = [row[3:] for row in df_trainval_list ]
421+ # ---- list comprehension degrades performance, dataframe operations are faster
422+ # mask = [trainval_id not in test_ids for trainval_id in df_trainval["ident"] ]
423+ # df_trainval = df_trainval[mask]
424+ df_trainval = df [ ~ df [ "ident" ]. isin ( test_ids ) ]
413425 labels_list_trainval = df_trainval ["labels" ].tolist ()
414426
415427 if self .use_inner_cross_validation :
@@ -437,16 +449,13 @@ def get_train_val_splits_given_test(
437449 msss = MultilabelStratifiedShuffleSplit (
438450 n_splits = 1 , test_size = test_size , random_state = seed
439451 )
440- train_split = []
441- validation_split = []
442- for train_split , validation_split in msss .split (
443- labels_list_trainval , labels_list_trainval
444- ):
445- train_split = train_split
446- validation_split = validation_split
447452
448- df_validation = df_trainval .iloc [validation_split ]
449- df_train = df_trainval .iloc [train_split ]
453+ train_indices , validation_indices = next (
454+ msss .split (labels_list_trainval , labels_list_trainval )
455+ )
456+
457+ df_validation = df_trainval .iloc [validation_indices ]
458+ df_train = df_trainval .iloc [train_indices ]
450459 return df_train , df_validation
451460
452461 @property
@@ -632,7 +641,7 @@ def _get_dynamic_splits(self):
632641 data_chebi_version = torch .load (os .path .join (self .processed_dir , filename ))
633642 except FileNotFoundError :
634643 raise FileNotFoundError (
635- f"File { filename } doesn't exists. "
644+ f"File data.pt doesn't exists. "
636645 f"Please call 'prepare_data' and/or 'setup' methods to generate the dataset files"
637646 )
638647
@@ -654,7 +663,7 @@ def _get_dynamic_splits(self):
654663 )
655664 except FileNotFoundError :
656665 raise FileNotFoundError (
657- f"File { filename_train } doesn't exists related to chebi_version_train { self .chebi_version_train } ."
666+ f"File data.pt doesn't exists related to chebi_version_train { self .chebi_version_train } ."
658667 f"Please call 'prepare_data' and/or 'setup' methods to generate the dataset files"
659668 )
660669
0 commit comments