diff --git a/torchrec/datasets/random.py b/torchrec/datasets/random.py index 9008622e5..80812705e 100644 --- a/torchrec/datasets/random.py +++ b/torchrec/datasets/random.py @@ -212,7 +212,9 @@ def __init__( num_generated_batches=num_generated_batches, min_ids_per_features=min_ids_per_features, ) - self.num_batches: int = cast(int, num_batches if not None else sys.maxsize) + self.num_batches: int = cast( + int, num_batches if num_batches is not None else sys.maxsize + ) def __iter__(self) -> Iterator[Batch]: return itertools.islice(iter(self.batch_generator), self.num_batches)