@@ -29,7 +29,8 @@ class XYBaseDataModule(LightningDataModule):
2929
3030 Args:
3131 batch_size (int): The batch size for data loading. Default is 1.
32- train_split (float): The ratio of training data to total data and of test data to (validation + test) data. Default is 0.85.
32+ test_split (float): The ratio of test data to total data. Default is 0.1.
33+ validation_split (float): The ratio of validation data to total data. Default is 0.05.
3334 reader_kwargs (dict): Additional keyword arguments to be passed to the data reader. Default is None.
3435 prediction_kind (str): The kind of prediction to be performed (only relevant for the predict_dataloader). Default is "test".
3536 data_limit (Optional[int]): The maximum number of data samples to load. If set to None, the complete dataset will be used. Default is None.
@@ -45,7 +46,8 @@ class XYBaseDataModule(LightningDataModule):
4546 Attributes:
4647 READER (DataReader): The data reader class to use.
4748 reader (DataReader): An instance of the data reader class.
48- train_split (float): The ratio of training data to total data.
49+ test_split (float): The ratio of test data to total data.
50+ validation_split (float): The ratio of validation data to total data.
4951 batch_size (int): The batch size for data loading.
5052 prediction_kind (str): The kind of prediction to be performed.
5153 data_limit (Optional[int]): The maximum number of data samples to load.
@@ -68,7 +70,8 @@ class XYBaseDataModule(LightningDataModule):
6870 def __init__ (
6971 self ,
7072 batch_size : int = 1 ,
71- train_split : float = 0.85 ,
73+ test_split : Optional [float ] = 0.1 ,
74+ validation_split : Optional [float ] = 0.05 ,
7275 reader_kwargs : Optional [dict ] = None ,
7376 prediction_kind : str = "test" ,
7477 data_limit : Optional [int ] = None ,
@@ -86,7 +89,9 @@ def __init__(
8689 if reader_kwargs is None :
8790 reader_kwargs = dict ()
8891 self .reader = self .READER (** reader_kwargs )
89- self .train_split = train_split
92+ self .test_split = test_split
93+ self .validation_split = validation_split
94+
9095 self .batch_size = batch_size
9196 self .prediction_kind = prediction_kind
9297 self .data_limit = data_limit
@@ -1083,16 +1088,17 @@ def get_train_val_splits_given_test(
10831088
10841089 return folds
10851090
1086- # scale val set size by 1/self.train_split to compensate for (hypothetical) test set size (1-self.train_split)
1087- test_size = ((1 - self .train_split ) ** 2 ) / self .train_split
1088-
10891091 if len (labels_list_trainval [0 ]) > 1 :
10901092 splitter = MultilabelStratifiedShuffleSplit (
1091- n_splits = 1 , test_size = test_size , random_state = seed
1093+ n_splits = 1 ,
1094+ test_size = self .validation_split / (1 - self .test_split ),
1095+ random_state = seed ,
10921096 )
10931097 else :
10941098 splitter = StratifiedShuffleSplit (
1095- n_splits = 1 , test_size = test_size , random_state = seed
1099+ n_splits = 1 ,
1100+ test_size = self .validation_split / (1 - self .test_split ),
1101+ random_state = seed ,
10961102 )
10971103
10981104 train_indices , validation_indices = next (
0 commit comments