From 0382564f1d7917463329f7fee721306e8c798610 Mon Sep 17 00:00:00 2001 From: Faran Ahmad Date: Tue, 18 Nov 2025 06:36:12 -0800 Subject: [PATCH] Keep the logic for inference tensorpool forward consistent w/ set up before hetero sharding Summary: Keep the logic for inference tensorpool forward consistent w/ set up before hetero sharding. Using optional tensor wrapper is interfering with lowering jobs as the model split boundary are different when tensorpool + TBE exist together Differential Revision: D87326553 --- torchrec/distributed/tensor_pool.py | 24 +++++++++++++++--------- 1 file changed, 15 insertions(+), 9 deletions(-) diff --git a/torchrec/distributed/tensor_pool.py b/torchrec/distributed/tensor_pool.py index 80d0abbd3..dba4476b9 100644 --- a/torchrec/distributed/tensor_pool.py +++ b/torchrec/distributed/tensor_pool.py @@ -474,9 +474,7 @@ def forward(self, ids: torch.Tensor) -> torch.Tensor: dist_input, unbucketize_permute, bucket_mapping, bucketized_lengths = ( self._lookup_ids_dist(ids) ) - unbucketize_permute_non_opt = _fx_item_unwrap_optional_tensor( - unbucketize_permute - ) + unbucketize_permute_non_opt = unbucketize_permute lookup = self._lookup_local(dist_input) @@ -513,12 +511,20 @@ def forward(self, ids: torch.Tensor) -> torch.Tensor: ) output = self._lookup_values_dist(lookup_list) - - return index_select_view( - output, - unbucketize_permute_non_opt.to(device=output.device), - self._dim, - ) + # When memory_capacity_per_rank is added then boundary split for the + # model is different. Handling device movement accordingly + if self._sharding_plan.memory_capacity_per_rank is None: + return index_select_view( + output, + unbucketize_permute_non_opt, + self._dim, + ) + else: + return index_select_view( + output, + unbucketize_permute_non_opt.to(device=output.device), + self._dim, + ) # pyre-ignore def _update_values_dist(self, ctx: ObjectPoolShardingContext, values: torch.Tensor):