From 0743ee119c189eaac622d1c9a6e6953dbb964c9b Mon Sep 17 00:00:00 2001 From: Nipun Gupta Date: Thu, 9 Oct 2025 08:00:19 -0700 Subject: [PATCH] Fix input_dist for when it has to handle KVZCH and non-KVZCH tables together for row-wise sharding (#3444) Summary: In the KVZCH scenarios, the input_dist for row-wise sharding will have to deal with KVZCH and non-KVZCH features together. This means that the virtual_table_feature_num_buckets will have to represent non-KVZCH features too. In this diff, we default the value of the num_buckets to world_size for non-KVZCH features. Differential Revision: D84094039 --- torchrec/distributed/sharding/rw_sharding.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/torchrec/distributed/sharding/rw_sharding.py b/torchrec/distributed/sharding/rw_sharding.py index 136052137..8a1c10296 100644 --- a/torchrec/distributed/sharding/rw_sharding.py +++ b/torchrec/distributed/sharding/rw_sharding.py @@ -301,6 +301,11 @@ def _get_writable_feature_hash_sizes(self) -> List[int]: return feature_hash_sizes def _get_virtual_table_feature_num_buckets(self) -> List[int]: + """ + Returns the number of buckets for each KVZCH feature in the GroupedEmbeddingConfigs. + If a feature is not a KVZCH feature, the list will have world_size for that feature's corresponding position. + This is needed as KVZCH features have to be processed for input_dist with non-KVZCH features. + """ feature_num_buckets: List[int] = [] for group_config in self._grouped_embedding_configs: for embedding_table in group_config.embedding_tables: @@ -312,6 +317,10 @@ def _get_virtual_table_feature_num_buckets(self) -> List[int]: [embedding_table.total_num_buckets] * embedding_table.num_features() ) + else: + feature_num_buckets.extend( + [self._world_size] * embedding_table.num_features() + ) return feature_num_buckets