diff --git a/fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py b/fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py index f63829845b..b8a5b335e3 100644 --- a/fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py +++ b/fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py @@ -186,6 +186,7 @@ def __init__( use_rowwise_bias_correction: bool = False, # For Adam use optimizer_state_dtypes: dict[str, SparseType] = {}, # noqa: B006 pg: Optional[dist.ProcessGroup] = None, + enable_optimizer_offloading: bool = False, ) -> None: super(SSDTableBatchedEmbeddingBags, self).__init__()