@@ -90,6 +90,7 @@ def build_pipeline_schedule(
9090 local_batch_size : int ,
9191 pipeline_parallel_degree : int ,
9292 backward_requires_autograd : bool = False ,
93+ scale_grads : bool = True ,
9394) -> _PipelineSchedule :
9495 """Builds a pipeline schedule for the given configuration and stages."""
9596 schedule_class = get_schedule_class (pipeline_parallel_schedule )
@@ -115,6 +116,7 @@ def build_pipeline_schedule(
115116 n_microbatches = n_microbatches ,
116117 loss_fn = loss_fn ,
117118 backward_requires_autograd = backward_requires_autograd ,
119+ scale_grads = scale_grads ,
118120 )
119121 logger .info (
120122 f"Using pipeline schedule { pipeline_parallel_schedule } "
@@ -192,9 +194,9 @@ def run_test(
192194 # This is the spmd mesh to be used for tracing
193195 mesh = world_mesh [("dp_mod_ep" , "ep" )]
194196
195- global_batch_size = 32 * dp_degree
196197 # Batch size that will be supplied to the schedule and will be broken down into microbatches
197- local_batch_size = global_batch_size // dp_degree
198+ local_batch_size = 32
199+ # global_batch_size = local_batch_size * dp_degree
198200 n_microbatches = 16
199201 # Batch size with which the spmd graphs will actually be executed
200202 microbatch_size = local_batch_size // n_microbatches
@@ -379,6 +381,31 @@ def last_stage_inp_with_loss_fn():
379381 for stages in pp_rank_to_stage_indices .values ():
380382 assert len (stages ) * pp_degree == len (virtual_pp_stages )
381383 stage_indices_current_pp_rank = pp_rank_to_stage_indices [pp_rank ]
384+ if rng_seed :
385+ # Compute the ranks to log from
386+ # 1. for fw_outs, log from coord [pp_rank_containing_last_stage, 0, 0]
387+ last_stage_idx = total_pp_stages - 1
388+ pp_rank_containing_last_stage = None
389+ for pp_rank_ , stage_indices in pp_rank_to_stage_indices .items ():
390+ if last_stage_idx in stage_indices :
391+ assert pp_rank_containing_last_stage is None
392+ pp_rank_containing_last_stage = pp_rank_
393+
394+ log_fw_out_rank_coordinate = []
395+ for mesh_dim_name in world_mesh .mesh_dim_names :
396+ if mesh_dim_name == "pp" :
397+ log_fw_out_rank_coordinate .append (pp_rank_containing_last_stage )
398+ else :
399+ log_fw_out_rank_coordinate .append (0 )
400+ should_log_fw_outs = world_mesh .get_coordinate () == log_fw_out_rank_coordinate
401+
402+ # 2. for weights, log from coords [:, 0, 0]
403+ pp_world_size = world_mesh .shape [world_mesh ._get_mesh_dim_by_name ("pp" )]
404+ log_weights_rank_coordinates = [(i , 0 , 0 ) for i in range (pp_world_size )]
405+ should_log_weights = (
406+ tuple (world_mesh .get_coordinate ()) in log_weights_rank_coordinates
407+ )
408+
382409 stage_mods : dict [int , torch .nn .Module ] = {}
383410 stage_graphs : dict [int , GraphCallables ] = {}
384411 stage_graph_metas : dict [int , GraphMeta ] = {}
@@ -504,10 +531,6 @@ def last_stage_inp_with_loss_fn():
504531
505532 world_size = torch .distributed .get_world_size ()
506533 num_world_stages = world_size * len (stage_mods )
507- if rng_seed is not None :
508- NumericsLogger (logs_dir ).log_pp_model_weights (
509- model , stage_mods , num_world_stages , ranks = [0 , 4 ]
510- )
511534
512535 stages = []
513536 # Step 4. Construct pipeline stages for this pp_rank using the stage modules, graphs and metadata
@@ -532,6 +555,7 @@ def last_stage_inp_with_loss_fn():
532555 group = world_mesh .get_group ("pp" ),
533556 )
534557 stages .append (stage )
558+
535559 # Step 5. Construct the pipeline runner using the pipeline stages for this pp_rank
536560 schedule = build_pipeline_schedule (
537561 stages = stages ,
@@ -541,11 +565,37 @@ def last_stage_inp_with_loss_fn():
541565 local_batch_size = local_batch_size ,
542566 pipeline_parallel_degree = pp_degree ,
543567 backward_requires_autograd = False ,
568+ scale_grads = rng_seed is None , # In determinism mode, don't scale grads
544569 )
545570 assert isinstance (schedule , _PipelineScheduleRuntime )
571+
572+ if rng_seed is not None :
573+ numerics_logger = NumericsLogger (logs_dir )
574+ numerics_logger .log_pp_model_weights (
575+ model , stage_mods , num_world_stages , should_log = should_log_weights
576+ )
577+ torch .manual_seed (rng_seed )
578+
579+ def last_stage_forward_hook (
580+ stage : GraphPipelineStage , action : str , output : torch .Tensor
581+ ):
582+ if not stage .is_last or rng_seed is None :
583+ # hook is only for numerics mode
584+ return
585+
586+ if should_log_fw_outs :
587+ numerics_logger .log_diff (
588+ output ,
589+ rank = torch .distributed .get_rank (),
590+ prefix = f"mb{ action .microbatch_index } fwd out" ,
591+ )
592+
546593 # Step 6. Override the pipeline runner's action implementations
547594 schedule .register_custom_function (
548- FORWARD , functools .partial (stage_forward , numerics_logs = None )
595+ FORWARD ,
596+ functools .partial (
597+ stage_forward , numerics_logs = None , forward_hook = last_stage_forward_hook
598+ ),
549599 )
550600 schedule .register_custom_function (FULL_BACKWARD , stage_full_backward )
551601 schedule .register_custom_function (REDUCE_GRAD , stage_reduce_grad )
@@ -576,6 +626,10 @@ def last_stage_inp_with_loss_fn():
576626 )
577627 if pp_rank == 0 :
578628 x = runtime_input_fn_first_stage ()
629+ if rng_seed :
630+ numerics_logger .log_diff (
631+ x .to (torch .float32 ), prefix = "full batch input"
632+ )
579633 graph_pp_runner .step (
580634 x , target = target , losses = losses , return_outputs = False
581635 )
@@ -590,6 +644,10 @@ def last_stage_inp_with_loss_fn():
590644 payload_fn = lambda : f"losses: { losses } " ,
591645 )
592646
647+ numerics_logger .log_pp_grads (
648+ model , stage_mods , num_world_stages , should_log = should_log_weights
649+ )
650+
593651 print ("All good!" )
594652
595653 if torch .distributed .is_initialized ():
0 commit comments