Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 26 additions & 3 deletions keras/src/trainers/data_adapters/py_dataset_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,12 @@ class PyDataset:
multiprocessed setting.
Reduce this value to reduce the CPU memory consumption of
your dataset. Defaults to 10.
shuffle: Whether to shuffle the sample ordering at the end of
each epoch. This argument is passed to `model.fit()`. When
calling `model.fit(..., shuffle=True)`, the training loop
automatically calls `on_epoch_end()` at each epoch
boundary, allowing datasets to implement custom
shuffling logic. Defaults to `False`.

Notes:

Expand All @@ -52,6 +58,9 @@ class PyDataset:
over the dataset. They are not being used by the `PyDataset` class
directly. When you are manually iterating over a `PyDataset`,
no parallelism is applied.
- `shuffle=False` keeps the sample order fixed across epochs.
For distributed or deterministic training prefer
`shuffle=False` and manage the order externally.

Example:

Expand All @@ -66,10 +75,17 @@ class PyDataset:

class CIFAR10PyDataset(keras.utils.PyDataset):

def __init__(self, x_set, y_set, batch_size, **kwargs):
def __init__(self, x_set, y_set, batch_size, shuffle=False, **kwargs):
super().__init__(**kwargs)
self.x, self.y = x_set, y_set
self.batch_size = batch_size
self.shuffle = shuffle
# create index array for shuffling
self.indices = np.arange(len(self.x))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the example, also do np.random.shuffle(self.indices) here when shuffle is True, since the best practice when shuffling is to do it for every epoch, not just epoch >= 1

# Shuffle once at initialization when shuffle=True
if self.shuffle:
np.random.shuffle(self.indices)


def __len__(self):
# Return number of batches.
Expand All @@ -81,12 +97,19 @@ def __getitem__(self, idx):
# Cap upper bound at array length; the last batch may be smaller
# if the total number of items is not a multiple of batch size.
high = min(low + self.batch_size, len(self.x))
batch_x = self.x[low:high]
batch_y = self.y[low:high]
# Retrieve a batch using shuffled indices
batch_indices = self.indices[low:high]
batch_x = self.x[batch_indices]
batch_y = self.y[batch_indices]

return np.array([
resize(imread(file_name), (200, 200))
for file_name in batch_x]), np.array(batch_y)

def on_epoch_end(self):
# Shuffle indices at the end of each epoch if enabled
if self.shuffle:
np.random.shuffle(self.indices)
```
"""

Expand Down