Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 20 additions & 7 deletions torchrec/distributed/train_pipeline/train_pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,12 @@ def _reset_data_iter(self) -> None:
self._cur_batch = None

def _connect(self, dataloader_iter: Iterator[In]) -> None:
cur_batch = next(dataloader_iter)
"""
Connect the data iterator to the pipeline when first start the pipeline
It also fetch the first batch from the data iterator and copy batch to gpu
The batch is stored in self._cur_batch
"""
cur_batch = self._next_batch(dataloader_iter)
self._cur_batch = cur_batch
if cur_batch is not None:
if self._inplace_copy_batch_to_gpu:
Expand All @@ -234,18 +239,22 @@ def _connect(self, dataloader_iter: Iterator[In]) -> None:
self._connected = True

def _next_batch(self, dataloader_iter: Iterator[In]) -> Optional[In]:
with record_function("## next_batch ##"):
with record_function("## next batch from dataloader (host) ##"):
try:
next_batch = next(dataloader_iter)
except StopIteration:
self._data_iter_stopped = True
return None

return next_batch

def _wait_for_batch(self, cur_batch: In) -> None:
with record_function("## wait_for_batch ##"):
_wait_for_batch(cur_batch, self._memcpy_stream)
_wait_for_batch(
cur_batch,
self._memcpy_stream,
# no need to record stream when using in-place copy
record_stream=not self._inplace_copy_batch_to_gpu,
)

def _backward(self, losses: torch.Tensor) -> None:
with record_function("## backward ##"):
Expand All @@ -272,10 +281,12 @@ def progress(self, dataloader_iter: Iterator[In]) -> Out:
if self._data_iter_stopped:
raise StopIteration()

# Fetch next batch, if depleted, raise at start of next progress
next_batch = self._next_batch(dataloader_iter)
# get the current batch from previous operation
cur_batch = self._cur_batch

# Fetch next batch from dataloader (host), aise at start of next progress if depleted
next_batch = self._next_batch(dataloader_iter)

# for exhaustive data iter, some ranks will first depletes data,
# but we still need progress the train pipeline for other ranks;
# cur_batch could be None
Expand All @@ -292,11 +303,13 @@ def progress(self, dataloader_iter: Iterator[In]) -> Out:
with record_function("## forward ##"):
(losses, output) = self._model(cur_batch)

# clear the current batch after forward pass (so current batch can be freed)
self._cur_batch = cur_batch = next_batch

if self._model.training:
self._backward(losses)

# Copy the next batch to GPU
self._cur_batch = cur_batch = next_batch
if cur_batch is not None:
self._copy_batch_to_gpu(cur_batch)

Expand Down
24 changes: 17 additions & 7 deletions torchrec/distributed/train_pipeline/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,9 @@ def _to_device(
)


def _wait_for_batch(batch: In, stream: Optional[torch.Stream]) -> None:
def _wait_for_batch(
batch: In, stream: Optional[torch.Stream], record_stream: bool = True
) -> None:
"""
As mentioned in
https://pytorch.org/docs/stable/generated/torch.Tensor.record_stream.html, PyTorch
Expand All @@ -113,13 +115,21 @@ def _wait_for_batch(batch: In, stream: Optional[torch.Stream]) -> None:
if stream is None:
return

# batch is loaded/processed in the given stream, but will be used in the current
device = stream.device
torch.get_device_module(device).current_stream().wait_stream(stream)
cur_stream = torch.get_device_module(device).current_stream()
assert isinstance(
batch, (torch.Tensor, Multistreamable)
), f"{type(batch)} must implement Multistreamable interface"
batch.record_stream(cur_stream)
curr_stream = torch.get_device_module(device).current_stream()

# current stream needs to wait for the given stream to complete
curr_stream.wait_stream(stream)

# record_stream is needed when the batch is created (allocated) in the given stream
# but used by another stream (e.g., the current stream), however, when the batch is
# created in the current stream (in-place copy), we don't need to call
if record_stream:
assert isinstance(
batch, (torch.Tensor, Multistreamable)
), f"{type(batch)} must implement Multistreamable interface"
batch.record_stream(curr_stream)


def _wait_for_events(
Expand Down
Loading