We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 77e96e8 commit 7668858Copy full SHA for 7668858
chebai/preprocessing/datasets/base.py
@@ -1027,15 +1027,13 @@ def get_test_split(
1027
1028
labels_list = df["labels"].tolist()
1029
1030
- test_size = 1 - self.train_split - (1 - self.train_split) ** 2
1031
-
1032
if len(labels_list[0]) > 1:
1033
splitter = MultilabelStratifiedShuffleSplit(
1034
- n_splits=1, test_size=test_size, random_state=seed
+ n_splits=1, test_size=self.test_split, random_state=seed
1035
)
1036
else:
1037
splitter = StratifiedShuffleSplit(
1038
1039
1040
1041
train_indices, test_indices = next(splitter.split(labels_list, labels_list))
0 commit comments