Skip to content

Commit f52f958

Browse files
Sadeep Jayasumanacopybara-github
authored andcommitted
Adds reshuffle_each_iteration argument to deterministic_data.create_dataset().
This argument is passed to `tf.data.Dataset.shuffle()` and controls whether the dataset is reshuffled each time it is iterated over. The default value is `None`, which is the same as the default value of `reshuffle_each_iteration` in `tf.data.Dataset.shuffle()`. This change is being made to support the use of `deterministic_data.create_dataset()` in evaluation loops that need to access the same evaluation data batches in each iteration of the dataset without reshuffling before each iteration/epoch over the dataset. This is useful, for example, in visualizing the progress of image generation models at different model checkpoints. Visualizing the model progress on the same evaluation data makes Tensorboard qualitative evaluation easier. This change is backwards compatible. If the `reshuffle_each_iteration` argument is not specified, the default value of `None` will be used. PiperOrigin-RevId: 661355447
1 parent b64aa29 commit f52f958

File tree

1 file changed

+9
-1
lines changed

1 file changed

+9
-1
lines changed

clu/deterministic_data.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -372,6 +372,7 @@ def create_dataset(dataset_builder: DatasetBuilder,
372372
num_epochs: Optional[int] = None,
373373
shuffle: bool = True,
374374
shuffle_buffer_size: int = 10_000,
375+
reshuffle_each_iteration: Optional[bool] = None,
375376
prefetch_size: int = 4,
376377
pad_up_to_batches: Optional[Union[int, str]] = None,
377378
cardinality: Optional[int] = None,
@@ -402,6 +403,9 @@ def create_dataset(dataset_builder: DatasetBuilder,
402403
forever.
403404
shuffle: Whether to shuffle the dataset (both on file and example level).
404405
shuffle_buffer_size: Number of examples in the shuffle buffer.
406+
reshuffle_each_iteration: A boolean, which if true indicates that the
407+
dataset should be pseudorandomly reshuffled each time it is iterated over.
408+
(Defaults to `True`.)
405409
prefetch_size: The number of elements in the final dataset to prefetch in
406410
the background. This should be a small (say <10) positive integer or
407411
tf.data.experimental.AUTOTUNE.
@@ -453,7 +457,11 @@ def create_dataset(dataset_builder: DatasetBuilder,
453457
ds = ds.cache()
454458

455459
if shuffle:
456-
ds = ds.shuffle(shuffle_buffer_size, seed=rngs.pop()[0])
460+
ds = ds.shuffle(
461+
shuffle_buffer_size,
462+
seed=rngs.pop()[0],
463+
reshuffle_each_iteration=reshuffle_each_iteration,
464+
)
457465
ds = ds.repeat(num_epochs)
458466

459467
if preprocess_fn is not None:

0 commit comments

Comments
 (0)