diff --git a/torchrec/distributed/train_pipeline/utils.py b/torchrec/distributed/train_pipeline/utils.py index 8bea1ff37..52e72d64f 100644 --- a/torchrec/distributed/train_pipeline/utils.py +++ b/torchrec/distributed/train_pipeline/utils.py @@ -210,7 +210,8 @@ def _fuse_input_dist_splits(context: TrainPipelineContext) -> None: if isinstance(awaitable, KJTSplitsAllToAllMeta): pg = awaitable.pg break - names_per_pg[pg].append(name) + if pg is not None: + names_per_pg[pg].append(name) for pg, names in names_per_pg.items(): context.fused_splits_awaitables.append(