Skip to content

Commit 961e596

Browse files
nandwalritikmariosaskolhoestq
authored
Added stratify option to train_test_split function. (#4322)
* Add stratify option to train_test_split * Add utility functions for performing stratified split * Removed unused import * Add suggested changes * Remove unused import from splits.py * Add example usage of train_test_split with stratify arg to docstring * Add test cases to test stratified_train_test_split * Move stratify functions to utils/stratify.py and refactor code. * Fix test cases according to ClassLabel class * Add changes for error handling and recommended changes * Add error handling for KeyErr for stratify_by_column arg * Add tests for checking error handling in stratified train_test_split * Removed unwanted imports * Remove `import datasets` * Update src/datasets/arrow_dataset.py Co-authored-by: Mario Šaško <mario@huggingface.co> Co-authored-by: Quentin Lhoest <42851186+lhoestq@users.noreply.github.com>
1 parent ea4df89 commit 961e596

File tree

3 files changed

+226
-4
lines changed

3 files changed

+226
-4
lines changed

src/datasets/arrow_dataset.py

Lines changed: 52 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,7 @@
9393
from .utils.file_utils import _retry, estimate_dataset_size
9494
from .utils.info_utils import is_small_dataset
9595
from .utils.py_utils import convert_file_size_to_int, temporary_assignment, unique_values
96+
from .utils.stratify import stratified_shuffle_split_generate_indices
9697
from .utils.typing import PathLike
9798

9899

@@ -3255,6 +3256,7 @@ def train_test_split(
32553256
test_size: Union[float, int, None] = None,
32563257
train_size: Union[float, int, None] = None,
32573258
shuffle: bool = True,
3259+
stratify_by_column: Optional[str] = None,
32583260
seed: Optional[int] = None,
32593261
generator: Optional[np.random.Generator] = None,
32603262
keep_in_memory: bool = False,
@@ -3281,6 +3283,7 @@ def train_test_split(
32813283
If int, represents the absolute number of train samples.
32823284
If None, the value is automatically set to the complement of the test size.
32833285
shuffle (:obj:`bool`, optional, default `True`): Whether or not to shuffle the data before splitting.
3286+
stratify_by_column (:obj:`str`, optional, default `None`): The column name of labels to be used to perform stratified split of data.
32843287
seed (:obj:`int`, optional): A seed to initialize the default BitGenerator if ``generator=None``.
32853288
If None, then fresh, unpredictable entropy will be pulled from the OS.
32863289
If an int or array_like[ints] is passed, then it will be passed to SeedSequence to derive the initial BitGenerator state.
@@ -3320,6 +3323,24 @@ def train_test_split(
33203323
33213324
# set a seed
33223325
>>> ds = ds.train_test_split(test_size=0.2, seed=42)
3326+
3327+
# stratified split
3328+
>>> ds = load_dataset("imdb",split="train")
3329+
Dataset({
3330+
features: ['text', 'label'],
3331+
num_rows: 25000
3332+
})
3333+
>>> ds = ds.train_test_split(test_size=0.2, stratify_by_column="label")
3334+
DatasetDict({
3335+
train: Dataset({
3336+
features: ['text', 'label'],
3337+
num_rows: 20000
3338+
})
3339+
test: Dataset({
3340+
features: ['text', 'label'],
3341+
num_rows: 5000
3342+
})
3343+
})
33233344
```
33243345
"""
33253346
from .dataset_dict import DatasetDict # import here because of circular dependency
@@ -3437,15 +3458,42 @@ def train_test_split(
34373458
),
34383459
}
34393460
)
3440-
34413461
if not shuffle:
3462+
if stratify_by_column is not None:
3463+
raise ValueError("Stratified train/test split is not implemented for `shuffle=False`")
34423464
train_indices = np.arange(n_train)
34433465
test_indices = np.arange(n_train, n_train + n_test)
34443466
else:
3467+
# stratified partition
3468+
if stratify_by_column is not None:
3469+
if stratify_by_column not in self.features.keys():
3470+
raise ValueError(f"Key {stratify_by_column} not found in {self.features.keys()}")
3471+
if not isinstance(self.features[stratify_by_column], ClassLabel):
3472+
raise ValueError(
3473+
f"Stratifying by column is only supported for {ClassLabel.__name__} column, and column {stratify_by_column} is {type(self.features[stratify_by_column]).__name__}."
3474+
)
3475+
try:
3476+
train_indices, test_indices = next(
3477+
stratified_shuffle_split_generate_indices(
3478+
self.with_format("numpy")[stratify_by_column], n_train, n_test, rng=generator
3479+
)
3480+
)
3481+
except Exception as error:
3482+
if str(error) == "Minimum class count error":
3483+
raise ValueError(
3484+
f"The least populated class in {stratify_by_column} column has only 1"
3485+
" member, which is too few. The minimum"
3486+
" number of groups for any class cannot"
3487+
" be less than 2."
3488+
)
3489+
else:
3490+
raise error
3491+
34453492
# random partition
3446-
permutation = generator.permutation(len(self))
3447-
test_indices = permutation[:n_test]
3448-
train_indices = permutation[n_test : (n_test + n_train)]
3493+
else:
3494+
permutation = generator.permutation(len(self))
3495+
test_indices = permutation[:n_test]
3496+
train_indices = permutation[n_test : (n_test + n_train)]
34493497

34503498
train_split = self.select(
34513499
indices=train_indices,

src/datasets/utils/stratify.py

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
import numpy as np
2+
3+
4+
def approximate_mode(class_counts, n_draws, rng):
5+
"""Computes approximate mode of multivariate hypergeometric.
6+
This is an approximation to the mode of the multivariate
7+
hypergeometric given by class_counts and n_draws.
8+
It shouldn't be off by more than one.
9+
It is the mostly likely outcome of drawing n_draws many
10+
samples from the population given by class_counts.
11+
Args
12+
----------
13+
class_counts : ndarray of int
14+
Population per class.
15+
n_draws : int
16+
Number of draws (samples to draw) from the overall population.
17+
rng : random state
18+
Used to break ties.
19+
Returns
20+
-------
21+
sampled_classes : ndarray of int
22+
Number of samples drawn from each class.
23+
np.sum(sampled_classes) == n_draws
24+
25+
"""
26+
# this computes a bad approximation to the mode of the
27+
# multivariate hypergeometric given by class_counts and n_draws
28+
continuous = n_draws * class_counts / class_counts.sum()
29+
# floored means we don't overshoot n_samples, but probably undershoot
30+
floored = np.floor(continuous)
31+
# we add samples according to how much "left over" probability
32+
# they had, until we arrive at n_samples
33+
need_to_add = int(n_draws - floored.sum())
34+
if need_to_add > 0:
35+
remainder = continuous - floored
36+
values = np.sort(np.unique(remainder))[::-1]
37+
# add according to remainder, but break ties
38+
# randomly to avoid biases
39+
for value in values:
40+
(inds,) = np.where(remainder == value)
41+
# if we need_to_add less than what's in inds
42+
# we draw randomly from them.
43+
# if we need to add more, we add them all and
44+
# go to the next value
45+
add_now = min(len(inds), need_to_add)
46+
inds = rng.choice(inds, size=add_now, replace=False)
47+
floored[inds] += 1
48+
need_to_add -= add_now
49+
if need_to_add == 0:
50+
break
51+
return floored.astype(np.int)
52+
53+
54+
def stratified_shuffle_split_generate_indices(y, n_train, n_test, rng, n_splits=10):
55+
"""
56+
57+
Provides train/test indices to split data in train/test sets.
58+
It's reference is taken from StratifiedShuffleSplit implementation
59+
of scikit-learn library.
60+
61+
Args
62+
----------
63+
64+
n_train : int,
65+
represents the absolute number of train samples.
66+
67+
n_test : int,
68+
represents the absolute number of test samples.
69+
70+
random_state : int or RandomState instance, default=None
71+
Controls the randomness of the training and testing indices produced.
72+
Pass an int for reproducible output across multiple function calls.
73+
74+
n_splits : int, default=10
75+
Number of re-shuffling & splitting iterations.
76+
"""
77+
classes, y_indices = np.unique(y, return_inverse=True)
78+
n_classes = classes.shape[0]
79+
class_counts = np.bincount(y_indices)
80+
if np.min(class_counts) < 2:
81+
raise ValueError("Minimum class count error")
82+
if n_train < n_classes:
83+
raise ValueError(
84+
"The train_size = %d should be greater or " "equal to the number of classes = %d" % (n_train, n_classes)
85+
)
86+
if n_test < n_classes:
87+
raise ValueError(
88+
"The test_size = %d should be greater or " "equal to the number of classes = %d" % (n_test, n_classes)
89+
)
90+
class_indices = np.split(np.argsort(y_indices, kind="mergesort"), np.cumsum(class_counts)[:-1])
91+
for _ in range(n_splits):
92+
n_i = approximate_mode(class_counts, n_train, rng)
93+
class_counts_remaining = class_counts - n_i
94+
t_i = approximate_mode(class_counts_remaining, n_test, rng)
95+
96+
train = []
97+
test = []
98+
99+
for i in range(n_classes):
100+
permutation = rng.permutation(class_counts[i])
101+
perm_indices_class_i = class_indices[i].take(permutation, mode="clip")
102+
train.extend(perm_indices_class_i[: n_i[i]])
103+
test.extend(perm_indices_class_i[n_i[i] : n_i[i] + t_i[i]])
104+
train = rng.permutation(train)
105+
test = rng.permutation(test)
106+
107+
yield train, test

tests/test_arrow_dataset.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from unittest.mock import patch
1212

1313
import numpy as np
14+
import numpy.testing as npt
1415
import pandas as pd
1516
import pyarrow as pa
1617
import pytest
@@ -3553,3 +3554,69 @@ def test_task_text_classification_when_columns_removed(self):
35533554
with Dataset.from_dict(data, info=info) as dset:
35543555
with dset.map(lambda x: {"new_column": 0}, remove_columns=dset.column_names) as dset:
35553556
self.assertDictEqual(dset.features, features_after_map)
3557+
3558+
3559+
class StratifiedTest(TestCase):
3560+
def test_errors_train_test_split_stratify(self):
3561+
ys = [
3562+
np.array([0, 0, 0, 0, 1, 1, 1, 2, 2, 2, 2, 2]),
3563+
np.array([0, 1, 1, 1, 2, 2, 2, 3, 3, 3]),
3564+
np.array([0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2] * 2),
3565+
np.array([0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5]),
3566+
np.array([0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5]),
3567+
]
3568+
for i in range(len(ys)):
3569+
features = Features({"text": Value("int64"), "label": ClassLabel(len(np.unique(ys[i])))})
3570+
data = {"text": np.ones(len(ys[i])), "label": ys[i]}
3571+
d1 = Dataset.from_dict(data, features=features)
3572+
3573+
# For checking stratify_by_column exist as key in self.features.keys()
3574+
if i == 0:
3575+
self.assertRaises(ValueError, d1.train_test_split, 0.33, stratify_by_column="labl")
3576+
3577+
# For checking minimum class count error
3578+
elif i == 1:
3579+
self.assertRaises(ValueError, d1.train_test_split, 0.33, stratify_by_column="label")
3580+
3581+
# For check typeof label as ClassLabel type
3582+
elif i == 2:
3583+
d1 = Dataset.from_dict(data)
3584+
self.assertRaises(ValueError, d1.train_test_split, 0.33, stratify_by_column="label")
3585+
3586+
# For checking test_size should be greater than or equal to number of classes
3587+
elif i == 3:
3588+
self.assertRaises(ValueError, d1.train_test_split, 0.30, stratify_by_column="label")
3589+
3590+
# For checking train_size should be greater than or equal to number of classes
3591+
elif i == 4:
3592+
self.assertRaises(ValueError, d1.train_test_split, 0.60, stratify_by_column="label")
3593+
3594+
def test_train_test_split_startify(self):
3595+
ys = [
3596+
np.array([0, 0, 0, 0, 1, 1, 1, 2, 2, 2, 2, 2]),
3597+
np.array([0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3]),
3598+
np.array([0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2] * 2),
3599+
np.array([0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3]),
3600+
np.array([0] * 800 + [1] * 50),
3601+
]
3602+
for y in ys:
3603+
features = Features({"text": Value("int64"), "label": ClassLabel(len(np.unique(y)))})
3604+
data = {"text": np.ones(len(y)), "label": y}
3605+
d1 = Dataset.from_dict(data, features=features)
3606+
d1 = d1.train_test_split(test_size=0.33, stratify_by_column="label")
3607+
y = np.asanyarray(y) # To make it indexable for y[train]
3608+
test_size = np.ceil(0.33 * len(y))
3609+
train_size = len(y) - test_size
3610+
npt.assert_array_equal(np.unique(d1["train"]["label"]), np.unique(d1["test"]["label"]))
3611+
3612+
# checking classes proportion
3613+
p_train = np.bincount(np.unique(d1["train"]["label"], return_inverse=True)[1]) / float(
3614+
len(d1["train"]["label"])
3615+
)
3616+
p_test = np.bincount(np.unique(d1["test"]["label"], return_inverse=True)[1]) / float(
3617+
len(d1["test"]["label"])
3618+
)
3619+
npt.assert_array_almost_equal(p_train, p_test, 1)
3620+
assert len(d1["train"]["text"]) + len(d1["test"]["text"]) == y.size
3621+
assert len(d1["train"]["text"]) == train_size
3622+
assert len(d1["test"]["text"]) == test_size

0 commit comments

Comments
 (0)