From c84a38d0d58883f1cd8bb6d3ceb7d8f6ff1f678a Mon Sep 17 00:00:00 2001 From: Francisco Massa Date: Mon, 4 Aug 2025 11:07:48 +0000 Subject: [PATCH] Use fake_mode when constructing ShardingOptimizer This is particularly important for constructor nodes and for the flop estimation, otherwise they could materialize massive tensors in memory, leading to OOM. This shows up in DeepSeek. --- autoparallel/api.py | 27 ++++++++++++--------------- autoparallel/propagation_rules.py | 5 ++++- 2 files changed, 16 insertions(+), 16 deletions(-) diff --git a/autoparallel/api.py b/autoparallel/api.py index f59a727d..3ac7fc17 100644 --- a/autoparallel/api.py +++ b/autoparallel/api.py @@ -191,21 +191,18 @@ def __enter__(self): self.build_model_graph() - from torch._subclasses.fake_tensor import unset_fake_temporarily - - with unset_fake_temporarily(): - rescale_grad_comm_cost_for_mp = 1.0 - if self.mp_policy is not None: - param_size = self.mp_policy.param_dtype.itemsize - reduce_size = self.mp_policy.reduce_dtype.itemsize - if param_size != reduce_size: - rescale_grad_comm_cost_for_mp = reduce_size / param_size - # Tiebreak, favoring performing the comms in the largest - # dtype - rescale_grad_comm_cost_for_mp *= 1.1 - sharding_optimizer = ShardingOptimizer( - self.gm, self.mesh, rescale_grad_comm_cost_for_mp - ) + rescale_grad_comm_cost_for_mp = 1.0 + if self.mp_policy is not None: + param_size = self.mp_policy.param_dtype.itemsize + reduce_size = self.mp_policy.reduce_dtype.itemsize + if param_size != reduce_size: + rescale_grad_comm_cost_for_mp = reduce_size / param_size + # Tiebreak, favoring performing the comms in the largest + # dtype + rescale_grad_comm_cost_for_mp *= 1.1 + sharding_optimizer = ShardingOptimizer( + self.gm, self.mesh, rescale_grad_comm_cost_for_mp + ) # makes sharding of params and gradients the same sharding_optimizer.add_grad_param_constraints() diff --git a/autoparallel/propagation_rules.py b/autoparallel/propagation_rules.py index abba8964..f445f03e 100644 --- a/autoparallel/propagation_rules.py +++ b/autoparallel/propagation_rules.py @@ -700,7 +700,10 @@ def reshape_rule(mesh, op_schema): @register_opschema_rule(torch.ops.aten.expand.default) def expand_rule(mesh, op_schema_): op = torch.ops.aten.expand.default - op_schema = copy.deepcopy(op_schema_) + from torch._subclasses.fake_tensor import unset_fake_temporarily + + with unset_fake_temporarily(): + op_schema = copy.deepcopy(op_schema_) input_strat = op_schema.args_schema[0] orig_shape = input_strat.strategies[0].output_specs.tensor_meta.shape dest_shape = op_schema.args_schema[1]