@@ -89,7 +89,7 @@ def build_pipeline_schedule(
8989def run_test (fake_evaluate : bool = True ):
9090 if not fake_evaluate :
9191 pp_degree = 2
92- dp_mod_ep_degree = 1
92+ dp_mod_ep_degree = 2
9393 ep_degree = 2
9494 else :
9595 pp_degree = 4
@@ -123,7 +123,7 @@ def run_test(fake_evaluate: bool = True):
123123 ), "run with torchrun --standalone --nproc-per-node 8"
124124 assert (
125125 int (os .getenv ("WORLD_SIZE" )) == world_size
126- ), "Need at least 4 GPUs for real evaluation"
126+ ), "Need at least 8 GPUs for real evaluation"
127127 local_rank = int (os .getenv ("LOCAL_RANK" ))
128128 device = torch .device (f"cuda:{ local_rank } " )
129129 torch .distributed .init_process_group (backend = "nccl" )
@@ -280,52 +280,19 @@ def shape_inference_output_fn_last_stage():
280280 requires_grad = True ,
281281 )
282282
283- if fake_evaluate :
284- # Step 1. Construct the logical pipeline stages
285- with torch .device ("meta" ):
286- stage0 = DeepSeekV3Stage0 (embed , layers [0 ], config )
287- stage1 = DeepSeekV3StageI (layers [1 ], config )
288- stage2 = DeepSeekV3StageI (layers [2 ], config )
289- stage3 = DeepSeekV3StageI (layers [3 ], config )
290- stage4 = DeepSeekV3StageI (layers [4 ], config )
291- stage5 = DeepSeekV3StageI (layers [5 ], config )
292- stage6 = DeepSeekV3StageI (layers [6 ], config )
293- stage7 = DeepSeekV3StageN (layers [7 ], norm , output , config )
294- virtual_pp_stages = [
295- stage0 ,
296- stage1 ,
297- stage2 ,
298- stage3 ,
299- stage4 ,
300- stage5 ,
301- stage6 ,
302- stage7 ,
303- ]
304- # Step 2. Assign each logical stage(s) to pp ranks for Interleaved1F1B schedule
305- pp_rank_to_stage_indices : dict [int , list [int ]] = {
306- 0 : [0 , 4 ],
307- 1 : [1 , 5 ],
308- 2 : [2 , 6 ],
309- 3 : [3 , 7 ],
310- }
311- else :
312- # Step 1. Construct the logical pipeline stages
313- with torch .device ("meta" ):
314- stage0 = DeepSeekV3Stage0 (embed , layers [0 ], config )
315- stage1 = DeepSeekV3StageI (layers [1 ], config )
316- stage2 = DeepSeekV3StageI (layers [2 ], config )
317- stage3 = DeepSeekV3StageN (layers [3 ], norm , output , config )
318- virtual_pp_stages = [
319- stage0 ,
320- stage1 ,
321- stage2 ,
322- stage3 ,
323- ]
324- # Step 2. Assign each logical stage(s) to pp ranks for Interleaved1F1B schedule
325- pp_rank_to_stage_indices : dict [int , list [int ]] = {
326- 0 : [0 , 2 ],
327- 1 : [1 , 3 ],
328- }
283+ # Step 1. Construct the logical pipeline stages
284+ with torch .device ("meta" ):
285+ virtual_pp_stages = [DeepSeekV3Stage0 (embed , layers [0 ], config )]
286+ for i in range (1 , total_pp_stages - 1 ):
287+ virtual_pp_stages .append (DeepSeekV3StageI (layers [i ], config ))
288+ virtual_pp_stages .append (
289+ DeepSeekV3StageN (layers [total_pp_stages - 1 ], norm , output , config )
290+ )
291+ # Step 2. Assign each logical stage(s) to pp ranks for Interleaved1F1B schedule
292+ pp_rank_to_stage_indices : dict [int , list [int ]] = {
293+ rank : [rank + i * pp_degree for i in range (stages_per_rank )]
294+ for rank in range (pp_degree )
295+ }
329296 assert len (pp_rank_to_stage_indices ) == pp_degree
330297 for stages in pp_rank_to_stage_indices .values ():
331298 assert len (stages ) * pp_degree == len (virtual_pp_stages )
@@ -444,7 +411,7 @@ def shape_inference_output_fn_last_stage():
444411 ),
445412 output_args = (
446413 shape_inference_output_fn_last_stage ()
447- if pp_stage_idx == 7
414+ if pp_stage_idx == ( len ( virtual_pp_stages ) - 1 )
448415 else shape_inference_input_fn_after_first_stage ()
449416 ),
450417 group = world_mesh .get_group ("pp" ),
0 commit comments