Skip to content

Commit a8d46ea

Browse files
Enable Loss Fn in Graph PP (#247)
* Enable Loss Fn in Graph PP [ghstack-poisoned] * Update on "Enable Loss Fn in Graph PP" [ghstack-poisoned] * Update on "Enable Loss Fn in Graph PP" [ghstack-poisoned] * Update on "Enable Loss Fn in Graph PP" [ghstack-poisoned]
1 parent 2895806 commit a8d46ea

File tree

7 files changed

+370
-132
lines changed

7 files changed

+370
-132
lines changed

autoparallel/_testing/models/dsv3.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1584,6 +1584,12 @@ def forward(
15841584
return output
15851585

15861586

1587+
def dsv3_loss_fn(pred: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
1588+
return torch.nn.functional.cross_entropy(
1589+
pred.flatten(0, 1).float(), labels.flatten(0, 1)
1590+
)
1591+
1592+
15871593
########################
15881594
# Pipeline stuff start #
15891595
########################

autoparallel/api.py

Lines changed: 67 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import warnings
99
from contextlib import ExitStack, contextmanager
1010
from types import MethodType
11-
from typing import Any, Optional, Union
11+
from typing import Any, Callable, Optional, Union
1212

1313
import torch
1414
from torch._dynamo.functional_export import _dynamo_graph_capture_for_export
@@ -159,17 +159,35 @@ def enable_local_map_wrapping():
159159
yield
160160

161161

162-
def _export(model: torch.nn.Module, inputs: tuple[Any]) -> torch.nn.Module:
162+
def _export(
163+
model: torch.nn.Module, model_wrapper: Callable, inputs: tuple[Any, ...]
164+
) -> torch.fx.GraphModule:
163165
"""
164-
Thin wrapper around graph capture output that restores the
165-
original calling convention and attribute fqn. TODO:
166-
1) Use bytecode for calling convention instead of pytree for more
167-
seamless UX.
166+
Capture a model graph via Dynamo and restore parameter/buffer metadata.
167+
168+
We need both `model` and `model_wrapper` because:
169+
- `model_wrapper` is the actual callable that gets traced by Dynamo. It may wrap
170+
the model with additional logic (e.g., adding a loss function on top of the model's
171+
forward pass, or preparing inputs in a specific way).
172+
- `model` is the original nn.Module needed to restore the correct fully-qualified
173+
names (FQNs) for parameters and buffers in the traced graph. Without this, the
174+
captured graph would lose the original parameter naming structure.
175+
176+
Args:
177+
model: Original nn.Module with parameter/buffer metadata to restore
178+
model_wrapper: Callable to trace (may wrap model with additional logic)
179+
inputs: Input tensors for tracing
180+
181+
Returns:
182+
GraphModule with restored parameter FQNs and calling convention
183+
184+
TODO:
185+
1) Use bytecode for calling convention instead of pytree for more seamless UX
168186
2) Attach guards
169-
3) Be more careful about tensor constants names.
187+
3) Be more careful about tensor constants names
170188
"""
171189
with torch._dynamo.config.patch(install_free_tensors=True):
172-
gm = _dynamo_graph_capture_for_export(model)(*inputs)
190+
gm = _dynamo_graph_capture_for_export(model_wrapper)(*inputs)
173191
_restore_state_dict(model, gm)
174192
return gm
175193

@@ -193,6 +211,7 @@ def __init__(
193211
ac_stage_size_in_GiB: Optional[Union[float, str]] = "auto",
194212
reshard_after_forward: bool = True,
195213
dynamic: bool = False,
214+
loss_fn: Optional[Callable] = None,
196215
**kwargs,
197216
):
198217
self.stack = ExitStack()
@@ -224,6 +243,7 @@ def __init__(
224243
self.enable_ac = enable_ac
225244
self.ac_stage_size_in_GiB = ac_stage_size_in_GiB
226245
self.reshard_after_forward = reshard_after_forward
246+
self.loss_fn = loss_fn
227247

228248
if dynamic:
229249
self.fake_mode.shape_env = ShapeEnv()
@@ -291,20 +311,54 @@ def build_model_graph(self):
291311
decomp_table = _get_decomp_table()
292312

293313
with self.fake_mode:
294-
inputs = self.input_fn()
295-
if not isinstance(inputs, tuple):
296-
inputs = (inputs,)
314+
raw_inputs = self.input_fn()
315+
316+
# Define model wrapper based on whether loss_fn is provided
317+
model_wrapper: Callable
318+
# Parse inputs based on whether loss_fn is provided
319+
# If loss_fn is not None: expected format ((inp1, inp2,...), target)
320+
# If loss_fn is None: expected format (inp1, inp2, ...)
321+
if self.loss_fn is not None:
322+
if isinstance(raw_inputs, tuple) and len(raw_inputs) == 2:
323+
model_inputs, target = raw_inputs
324+
# Normalize inputs to always be a tuple
325+
if not isinstance(model_inputs, tuple):
326+
model_inputs = (model_inputs,)
327+
formatted_inputs = (model_inputs, target)
328+
329+
def model_with_loss(model_inputs, target) -> Any:
330+
output = self.model(*model_inputs)
331+
loss = self.loss_fn(output, target) # type: ignore[misc]
332+
return loss
333+
334+
model_wrapper = model_with_loss
335+
else:
336+
raise ValueError(
337+
"When loss_fn is provided, input_fn must return (inputs, target) "
338+
"where inputs can be a single tensor or tuple of tensors"
339+
)
340+
else:
341+
# No loss function, inputs are just model inputs
342+
formatted_inputs = (
343+
raw_inputs if isinstance(raw_inputs, tuple) else (raw_inputs,)
344+
)
345+
346+
def model_wo_loss(*model_inputs) -> Any:
347+
output = self.model(*model_inputs)
348+
return output
349+
350+
model_wrapper = model_wo_loss
297351

298352
with set_dtype_cast(
299353
True
300354
), enable_local_map_wrapping(), torch._dynamo.utils._disable_saved_tensors_hooks_during_tracing():
301-
torch_ir_with_fqn = _export(self.model, inputs)
355+
torch_ir_with_fqn = _export(self.model, model_wrapper, formatted_inputs)
302356
# TODO Cna't use fake mode here because it clashes with the user level
303357
# fake mode. Ideally dynamo should reuse the user level fake mode.
304358
self.joint_with_descriptors = aot_export_joint_with_descriptors(
305359
self.stack,
306360
torch_ir_with_fqn,
307-
inputs,
361+
formatted_inputs,
308362
decompositions=decomp_table,
309363
)
310364
gm = self.joint_with_descriptors.graph_module

autoparallel/apply_sharding.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -208,13 +208,18 @@ def call_function(self, target, args, kwargs):
208208

209209
# apply sharding to constructor functions as well
210210
if target in TENSOR_FACTORY_OPS:
211-
val = list(new_args[0])
212-
spec = self.sharding_placement[node].output_specs
213-
for mesh_size, placement in zip(spec.mesh.shape, spec.placements):
214-
if placement.is_shard():
215-
# TODO: fix uneven cases ?
216-
val[placement.dim] //= mesh_size
217-
new_args[0] = tuple(val)
211+
# scalar_tensor has a scalar as first arg, not a shape
212+
if target == torch.ops.aten.scalar_tensor.default:
213+
# scalar tensors can't be sharded, so no transformation needed
214+
pass
215+
else:
216+
val = list(new_args[0])
217+
spec = self.sharding_placement[node].output_specs
218+
for mesh_size, placement in zip(spec.mesh.shape, spec.placements):
219+
if placement.is_shard():
220+
# TODO: fix uneven cases ?
221+
val[placement.dim] //= mesh_size
222+
new_args[0] = tuple(val)
218223

219224
# use DTensor machinery to ensure the view ops are valid
220225
# otherwise we would end-up forcing global shapes on local tensors

autoparallel/graph_pp_runner.py

Lines changed: 34 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,20 @@ def _run_reduce_grad_module(
184184
return sharded_grads
185185

186186

187+
def _accumulate_stage_grads(
188+
unsharded_grads: list[Union[torch.Tensor, None]],
189+
grads_to_accumulate: list[Union[torch.Tensor, None]],
190+
) -> None:
191+
assert len(unsharded_grads) == len(grads_to_accumulate)
192+
assert not all(grad is None for grad in grads_to_accumulate), "All grads are None"
193+
for unsharded_grad, grad_to_accumulate in zip(unsharded_grads, grads_to_accumulate):
194+
if grad_to_accumulate is not None:
195+
if unsharded_grad is None:
196+
unsharded_grad = grad_to_accumulate
197+
else:
198+
unsharded_grad += grad_to_accumulate
199+
200+
187201
def _run_forward_microbatch(
188202
stage: GraphPipelineStage, *args, numerics_logs: Optional[list[str]] = None
189203
) -> tuple[Any, Any]:
@@ -216,17 +230,9 @@ def _run_backward_microbatch(
216230
)
217231

218232
unsharded_grads = backward_stage.state["unsharded_grads"]
219-
grads_to_accumulate = param_buffer_grads[
220-
: len(backward_stage.state["sharded_params"])
221-
]
222-
assert len(unsharded_grads) == len(grads_to_accumulate)
223-
assert not all(grad is None for grad in grads_to_accumulate), "All grads are None"
224-
for unsharded_grad, grad_to_accumulate in zip(unsharded_grads, grads_to_accumulate):
225-
if grad_to_accumulate is not None:
226-
if unsharded_grad is None:
227-
unsharded_grad = grad_to_accumulate
228-
else:
229-
unsharded_grad += grad_to_accumulate
233+
grads_to_accumulate = param_buffer_grads[: backward_stage.graph_meta.num_params]
234+
_accumulate_stage_grads(unsharded_grads, grads_to_accumulate)
235+
230236
return input_grads
231237

232238

@@ -275,6 +281,11 @@ def stage_forward(
275281
# Receive activations for this chunk
276282
# Activations only come in args form
277283
composite_args = stage._retrieve_recv_activations(mb_index)
284+
if stage.is_last and ctx.target_mbs is not None:
285+
assert isinstance(
286+
composite_args, tuple
287+
), f"Expected composite args to be a tuple but got {type(composite_args)}"
288+
composite_args = composite_args + (ctx.target_mbs[mb_index],) # type: ignore[index]
278289

279290
# stage._validate_fwd_input(args, kwargs) Maybe need to validate composite args?
280291
logger.debug(
@@ -292,6 +303,8 @@ def stage_forward(
292303
# Output chunks is only used for the last stage since we only merge the output of the last stage
293304
if stage.is_last:
294305
stage.output_chunks.append(output)
306+
if ctx.target_mbs is not None:
307+
ctx.schedule_ref._internal_losses.append(output)
295308

296309
stage.fwd_cache[mb_index] = (
297310
output_tuple, # stage_output
@@ -360,7 +373,7 @@ def stage_full_backward(
360373
# HACK till we have loss function, we populate the tangents here manually
361374
bwd_kwargs = {
362375
"stage_output": loss,
363-
"tangents": [torch.randn_like(stage_output[0])],
376+
"tangents": [torch.ones_like(stage_output[0])],
364377
"saved_intermediates": saved_intermediates,
365378
}
366379
else:
@@ -525,7 +538,9 @@ def _accumulate_stage_grads_and_clear_states(
525538
stage.state.clear()
526539

527540
def step(self, *args, **kwargs) -> None:
528-
541+
has_targets_and_loss = (
542+
"losses" in kwargs and "targets" in kwargs if kwargs else False
543+
)
529544
for stage in self.schedule._stages:
530545
assert isinstance(stage, GraphPipelineStage)
531546
self._populate_stage_states(stage)
@@ -535,3 +550,9 @@ def step(self, *args, **kwargs) -> None:
535550
for stage in self.schedule._stages:
536551
assert isinstance(stage, GraphPipelineStage)
537552
self._accumulate_stage_grads_and_clear_states(stage)
553+
554+
if has_targets_and_loss:
555+
losses = kwargs["losses"]
556+
assert len(self.schedule._internal_losses) == self.schedule._n_microbatches
557+
losses.extend(self.schedule._internal_losses)
558+
self.schedule._internal_losses.clear()

autoparallel/propagation_rules.py

Lines changed: 44 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -363,11 +363,8 @@ def randperm_rule(mesh, specs):
363363
return OpStrategy([OpSpec(spec, input_specs=[spec], redistribute_cost=[[0.0]])])
364364

365365

366-
# We do a few special things for factory ops
367-
# - use the factory rule below
368-
# - fake that they have input schemas so the solver doesn't freak out
369-
# - convert their sizes to 'local tensor sizes' (divide by mesh dim) during ApplySharding
370-
TENSOR_FACTORY_OPS = [
366+
# Factory ops that take a shape as the first argument
367+
_SHAPE_FACTORY_OPS = [
371368
torch.ops.aten.zeros.default,
372369
torch.ops.aten.ones.default,
373370
torch.ops.aten.full.default,
@@ -376,8 +373,49 @@ def randperm_rule(mesh, specs):
376373
torch.ops.aten.randn.default,
377374
]
378375

376+
# We do a few special things for factory ops
377+
# - use the factory rule below
378+
# - fake that they have input schemas so the solver doesn't freak out
379+
# - convert their sizes to 'local tensor sizes' (divide by mesh dim) during ApplySharding
380+
TENSOR_FACTORY_OPS = _SHAPE_FACTORY_OPS + [
381+
torch.ops.aten.scalar_tensor.default, # Special case: creates 0-dim tensor
382+
]
383+
384+
385+
@register_opschema_rule(torch.ops.aten.scalar_tensor.default)
386+
def scalar_tensor_rule(mesh, op_schema: OpSchema) -> OpStrategy:
387+
"""
388+
Rule for aten.scalar_tensor which creates a scalar (0-dimensional) tensor.
389+
Unlike other factory ops, this doesn't take a shape parameter.
390+
391+
Schema: scalar_tensor(Scalar s, *, ScalarType? dtype=None, ...) -> Tensor
392+
"""
393+
# scalar_tensor creates a 0-dimensional tensor
394+
shape = ()
395+
stride = ()
396+
dtype = torch.get_default_dtype()
397+
398+
# Check if dtype is specified in kwargs or args
399+
if len(op_schema.args_schema) >= 2 and op_schema.args_schema[1] is not None:
400+
dtype = op_schema.args_schema[1] # type: ignore[assignment]
401+
402+
tensor_meta = TensorMeta(shape, stride, dtype) # type: ignore[arg-type]
403+
404+
# For a scalar (0-dim) tensor, we can only replicate across all mesh dimensions
405+
placement = (Replicate(),) * mesh.ndim
406+
output_specs = DTensorSpec(mesh, placement, tensor_meta=tensor_meta)
407+
408+
# Similar to factory_rule, we add a dummy input_specs for solver compatibility
409+
strategy = OpSpec(
410+
output_specs=output_specs,
411+
input_specs=[output_specs],
412+
redistribute_cost=[[0.0]],
413+
)
414+
415+
return OpStrategy([strategy])
416+
379417

380-
@register_opschema_rule(TENSOR_FACTORY_OPS)
418+
@register_opschema_rule(_SHAPE_FACTORY_OPS)
381419
def factory_rule(mesh, op_schema: OpSchema) -> OpStrategy:
382420
"""
383421
This is an auto-parallel specific util that won't be upstreamed becuase of a UX mismatch.

0 commit comments

Comments
 (0)