From dfdc6c42c4555bfaed261840655fe85ac078ff66 Mon Sep 17 00:00:00 2001 From: Nipun Gupta Date: Mon, 17 Nov 2025 19:28:36 -0800 Subject: [PATCH] Enable logging for the plan() function, ShardEstimators and TrainingPipeline class constructors (#3521) Summary: This diff enables the static logging functionality to collect data for: 1) plan() - This will allow us to log the inputs and outputs to the planner to help with use issue debugging 2) ShardEstimators - This will allow us to log the inputs and outputs to the ShardEstimators, which gives us the bandwidth inputs to verify if the planner is generating expected values as well as help with debugging OOMs 3) TrainingPipeline - The class type here will be an indicator of which pipeline was used by the training job. The training pipeline has implications on the memory usage and is an important data point to collect to investigate OOMs. Reviewed By: kausv Differential Revision: D86317910 --- torchrec/distributed/embeddingbag.py | 3 +++ torchrec/distributed/planner/planners.py | 2 ++ torchrec/distributed/planner/shard_estimators.py | 2 ++ torchrec/distributed/shard.py | 3 +++ torchrec/distributed/train_pipeline/train_pipelines.py | 3 +++ torchrec/modules/mc_embedding_modules.py | 3 +++ 6 files changed, 16 insertions(+) diff --git a/torchrec/distributed/embeddingbag.py b/torchrec/distributed/embeddingbag.py index 754b9e6fa..2d43db376 100644 --- a/torchrec/distributed/embeddingbag.py +++ b/torchrec/distributed/embeddingbag.py @@ -55,6 +55,7 @@ FUSED_PARAM_IS_SSD_TABLE, FUSED_PARAM_SSD_TABLE_LIST, ) +from torchrec.distributed.logger import _torchrec_method_logger from torchrec.distributed.sharding.cw_sharding import CwPooledEmbeddingSharding from torchrec.distributed.sharding.dp_sharding import DpPooledEmbeddingSharding from torchrec.distributed.sharding.dynamic_sharding import ( @@ -466,6 +467,7 @@ class ShardedEmbeddingBagCollection( This is part of the public API to allow for manual data dist pipelining. """ + @_torchrec_method_logger() def __init__( self, module: EmbeddingBagCollectionInterface, @@ -2021,6 +2023,7 @@ class ShardedEmbeddingBag( This is part of the public API to allow for manual data dist pipelining. """ + @_torchrec_method_logger() def __init__( self, module: nn.EmbeddingBag, diff --git a/torchrec/distributed/planner/planners.py b/torchrec/distributed/planner/planners.py index 84a21bd12..214015844 100644 --- a/torchrec/distributed/planner/planners.py +++ b/torchrec/distributed/planner/planners.py @@ -18,6 +18,7 @@ from torch import nn from torchrec.distributed.collective_utils import invoke_on_rank_and_broadcast_result from torchrec.distributed.comm import get_local_size +from torchrec.distributed.logger import _torchrec_method_logger from torchrec.distributed.planner.constants import BATCH_SIZE, MAX_SIZE from torchrec.distributed.planner.enumerators import EmbeddingEnumerator from torchrec.distributed.planner.partitioners import ( @@ -498,6 +499,7 @@ def collective_plan( sharders, ) + @_torchrec_method_logger() def plan( self, module: nn.Module, diff --git a/torchrec/distributed/planner/shard_estimators.py b/torchrec/distributed/planner/shard_estimators.py index e9fe723e8..3650324fc 100644 --- a/torchrec/distributed/planner/shard_estimators.py +++ b/torchrec/distributed/planner/shard_estimators.py @@ -16,6 +16,7 @@ import torchrec.optim as trec_optim from torch import nn from torchrec.distributed.embedding_types import EmbeddingComputeKernel +from torchrec.distributed.logger import _torchrec_method_logger from torchrec.distributed.planner.constants import ( BATCHED_COPY_PERF_FACTOR, BIGINT_DTYPE, @@ -955,6 +956,7 @@ class EmbeddingStorageEstimator(ShardEstimator): is_inference (bool): If the model is inference model. Default to False. """ + @_torchrec_method_logger() def __init__( self, topology: Topology, diff --git a/torchrec/distributed/shard.py b/torchrec/distributed/shard.py index 0a27711a7..bc5a811ae 100644 --- a/torchrec/distributed/shard.py +++ b/torchrec/distributed/shard.py @@ -15,6 +15,7 @@ from torch.distributed._composable.contract import contract from torchrec.distributed.comm import get_local_size from torchrec.distributed.global_settings import get_propogate_device +from torchrec.distributed.logger import _torchrec_method_logger from torchrec.distributed.model_parallel import get_default_sharders from torchrec.distributed.planner import EmbeddingShardingPlanner, Topology from torchrec.distributed.sharding_plan import ( @@ -146,6 +147,7 @@ def _shard( # pyre-ignore @contract() +@_torchrec_method_logger() def shard_modules( module: nn.Module, env: Optional[ShardingEnv] = None, @@ -194,6 +196,7 @@ def init_weights(m): return _shard_modules(module, env, device, plan, sharders, init_params) +@_torchrec_method_logger() def _shard_modules( # noqa: C901 module: nn.Module, # TODO: Consolidate to using Dict[str, ShardingEnv] diff --git a/torchrec/distributed/train_pipeline/train_pipelines.py b/torchrec/distributed/train_pipeline/train_pipelines.py index 61c430bd8..4977477f3 100644 --- a/torchrec/distributed/train_pipeline/train_pipelines.py +++ b/torchrec/distributed/train_pipeline/train_pipelines.py @@ -32,6 +32,7 @@ import torch from torch.autograd.profiler import record_function from torchrec.distributed.dist_data import KJTAllToAllTensorsAwaitable +from torchrec.distributed.logger import _torchrec_method_logger from torchrec.distributed.model_parallel import ShardedModule from torchrec.distributed.train_pipeline.pipeline_context import ( EmbeddingTrainPipelineContext, @@ -106,6 +107,8 @@ class TrainPipeline(abc.ABC, Generic[In, Out]): def progress(self, dataloader_iter: Iterator[In]) -> Out: pass + # pyre-ignore [56] + @_torchrec_method_logger() def __init__(self) -> None: # pipeline state such as in foward, in backward etc, used in training recover scenarios self._state: PipelineState = PipelineState.IDLE diff --git a/torchrec/modules/mc_embedding_modules.py b/torchrec/modules/mc_embedding_modules.py index 129da1e69..04d4e2489 100644 --- a/torchrec/modules/mc_embedding_modules.py +++ b/torchrec/modules/mc_embedding_modules.py @@ -12,6 +12,7 @@ import torch import torch.nn as nn +from torchrec.distributed.logger import _torchrec_method_logger from torchrec.modules.embedding_modules import ( EmbeddingBagCollection, @@ -125,6 +126,7 @@ class ManagedCollisionEmbeddingCollection(BaseManagedCollisionEmbeddingCollectio """ + @_torchrec_method_logger() def __init__( self, embedding_collection: EmbeddingCollection, @@ -164,6 +166,7 @@ class ManagedCollisionEmbeddingBagCollection(BaseManagedCollisionEmbeddingCollec """ + @_torchrec_method_logger() def __init__( self, embedding_bag_collection: EmbeddingBagCollection,