Skip to content

Commit 36903f2

Browse files
committed
Grad scaling parity between both pp and non-pp
1 parent 7c45448 commit 36903f2

File tree

2 files changed

+7
-5
lines changed

2 files changed

+7
-5
lines changed

autoparallel/graph_pp_runner.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -190,12 +190,12 @@ def _accumulate_stage_grads(
190190
) -> None:
191191
assert len(unsharded_grads) == len(grads_to_accumulate)
192192
assert not all(grad is None for grad in grads_to_accumulate), "All grads are None"
193-
for unsharded_grad, grad_to_accumulate in zip(unsharded_grads, grads_to_accumulate):
194-
if grad_to_accumulate is not None:
195-
if unsharded_grad is None:
196-
unsharded_grad = grad_to_accumulate
193+
for i in range(len(unsharded_grads)):
194+
if grads_to_accumulate[i] is not None:
195+
if unsharded_grads[i] is None:
196+
unsharded_grads[i] = grads_to_accumulate[i]
197197
else:
198-
unsharded_grad += grad_to_accumulate
198+
unsharded_grads[i] += grads_to_accumulate[i]
199199

200200

201201
def _run_forward_microbatch(
@@ -374,6 +374,7 @@ def stage_full_backward(
374374
# next stage
375375
# TODO(sanketpurandare)
376376
# HACK till we have loss function, we populate the tangents here manually
377+
assert len(stage_output) == 1
377378
bwd_kwargs = {
378379
"stage_output": loss,
379380
"tangents": [torch.ones_like(stage_output[0])],

examples/example_ds3_pp.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,7 @@ def build_pipeline_schedule(
9393
n_microbatches=n_microbatches,
9494
loss_fn=loss_fn,
9595
backward_requires_autograd=backward_requires_autograd,
96+
scale_grads=False,
9697
)
9798
logger.info(
9899
f"Using pipeline schedule {pipeline_parallel_schedule} "

0 commit comments

Comments
 (0)