Skip to content

ParallelStreamingDataset does not support StreamingDataset returning more than one sample per iteration #801

@albertoveneri

Description

@albertoveneri

🚀 Feature

Hi Litdata maintainers. Is there any plan to make the StreamingDataset wrappers, such as ParallelStreamingDataset, able to handle StreamingDatasets that return more than one sample per iteration? This would be useful when performing aggregation operations on the fly, such as sequence packing. More details below.

Motivation

If we want to perform an aggregation operation on the fly, such as packing samples from a StreamingDataset, we would return multiple samples for a single next() iteration. However, certain wrapper iterators, such as the _ParallelDatasetIterator, increment the number of samples yielded by only 1 at each iteration:

self._num_samples_yielded[i] = 1 if _reset else self._num_samples_yielded[i] + 1

Pitch

Like _ParallelDatasetIterator returns to callers of next a list containing the number of samples yielded for each dataset:

__NUM_SAMPLES_YIELDED_KEY__: self._num_samples_yielded,

it would be useful to use a similar reserved key to pass the number of samples yielded by a StreamingDataset to the _ParallelDatasetIterator by modifying

self._num_samples_yielded[i] = 1 if _reset else self._num_samples_yielded[i] + 1

with something like:

self._num_samples_yielded[i] = sample.get("__NUM_SAMPLES_YIELDED_KEY__", 1) if _reset else self._num_samples_yielded[i] + sample.get("__NUM_SAMPLES_YIELDED_KEY__", 1)

Alternatives

The aggregation can also be done during the optimization step, at the cost of rebuilding the dataset each time we want to change the aggregation.

Thank you in advance.

Metadata

Metadata

Assignees

No one assigned

    Labels

    enhancementNew feature or request

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions