Skip to content

Commit 7668858

Browse files
committed
fix test splits
1 parent 77e96e8 commit 7668858

File tree

1 file changed

+2
-4
lines changed
  • chebai/preprocessing/datasets

1 file changed

+2
-4
lines changed

chebai/preprocessing/datasets/base.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1027,15 +1027,13 @@ def get_test_split(
10271027

10281028
labels_list = df["labels"].tolist()
10291029

1030-
test_size = 1 - self.train_split - (1 - self.train_split) ** 2
1031-
10321030
if len(labels_list[0]) > 1:
10331031
splitter = MultilabelStratifiedShuffleSplit(
1034-
n_splits=1, test_size=test_size, random_state=seed
1032+
n_splits=1, test_size=self.test_split, random_state=seed
10351033
)
10361034
else:
10371035
splitter = StratifiedShuffleSplit(
1038-
n_splits=1, test_size=test_size, random_state=seed
1036+
n_splits=1, test_size=self.test_split, random_state=seed
10391037
)
10401038

10411039
train_indices, test_indices = next(splitter.split(labels_list, labels_list))

0 commit comments

Comments
 (0)