Skip to content

Commit 6bd37a6

Browse files
author
Sanket Jayant Purandare
committed
Enabling real PP run on 4 GPUs
1 parent 9aebf3b commit 6bd37a6

File tree

3 files changed

+194
-180
lines changed

3 files changed

+194
-180
lines changed

autoparallel/_testing/models/dsv3.py

Lines changed: 31 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -272,9 +272,9 @@ def wrapper(
272272

273273
experts_per_ep_rank = w1.shape[0]
274274
num_ep_ranks = num_tokens_per_expert.shape[0] // experts_per_ep_rank
275-
assert (
276-
num_ep_ranks == 64
277-
), f"{num_ep_ranks}, {experts_per_ep_rank}, num_tokens_per_expert.shape: {num_tokens_per_expert.shape}, x={x.ndim}, w={w1.shape}"
275+
# assert (
276+
# num_ep_ranks == 64
277+
# ), f"{num_ep_ranks}, {experts_per_ep_rank}, num_tokens_per_expert.shape: {num_tokens_per_expert.shape}, x={x.ndim}, w={w1.shape}"
278278

279279
# Make sure max_len of permuted token indicies is divisible by TOKEN_GROUP_ALIGN_SIZE_M,
280280
# by padding it to the nearest multiple of TOKEN_GROUP_ALIGN_SIZE_M.
@@ -680,13 +680,12 @@ def local_mapped_region(
680680
experts_w3: torch.Tensor,
681681
experts_w2: torch.Tensor,
682682
out: torch.Tensor,
683+
top_k: int,
684+
num_experts: int,
683685
) -> tuple[torch.Tensor, torch.Tensor]:
684686
axis_name = "ep"
685687
# assert False, f"{x.shape}, {selected_experts_indices.shape}, {top_scores.shape}, {out.shape}"
686688

687-
top_k = 6
688-
num_experts = 64
689-
690689
dim = x.shape[-1]
691690

692691
# num_tokens_per_expert = torch.ops.autoparallel.batched_histc(
@@ -850,18 +849,18 @@ def _(
850849

851850

852851
def _moe_forward(
853-
x,
854-
router_gate_weight,
855-
expert_bias,
856-
experts_w1,
857-
experts_w3,
858-
experts_w2,
859-
shared_w1,
860-
shared_w3,
861-
shared_w2,
862-
router, # None
863-
reorderer, # None
864-
mesh, # None
852+
x: torch.Tensor,
853+
router_gate_weight: torch.Tensor,
854+
expert_bias: Optional[torch.Tensor],
855+
experts_w1: torch.Tensor,
856+
experts_w3: torch.Tensor,
857+
experts_w2: torch.Tensor,
858+
shared_w1: torch.Tensor,
859+
shared_w3: torch.Tensor,
860+
shared_w2: torch.Tensor,
861+
router: TokenChoiceTopKRouter, # None
862+
reorderer: TokenReorderer, # None
863+
mesh: Optional[DeviceMesh], # None
865864
):
866865
# x: 64, 2048, 256
867866
bs, slen, dim = x.shape
@@ -926,6 +925,8 @@ def _moe_forward(
926925
(Replicate(), Shard(0)),
927926
(Replicate(), Shard(0)),
928927
(Shard(0), Shard(0)),
928+
None,
929+
None,
929930
)
930931

931932
# assert False, f"{x.shape}, {selected_experts_indices.shape}, {top_scores.shape}, {out.shape}"
@@ -937,7 +938,17 @@ def _moe_forward(
937938
redistribute_inputs=True,
938939
in_grad_placements=None,
939940
device_mesh=mesh,
940-
)(selected_experts_indices, top_scores, x, experts_w1, experts_w3, experts_w2, out)
941+
)(
942+
selected_experts_indices,
943+
top_scores,
944+
x,
945+
experts_w1,
946+
experts_w3,
947+
experts_w2,
948+
out,
949+
router.top_k,
950+
router.num_experts,
951+
)
941952
# assert False, f"there: {out.shape}, {num_tokens_per_expert.shape}"
942953

943954
######################################################
@@ -1046,7 +1057,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
10461057
shared_w2,
10471058
self.router, # None
10481059
self.reorderer, # None
1049-
self.mesh,
1060+
self.mesh, # None
10501061
)
10511062

10521063
# HOPs don't support buffer mutations, keep this outside

examples/example_ds3_local_map.py

Lines changed: 0 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -36,50 +36,8 @@
3636
),
3737
)
3838

39-
40-
seq_len = 1024
41-
42-
config = DeepSeekV3ModelArgs(
43-
vocab_size=2048,
44-
max_seq_len=seq_len,
45-
dim=256,
46-
inter_dim=1024,
47-
moe_inter_dim=256,
48-
n_layers=1, # 6,
49-
n_dense_layers=0, # 1,
50-
n_heads=16,
51-
moe_args=MoEArgs(
52-
num_experts=8,
53-
num_shared_experts=2,
54-
top_k=3,
55-
score_func="softmax",
56-
route_norm=False,
57-
score_before_experts=False,
58-
mesh=mesh,
59-
),
60-
q_lora_rank=0,
61-
kv_lora_rank=512,
62-
qk_nope_head_dim=128,
63-
qk_rope_head_dim=64,
64-
v_head_dim=128,
65-
mscale=0.70,
66-
)
67-
6839
device = torch.device("cuda")
6940

70-
if False:
71-
model = DeepSeekV3Model(config)
72-
model.to(device)
73-
74-
global_batch_size = 2
75-
76-
x = torch.randint(
77-
0,
78-
config.vocab_size,
79-
(global_batch_size, seq_len),
80-
device=device,
81-
)
82-
o = model(x)
8341

8442
bs = 4 * mesh.shape[0] * mesh.shape[1]
8543
seq_len = 1024

0 commit comments

Comments
 (0)