From afa15a037dfddb03973a630f3b2bf86241f0a678 Mon Sep 17 00:00:00 2001 From: Francisco Massa Date: Tue, 7 Oct 2025 08:39:56 +0000 Subject: [PATCH 1/2] Remove DTensor accumulateGrad overhead --- autoparallel/api.py | 39 +++++++++++++++++++++++++++++++++++++-- 1 file changed, 37 insertions(+), 2 deletions(-) diff --git a/autoparallel/api.py b/autoparallel/api.py index ce087738..79d24710 100644 --- a/autoparallel/api.py +++ b/autoparallel/api.py @@ -21,7 +21,7 @@ from torch._logging import trace_structured from torch._subclasses import FakeTensorMode from torch.distributed.fsdp import MixedPrecisionPolicy -from torch.distributed.tensor import DeviceMesh +from torch.distributed.tensor import DeviceMesh, DTensor from torch.export._unlift import _assign_attr from torch.export.unflatten import _AttrKind @@ -457,16 +457,51 @@ def apply_placement(self, sharding_placement=None): self.joint_with_descriptors ) + param_mappings = {} + + def fwd_hook(model, args): + nonlocal param_mappings + param_mappings.clear() + for module in model.modules(): + for name, p in module.named_parameters(recurse=False): + if not isinstance(p, DTensor): + continue + p_new = torch.nn.Parameter(p._local_tensor) + param_mappings[p_new] = p + module._parameters[name] = p_new + + def bwd_hook(model, grad_input, grad_output): + nonlocal param_mappings + for module in model.modules(): + for name, p in module.named_parameters(recurse=False): + if p not in param_mappings: + continue + orig_p = param_mappings[p] + grad = p.grad + if grad is not None: + grad = DTensor( + grad.detach(), + orig_p._spec, + requires_grad=grad.requires_grad, + ) + orig_p.grad = grad + module._parameters[name] = orig_p + # TODO: this probably belongs in the AOTAutograd API # TODO: pytree handling class AutoParallelModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.register_forward_pre_hook(fwd_hook, prepend=True) + self.register_full_backward_hook(bwd_hook) + def forward(self, *args): # NB: don't close over the parameters/buffers, as the user may # reassign the module! # TODO: It's this to just exactly match # prepare_aot_module_simplified, this seems like an API gap params = [ - v.to_local() + v for k, v in # TODO: this is very slow itertools.chain( From 939feeebdd93ff15f85bff49ad4ec0bb61df3e61 Mon Sep 17 00:00:00 2001 From: Francisco Massa Date: Tue, 7 Oct 2025 13:36:20 +0000 Subject: [PATCH 2/2] Bugfix for buffers --- autoparallel/api.py | 19 ++++++++----------- 1 file changed, 8 insertions(+), 11 deletions(-) diff --git a/autoparallel/api.py b/autoparallel/api.py index 79d24710..8f79dc87 100644 --- a/autoparallel/api.py +++ b/autoparallel/api.py @@ -4,7 +4,6 @@ # LICENSE file in the root directory of this source tree. import copy -import itertools import warnings from contextlib import ExitStack, contextmanager from types import MethodType @@ -500,16 +499,14 @@ def forward(self, *args): # reassign the module! # TODO: It's this to just exactly match # prepare_aot_module_simplified, this seems like an API gap - params = [ - v - for k, v in - # TODO: this is very slow - itertools.chain( - dict(self.named_parameters(remove_duplicate=False)).items(), - dict(self.named_buffers(remove_duplicate=False)).items(), - ) - ] - boxed_args = [*params, *args] + params = tuple( + v for k, v in self.named_parameters(remove_duplicate=False) + ) + # buffers aren't impacted by the fwd hook + buffers = tuple( + v.to_local() for k, v in self.named_buffers(remove_duplicate=False) + ) + boxed_args = [*params, *buffers, *args] del params # NB: don't do self.parallel_model_fn work around Dynamo bug out = parallel_model_fn(boxed_args)