From 1cbe5f6bcd1e3da6481d34dfc2496bdf565242ac Mon Sep 17 00:00:00 2001 From: Justin Yang Date: Fri, 14 Nov 2025 15:02:22 -0800 Subject: [PATCH] Fix None process group appended in _fuse_input_dist_splits Summary: We identified this potential bug during debugging issue reported in https://fb.workplace.com/groups/755371733754414/permalink/833999072850393/ Fixed a bug in `_fuse_input_dist_splits` where names with no valid process group (pg=None) were being added to `names_per_pg[None]`. This would cause issues downstream when trying to create `FusedKJTListSplitsAwaitable` with a None process group. The issue occurred when: 1. A request is of type `KJTListSplitsAwaitable` 2. None of its awaitables are of type `KJTSplitsAllToAllMeta` 3. This leaves `pg = None` (line 207) 4. The name was still appended to `names_per_pg[None]` (line 213) The fix adds a check to only append names when `pg is not None`, ensuring that only requests with valid process groups are included in the fused operations. Why this matters: - Prevents passing `pg=None` to `FusedKJTListSplitsAwaitable` (line 232) - Ensures only valid distributed operations are fused together - Avoids potential runtime errors or undefined behavior Differential Revision: D87110878 --- torchrec/distributed/train_pipeline/utils.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/torchrec/distributed/train_pipeline/utils.py b/torchrec/distributed/train_pipeline/utils.py index 8bea1ff37..52e72d64f 100644 --- a/torchrec/distributed/train_pipeline/utils.py +++ b/torchrec/distributed/train_pipeline/utils.py @@ -210,7 +210,8 @@ def _fuse_input_dist_splits(context: TrainPipelineContext) -> None: if isinstance(awaitable, KJTSplitsAllToAllMeta): pg = awaitable.pg break - names_per_pg[pg].append(name) + if pg is not None: + names_per_pg[pg].append(name) for pg, names in names_per_pg.items(): context.fused_splits_awaitables.append(