Skip to content

Commit 77e96e8

Browse files
committed
add test and validation split parameters
1 parent 2ca72e9 commit 77e96e8

File tree

1 file changed

+15
-9
lines changed
  • chebai/preprocessing/datasets

1 file changed

+15
-9
lines changed

chebai/preprocessing/datasets/base.py

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,8 @@ class XYBaseDataModule(LightningDataModule):
2929
3030
Args:
3131
batch_size (int): The batch size for data loading. Default is 1.
32-
train_split (float): The ratio of training data to total data and of test data to (validation + test) data. Default is 0.85.
32+
test_split (float): The ratio of test data to total data. Default is 0.1.
33+
validation_split (float): The ratio of validation data to total data. Default is 0.05.
3334
reader_kwargs (dict): Additional keyword arguments to be passed to the data reader. Default is None.
3435
prediction_kind (str): The kind of prediction to be performed (only relevant for the predict_dataloader). Default is "test".
3536
data_limit (Optional[int]): The maximum number of data samples to load. If set to None, the complete dataset will be used. Default is None.
@@ -45,7 +46,8 @@ class XYBaseDataModule(LightningDataModule):
4546
Attributes:
4647
READER (DataReader): The data reader class to use.
4748
reader (DataReader): An instance of the data reader class.
48-
train_split (float): The ratio of training data to total data.
49+
test_split (float): The ratio of test data to total data.
50+
validation_split (float): The ratio of validation data to total data.
4951
batch_size (int): The batch size for data loading.
5052
prediction_kind (str): The kind of prediction to be performed.
5153
data_limit (Optional[int]): The maximum number of data samples to load.
@@ -68,7 +70,8 @@ class XYBaseDataModule(LightningDataModule):
6870
def __init__(
6971
self,
7072
batch_size: int = 1,
71-
train_split: float = 0.85,
73+
test_split: Optional[float] = 0.1,
74+
validation_split: Optional[float] = 0.05,
7275
reader_kwargs: Optional[dict] = None,
7376
prediction_kind: str = "test",
7477
data_limit: Optional[int] = None,
@@ -86,7 +89,9 @@ def __init__(
8689
if reader_kwargs is None:
8790
reader_kwargs = dict()
8891
self.reader = self.READER(**reader_kwargs)
89-
self.train_split = train_split
92+
self.test_split = test_split
93+
self.validation_split = validation_split
94+
9095
self.batch_size = batch_size
9196
self.prediction_kind = prediction_kind
9297
self.data_limit = data_limit
@@ -1083,16 +1088,17 @@ def get_train_val_splits_given_test(
10831088

10841089
return folds
10851090

1086-
# scale val set size by 1/self.train_split to compensate for (hypothetical) test set size (1-self.train_split)
1087-
test_size = ((1 - self.train_split) ** 2) / self.train_split
1088-
10891091
if len(labels_list_trainval[0]) > 1:
10901092
splitter = MultilabelStratifiedShuffleSplit(
1091-
n_splits=1, test_size=test_size, random_state=seed
1093+
n_splits=1,
1094+
test_size=self.validation_split / (1 - self.test_split),
1095+
random_state=seed,
10921096
)
10931097
else:
10941098
splitter = StratifiedShuffleSplit(
1095-
n_splits=1, test_size=test_size, random_state=seed
1099+
n_splits=1,
1100+
test_size=self.validation_split / (1 - self.test_split),
1101+
random_state=seed,
10961102
)
10971103

10981104
train_indices, validation_indices = next(

0 commit comments

Comments
 (0)