diff --git a/torchrec/distributed/train_pipeline/train_pipelines.py b/torchrec/distributed/train_pipeline/train_pipelines.py index 53d654680..a356b77ff 100644 --- a/torchrec/distributed/train_pipeline/train_pipelines.py +++ b/torchrec/distributed/train_pipeline/train_pipelines.py @@ -216,7 +216,12 @@ def _reset_data_iter(self) -> None: self._cur_batch = None def _connect(self, dataloader_iter: Iterator[In]) -> None: - cur_batch = next(dataloader_iter) + """ + Connect the data iterator to the pipeline when first start the pipeline + It also fetch the first batch from the data iterator and copy batch to gpu + The batch is stored in self._cur_batch + """ + cur_batch = self._next_batch(dataloader_iter) self._cur_batch = cur_batch if cur_batch is not None: if self._inplace_copy_batch_to_gpu: @@ -234,18 +239,22 @@ def _connect(self, dataloader_iter: Iterator[In]) -> None: self._connected = True def _next_batch(self, dataloader_iter: Iterator[In]) -> Optional[In]: - with record_function("## next_batch ##"): + with record_function("## next batch from dataloader (host) ##"): try: next_batch = next(dataloader_iter) except StopIteration: self._data_iter_stopped = True return None - return next_batch def _wait_for_batch(self, cur_batch: In) -> None: with record_function("## wait_for_batch ##"): - _wait_for_batch(cur_batch, self._memcpy_stream) + _wait_for_batch( + cur_batch, + self._memcpy_stream, + # no need to record stream when using in-place copy + record_stream=not self._inplace_copy_batch_to_gpu, + ) def _backward(self, losses: torch.Tensor) -> None: with record_function("## backward ##"): @@ -272,10 +281,12 @@ def progress(self, dataloader_iter: Iterator[In]) -> Out: if self._data_iter_stopped: raise StopIteration() - # Fetch next batch, if depleted, raise at start of next progress - next_batch = self._next_batch(dataloader_iter) + # get the current batch from previous operation cur_batch = self._cur_batch + # Fetch next batch from dataloader (host), aise at start of next progress if depleted + next_batch = self._next_batch(dataloader_iter) + # for exhaustive data iter, some ranks will first depletes data, # but we still need progress the train pipeline for other ranks; # cur_batch could be None @@ -292,11 +303,13 @@ def progress(self, dataloader_iter: Iterator[In]) -> Out: with record_function("## forward ##"): (losses, output) = self._model(cur_batch) + # clear the current batch after forward pass (so current batch can be freed) + self._cur_batch = cur_batch = next_batch + if self._model.training: self._backward(losses) # Copy the next batch to GPU - self._cur_batch = cur_batch = next_batch if cur_batch is not None: self._copy_batch_to_gpu(cur_batch) diff --git a/torchrec/distributed/train_pipeline/utils.py b/torchrec/distributed/train_pipeline/utils.py index ce1a30554..ac0f2ef7f 100644 --- a/torchrec/distributed/train_pipeline/utils.py +++ b/torchrec/distributed/train_pipeline/utils.py @@ -97,7 +97,9 @@ def _to_device( ) -def _wait_for_batch(batch: In, stream: Optional[torch.Stream]) -> None: +def _wait_for_batch( + batch: In, stream: Optional[torch.Stream], record_stream: bool = True +) -> None: """ As mentioned in https://pytorch.org/docs/stable/generated/torch.Tensor.record_stream.html, PyTorch @@ -113,13 +115,21 @@ def _wait_for_batch(batch: In, stream: Optional[torch.Stream]) -> None: if stream is None: return + # batch is loaded/processed in the given stream, but will be used in the current device = stream.device - torch.get_device_module(device).current_stream().wait_stream(stream) - cur_stream = torch.get_device_module(device).current_stream() - assert isinstance( - batch, (torch.Tensor, Multistreamable) - ), f"{type(batch)} must implement Multistreamable interface" - batch.record_stream(cur_stream) + curr_stream = torch.get_device_module(device).current_stream() + + # current stream needs to wait for the given stream to complete + curr_stream.wait_stream(stream) + + # record_stream is needed when the batch is created (allocated) in the given stream + # but used by another stream (e.g., the current stream), however, when the batch is + # created in the current stream (in-place copy), we don't need to call + if record_stream: + assert isinstance( + batch, (torch.Tensor, Multistreamable) + ), f"{type(batch)} must implement Multistreamable interface" + batch.record_stream(curr_stream) def _wait_for_events(