Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 15 additions & 9 deletions torchrec/distributed/tensor_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -474,9 +474,7 @@ def forward(self, ids: torch.Tensor) -> torch.Tensor:
dist_input, unbucketize_permute, bucket_mapping, bucketized_lengths = (
self._lookup_ids_dist(ids)
)
unbucketize_permute_non_opt = _fx_item_unwrap_optional_tensor(
unbucketize_permute
)
unbucketize_permute_non_opt = unbucketize_permute

lookup = self._lookup_local(dist_input)

Expand Down Expand Up @@ -513,12 +511,20 @@ def forward(self, ids: torch.Tensor) -> torch.Tensor:
)

output = self._lookup_values_dist(lookup_list)

return index_select_view(
output,
unbucketize_permute_non_opt.to(device=output.device),
self._dim,
)
# When memory_capacity_per_rank is added then boundary split for the
# model is different. Handling device movement accordingly
if self._sharding_plan.memory_capacity_per_rank is None:
return index_select_view(
output,
unbucketize_permute_non_opt,
self._dim,
)
else:
return index_select_view(
output,
unbucketize_permute_non_opt.to(device=output.device),
self._dim,
)

# pyre-ignore
def _update_values_dist(self, ctx: ObjectPoolShardingContext, values: torch.Tensor):
Expand Down
Loading