DreamStream is a Python package for making PyTorch models work efficiently in online settings such as speech recognition.
DreamStream introduces a new method called patch_module which augments an nn.Module with the ability to process StreamTensors. This ability is added to the module in the form of a new processing mode called online that is orthogonal to the existing train and eval modes.
import dreamstream as ds
model = MyModel()
ds.patch_module(model)DreamStream introduces one primary data structure called a StreamTensor which is a subclass of torch.Tensor.
| Module mode | Tensor |
StreamTensor |
|---|---|---|
| Offline - Train | β Standard behavoiur | β Not supported |
| Offline - Eval | β Standard behavoiur | β Not supported |
| Online - Train | β Not supported | β DreamStream behaviour |
| Online - Eval | β Not supported | β DreamStream behaviour |
DreamStream behaviour
If a module has been patched with patch_module and then set to online, DreamStream augments the standard PyTorch behaviour to support online processing.
This means, that any call to .forward() now expects that all torch.Tensor inputs are replaced by StreamTensor inputs.
The patched module, and its child modules, will then keep buffers of the internal state of the module for each StreamTensor input keyed by the StreamTensor's StreamMetaData.ids.
We support online processing using StreamTensors both for the train and eval modes.
Online - Eval
If the module is online and in
evalmode, any call to forward ...
Online - Train
If the module is online in
trainmode, any call to forward ...We support
trainmode in the following variations:
- Chunk-wise backpropagation with aligned targets: If
patch_modulewas called withtrain_variant="chunk-wise-aligned", a full forward-backward pass is performed on each chunk but with the forward pass conditioned on the detached (.detach()) state of the previous chunk that was forwarded on thisid. This requires that each chunk of a larger file has targets, i.e. a single file level target is not supported. This mode is constant memory complexity in terms of the number of chunks and therefore enables training a large files. However, it does not backpropagate gradients through the entire file but only within each chunk.- Full backpropagation with aligned targets: If
patch_modulewas called withtrain_variant="full-file-aligned", a forward pass is performed on each chunk with the forward pass conditioned on the state (not detached) of the previous chunk that was forward on thisid. For each chunk we compute the loss which requires chunk-aligned targets as for 1, and after the an entire file has been processed, we perform a backward pass on the total loss. This mode only provides memory savings compared to naively forward-backward passing the entire file at once if the model has superlinear memory complexity in terms of the sequence length. This mode is therefore not recommended for models with linear memory complexity such as RNNs and CNNs.- Full backpropagation with unaligned targets: If
patch_modulewas called withtrain_variant="full-file-unaligned", behaviour is like 2, but we accept a single target for a file (do not require chunk-aligned targets). Instead, from the chunk-wise forward calls, we accumulate the outputs needed to compute the total loss on the file level. Once a file has ended, we then compute the total loss and perform a backward pass. As for 2., this mode is only recommended for models with linear memory complexity such as RNNs and CNNs.- Standard training: If the
StreamTensorpassed to.forward()represents the entire file (i.e. has allsosand alleosTrue), the onlinetrainmode behaves like the standard PyTorch offlinetrainmode, as could be expected. This will be the behaviour regardless of thetrain_variantdefined inpatch_module.
DreamStream supports TorchScript and ONNX export of patched modules. To compile a model that has DreamStream behaviour, the scripting, tracing and exporting must be done on a module in online mode. If in offline mode, the model would be exported as a standard PyTorch model.
Exporting the patched model using the dynamic_axes=False argument to torch.onnx.export will export the model such that it works for a fixed chunk size. This is the most common case since models are usually served for streaming using a constant chunk size. Alternatively, if dynamic_axes=True, the model will be exported such that it works for any chunk size. This is useful if the model is to be used for streaming with variable chunk sizes but comes at the cost of a performance penalty.
ds.patch_module
ds.StreamModule # automatically patches itself after __init__
ds.stream_tensor
ds.StreamTensor
ds.meta
ds.StreamMetadata
ds.ChunkModule # maybe functional similar to ds.patch_module
class MyVerySpecialModel(StreamModule):
def __init__(self, asd, asd,asd):
super().__init__()
self.linear = nn.Linear(10, 10)
mm = MyVerySpecialModel()
class StreamModule(nn.Module):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def when_init_is_done(self):
ds.patch_module(self)PyTorch models are typically trained and evaluated on batches of data. However, in online settings, such as speech recognition, data is often streamed in one sample at a time. This means that the model must be able to process a single sample at a time, and that the model must be able to process the sample as soon as it is received. This is not possible with the standard PyTorch API, which requires the entire example to be collected before the model can be evaluated.
- Real-time processing: Return output as soon as the model has processed the input, reducing latency.
- Scalability: Reduce memory requirements for modules with superlinear memory complexity (e.g. transformers) by processing input in smaller samples.
- Transcribe a large number of files stored on disk (for example). The files are already fully available, but the files may be too long to process in full length due to memory.
- Transcribe a live audio stream. The audio is not available in full length, but must be processed as it is received. Such a stream could be a live audio stream from a microphone, or a stream of audio chunks from a network connection.
- Lack of compatibility with existing PyTorch inference frameworks (ONNX, TorchScript, etc.)
- Padding of first sample among a batch of samples, but not others (
torch.roll) - Handling samples in a batch that reach zero length at some layer during a forward pass.
- Efficient file loading when processing files in bulk (workers get copies of the dataset, but must not read the same files).
- (
worker_init_fn,torch.utils.data.get_worker_info()) - --> Sequential sampler, dataset, dataloader with single worker.
- (
- Passing tensor with sequence lengths, first, last, and id among layers without breaking:
- Batch object with data, sl, first, last, id attributes: Elementwise operations implemented outside Modules will break.
- Subclass of Tensor with sl, first, last, id attributes: Elementwise operations implemented outside Modules will work but usually return new Tensor objects without these attributes.
- This can work by implementing a custom
__torch_function__method that returns a new Tensor object with the attributes. - We probably still need to deal with:
- operations that reorder the batch dimension (e.g.
torch.sortor indexing): Apply the same reordering to the attributes. This happens incollate_fnand beforetorch.nn.utils.rnn.pack_padded_sequence. - operations such as
cat,stack,vstack, andhstackwhen used along the batch dim: Concatenate also the attributes along the batch dim. This happens incollate_fnbut could also happen in models. - operations that reduce the batch dimension (e.g.
torch.sum): Reduce also the attributes along the batch dim. This happens in losses. - operations that create new dimensions: Adjust
batch_dimandstream_dimaccordlingly. This could happen anywhere.
- operations that reorder the batch dimension (e.g.
- This can work by implementing a custom
- Seperate arguments to forward method: We can change all patched forward methods to take additional arguments, but, by default they won't be given (and we have no control).
- How do we deal with operations that are invalid on one or more StreamTensors?
- Examples:
- Adding two StreamTensors with different
ids. - Combining batch or length dimensions with one or more other dimensions into a single dimension using e.g.
torch.reshape,torch.flattenor masked indexing.
- Adding two StreamTensors with different
- Options:
- Fail outright.
- Fallback to a regular
torch.Tensor. - Fallback to a different tensor subclass that is identical in behaviour to
torch.Tensorbut carries the frozenStreamMetadataalong.
- Examples:
- How do we deal with
- Special tokens concatenated to the input? E.g. "translate" and "language" tokens in Whisper?
- Learnable tokens concatenated to the input sequence before an MHSA layer?
- Support loading/saving of named tensors by custom
__reduce__or__reduce_ex__.
- Case 1: Forward pass each chunk and compute loss. Detach module stream buffers after each module forward. Backpropagate and update parameters. This is O(1) memory in terms of the number of chunks for convolutions and RNNs but O(N^2) for Transformers.
- Case 2: Forward pass each chunk and collect logits, but do not detach stream buffers. Concatenate logits and compute loss. Backpropagate. This will backpropagate through the entire stream. This is less memory efficient for convolutions and RNNs O(N), in terms of the number of chunks.
- Patched module can provide estimate of state size (can be used to estimate memory requirements for given input lengths).
- Length-based sampling: When processing files in bulk, first sort files by length to minimize padding.
- Maybe we can support the EmFormer architecture.
Documentation is written in Sphinx and is inspired by PyTorch.
conda deactivate
conda create -y -n dreamstream python==3.11
conda activate dreamstream
pip install --upgrade --editable .
pip install -r requirements.txtclass Input():
def __init__(self, data):
self.data = data
self.first = first
self.last = last
class Batch(Iterable):
def __init__(self, inputs: List[Input]):
super().__init__(data)
self.inputs = inputs
self.is_collated = False
def append(self, input: Input):
self.inputs.append(input)
self.is_collated = False
def extend(self, inputs: Union[List[Input], Batch]):
self.inputs.extend(inputs)
self.is_collated = False
def collate(self):
self.data = collate_fn([input.data for input in self.inputs])
self.is_collated = True
return self
def stream_forward(x: Union[Single, Batch]):
if not x.is_collated:
x.collate()
if not self.is_streaming:
return self.original_forward(x.data)
# ... streamnig logic
x = self.original_forward(x.data)
# ... more streaming logic
return x