diff --git a/keras/src/trainers/data_adapters/py_dataset_adapter.py b/keras/src/trainers/data_adapters/py_dataset_adapter.py index 18865af026cf..88177a39eca3 100644 --- a/keras/src/trainers/data_adapters/py_dataset_adapter.py +++ b/keras/src/trainers/data_adapters/py_dataset_adapter.py @@ -40,6 +40,12 @@ class PyDataset: multiprocessed setting. Reduce this value to reduce the CPU memory consumption of your dataset. Defaults to 10. + shuffle: Whether to shuffle the sample ordering at the end of + each epoch. This argument is passed to `model.fit()`. When + calling `model.fit(..., shuffle=True)`, the training loop + automatically calls `on_epoch_end()` at each epoch + boundary, allowing datasets to implement custom + shuffling logic. Defaults to `False`. Notes: @@ -52,6 +58,9 @@ class PyDataset: over the dataset. They are not being used by the `PyDataset` class directly. When you are manually iterating over a `PyDataset`, no parallelism is applied. + - `shuffle=False` keeps the sample order fixed across epochs. + For distributed or deterministic training prefer + `shuffle=False` and manage the order externally. Example: @@ -66,10 +75,17 @@ class PyDataset: class CIFAR10PyDataset(keras.utils.PyDataset): - def __init__(self, x_set, y_set, batch_size, **kwargs): + def __init__(self, x_set, y_set, batch_size, shuffle=False, **kwargs): super().__init__(**kwargs) self.x, self.y = x_set, y_set self.batch_size = batch_size + self.shuffle = shuffle + # create index array for shuffling + self.indices = np.arange(len(self.x)) + # Shuffle once at initialization when shuffle=True + if self.shuffle: + np.random.shuffle(self.indices) + def __len__(self): # Return number of batches. @@ -81,12 +97,19 @@ def __getitem__(self, idx): # Cap upper bound at array length; the last batch may be smaller # if the total number of items is not a multiple of batch size. high = min(low + self.batch_size, len(self.x)) - batch_x = self.x[low:high] - batch_y = self.y[low:high] + # Retrieve a batch using shuffled indices + batch_indices = self.indices[low:high] + batch_x = self.x[batch_indices] + batch_y = self.y[batch_indices] return np.array([ resize(imread(file_name), (200, 200)) for file_name in batch_x]), np.array(batch_y) + + def on_epoch_end(self): + # Shuffle indices at the end of each epoch if enabled + if self.shuffle: + np.random.shuffle(self.indices) ``` """