-
Notifications
You must be signed in to change notification settings - Fork 90
Description
🚀 Feature
Hi Litdata maintainers. Is there any plan to make the StreamingDataset wrappers, such as ParallelStreamingDataset, able to handle StreamingDatasets that return more than one sample per iteration? This would be useful when performing aggregation operations on the fly, such as sequence packing. More details below.
Motivation
If we want to perform an aggregation operation on the fly, such as packing samples from a StreamingDataset, we would return multiple samples for a single next() iteration. However, certain wrapper iterators, such as the _ParallelDatasetIterator, increment the number of samples yielded by only 1 at each iteration:
litData/src/litdata/streaming/parallel.py
Line 365 in 36431bd
| self._num_samples_yielded[i] = 1 if _reset else self._num_samples_yielded[i] + 1 |
Pitch
Like _ParallelDatasetIterator returns to callers of next a list containing the number of samples yielded for each dataset:
litData/src/litdata/streaming/parallel.py
Line 372 in 36431bd
| __NUM_SAMPLES_YIELDED_KEY__: self._num_samples_yielded, |
it would be useful to use a similar reserved key to pass the number of samples yielded by a StreamingDataset to the _ParallelDatasetIterator by modifying
litData/src/litdata/streaming/parallel.py
Line 365 in 36431bd
| self._num_samples_yielded[i] = 1 if _reset else self._num_samples_yielded[i] + 1 |
with something like:
self._num_samples_yielded[i] = sample.get("__NUM_SAMPLES_YIELDED_KEY__", 1) if _reset else self._num_samples_yielded[i] + sample.get("__NUM_SAMPLES_YIELDED_KEY__", 1)
Alternatives
The aggregation can also be done during the optimization step, at the cost of rebuilding the dataset each time we want to change the aggregation.
Thank you in advance.