@@ -166,9 +166,9 @@ def run_test(
166166 # This is the spmd mesh to be used for tracing
167167 mesh = world_mesh [("dp_mod_ep" , "ep" )]
168168
169- global_batch_size = 32 * dp_degree
170169 # Batch size that will be supplied to the schedule and will be broken down into microbatches
171- local_batch_size = global_batch_size // dp_degree
170+ local_batch_size = 32
171+ # global_batch_size = local_batch_size * dp_degree
172172 n_microbatches = 16
173173 # Batch size with which the spmd graphs will actually be executed
174174 microbatch_size = local_batch_size // n_microbatches
@@ -472,10 +472,6 @@ def last_stage_inp_with_loss_fn():
472472
473473 world_size = torch .distributed .get_world_size ()
474474 num_world_stages = world_size * len (stage_mods )
475- if rng_seed is not None :
476- NumericsLogger (logs_dir ).log_pp_model_weights (
477- model , stage_mods , num_world_stages , ranks = [0 , 4 ]
478- )
479475
480476 stages = []
481477 # Step 4. Construct pipeline stages for this pp_rank using the stage modules, graphs and metadata
@@ -500,6 +496,7 @@ def last_stage_inp_with_loss_fn():
500496 group = world_mesh .get_group ("pp" ),
501497 )
502498 stages .append (stage )
499+
503500 # Step 5. Construct the pipeline runner using the pipeline stages for this pp_rank
504501 schedule = build_pipeline_schedule (
505502 stages = stages ,
@@ -511,9 +508,32 @@ def last_stage_inp_with_loss_fn():
511508 backward_requires_autograd = False ,
512509 )
513510 assert isinstance (schedule , _PipelineScheduleRuntime )
511+
512+ if rng_seed is not None :
513+ numerics_logger = NumericsLogger (logs_dir )
514+ numerics_logger .log_pp_model_weights (
515+ model , stage_mods , num_world_stages , ranks = [0 , 4 ]
516+ )
517+ torch .manual_seed (rng_seed )
518+
519+ def last_stage_forward_hook (
520+ stage : GraphPipelineStage , action : str , output : torch .Tensor
521+ ):
522+ if not stage .is_last or rng_seed is None :
523+ return
524+
525+ rank = torch .distributed .get_rank ()
526+ if rank == 4 :
527+ numerics_logger .log_diff (
528+ output , rank = 4 , prefix = f"mb{ action .microbatch_index } fwd out"
529+ )
530+
514531 # Step 6. Override the pipeline runner's action implementations
515532 schedule .register_custom_function (
516- FORWARD , functools .partial (stage_forward , numerics_logs = None )
533+ FORWARD ,
534+ functools .partial (
535+ stage_forward , numerics_logs = None , forward_hook = last_stage_forward_hook
536+ ),
517537 )
518538 schedule .register_custom_function (FULL_BACKWARD , stage_full_backward )
519539 schedule .register_custom_function (REDUCE_GRAD , stage_reduce_grad )
@@ -542,6 +562,10 @@ def last_stage_inp_with_loss_fn():
542562 )
543563 if pp_rank == 0 :
544564 x = runtime_input_fn_first_stage ()
565+ if rng_seed :
566+ numerics_logger .log_diff (
567+ x .to (torch .float32 ), prefix = "full batch input"
568+ )
545569 graph_pp_runner .step (
546570 x , target = target , losses = losses , return_outputs = False
547571 )
@@ -556,6 +580,8 @@ def last_stage_inp_with_loss_fn():
556580 payload_fn = lambda : f"losses: { losses } " ,
557581 )
558582
583+ numerics_logger .log_pp_grads (model , stage_mods , num_world_stages , ranks = [0 , 4 ])
584+
559585 print ("All good!" )
560586
561587 if torch .distributed .is_initialized ():
0 commit comments