Skip to content

Commit 83789ef

Browse files
author
Sanket Jayant Purandare
committed
Unify code path for different pp_degree specs
1 parent 6bd37a6 commit 83789ef

File tree

1 file changed

+16
-49
lines changed

1 file changed

+16
-49
lines changed

examples/example_ds3_pp.py

Lines changed: 16 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ def build_pipeline_schedule(
8989
def run_test(fake_evaluate: bool = True):
9090
if not fake_evaluate:
9191
pp_degree = 2
92-
dp_mod_ep_degree = 1
92+
dp_mod_ep_degree = 2
9393
ep_degree = 2
9494
else:
9595
pp_degree = 4
@@ -123,7 +123,7 @@ def run_test(fake_evaluate: bool = True):
123123
), "run with torchrun --standalone --nproc-per-node 8"
124124
assert (
125125
int(os.getenv("WORLD_SIZE")) == world_size
126-
), "Need at least 4 GPUs for real evaluation"
126+
), "Need at least 8 GPUs for real evaluation"
127127
local_rank = int(os.getenv("LOCAL_RANK"))
128128
device = torch.device(f"cuda:{local_rank}")
129129
torch.distributed.init_process_group(backend="nccl")
@@ -280,52 +280,19 @@ def shape_inference_output_fn_last_stage():
280280
requires_grad=True,
281281
)
282282

283-
if fake_evaluate:
284-
# Step 1. Construct the logical pipeline stages
285-
with torch.device("meta"):
286-
stage0 = DeepSeekV3Stage0(embed, layers[0], config)
287-
stage1 = DeepSeekV3StageI(layers[1], config)
288-
stage2 = DeepSeekV3StageI(layers[2], config)
289-
stage3 = DeepSeekV3StageI(layers[3], config)
290-
stage4 = DeepSeekV3StageI(layers[4], config)
291-
stage5 = DeepSeekV3StageI(layers[5], config)
292-
stage6 = DeepSeekV3StageI(layers[6], config)
293-
stage7 = DeepSeekV3StageN(layers[7], norm, output, config)
294-
virtual_pp_stages = [
295-
stage0,
296-
stage1,
297-
stage2,
298-
stage3,
299-
stage4,
300-
stage5,
301-
stage6,
302-
stage7,
303-
]
304-
# Step 2. Assign each logical stage(s) to pp ranks for Interleaved1F1B schedule
305-
pp_rank_to_stage_indices: dict[int, list[int]] = {
306-
0: [0, 4],
307-
1: [1, 5],
308-
2: [2, 6],
309-
3: [3, 7],
310-
}
311-
else:
312-
# Step 1. Construct the logical pipeline stages
313-
with torch.device("meta"):
314-
stage0 = DeepSeekV3Stage0(embed, layers[0], config)
315-
stage1 = DeepSeekV3StageI(layers[1], config)
316-
stage2 = DeepSeekV3StageI(layers[2], config)
317-
stage3 = DeepSeekV3StageN(layers[3], norm, output, config)
318-
virtual_pp_stages = [
319-
stage0,
320-
stage1,
321-
stage2,
322-
stage3,
323-
]
324-
# Step 2. Assign each logical stage(s) to pp ranks for Interleaved1F1B schedule
325-
pp_rank_to_stage_indices: dict[int, list[int]] = {
326-
0: [0, 2],
327-
1: [1, 3],
328-
}
283+
# Step 1. Construct the logical pipeline stages
284+
with torch.device("meta"):
285+
virtual_pp_stages = [DeepSeekV3Stage0(embed, layers[0], config)]
286+
for i in range(1, total_pp_stages - 1):
287+
virtual_pp_stages.append(DeepSeekV3StageI(layers[i], config))
288+
virtual_pp_stages.append(
289+
DeepSeekV3StageN(layers[total_pp_stages - 1], norm, output, config)
290+
)
291+
# Step 2. Assign each logical stage(s) to pp ranks for Interleaved1F1B schedule
292+
pp_rank_to_stage_indices: dict[int, list[int]] = {
293+
rank: [rank + i * pp_degree for i in range(stages_per_rank)]
294+
for rank in range(pp_degree)
295+
}
329296
assert len(pp_rank_to_stage_indices) == pp_degree
330297
for stages in pp_rank_to_stage_indices.values():
331298
assert len(stages) * pp_degree == len(virtual_pp_stages)
@@ -444,7 +411,7 @@ def shape_inference_output_fn_last_stage():
444411
),
445412
output_args=(
446413
shape_inference_output_fn_last_stage()
447-
if pp_stage_idx == 7
414+
if pp_stage_idx == (len(virtual_pp_stages) - 1)
448415
else shape_inference_input_fn_after_first_stage()
449416
),
450417
group=world_mesh.get_group("pp"),

0 commit comments

Comments
 (0)