File tree Expand file tree Collapse file tree 2 files changed +7
-5
lines changed Expand file tree Collapse file tree 2 files changed +7
-5
lines changed Original file line number Diff line number Diff line change @@ -190,12 +190,12 @@ def _accumulate_stage_grads(
190190) -> None :
191191 assert len (unsharded_grads ) == len (grads_to_accumulate )
192192 assert not all (grad is None for grad in grads_to_accumulate ), "All grads are None"
193- for unsharded_grad , grad_to_accumulate in zip ( unsharded_grads , grads_to_accumulate ):
194- if grad_to_accumulate is not None :
195- if unsharded_grad is None :
196- unsharded_grad = grad_to_accumulate
193+ for i in range ( len ( unsharded_grads ) ):
194+ if grads_to_accumulate [ i ] is not None :
195+ if unsharded_grads [ i ] is None :
196+ unsharded_grads [ i ] = grads_to_accumulate [ i ]
197197 else :
198- unsharded_grad += grad_to_accumulate
198+ unsharded_grads [ i ] += grads_to_accumulate [ i ]
199199
200200
201201def _run_forward_microbatch (
@@ -374,6 +374,7 @@ def stage_full_backward(
374374 # next stage
375375 # TODO(sanketpurandare)
376376 # HACK till we have loss function, we populate the tangents here manually
377+ assert len (stage_output ) == 1
377378 bwd_kwargs = {
378379 "stage_output" : loss ,
379380 "tangents" : [torch .ones_like (stage_output [0 ])],
Original file line number Diff line number Diff line change @@ -93,6 +93,7 @@ def build_pipeline_schedule(
9393 n_microbatches = n_microbatches ,
9494 loss_fn = loss_fn ,
9595 backward_requires_autograd = backward_requires_autograd ,
96+ scale_grads = False ,
9697 )
9798 logger .info (
9899 f"Using pipeline schedule { pipeline_parallel_schedule } "
You can’t perform that action at this time.
0 commit comments