@@ -272,9 +272,9 @@ def wrapper(
272272
273273 experts_per_ep_rank = w1 .shape [0 ]
274274 num_ep_ranks = num_tokens_per_expert .shape [0 ] // experts_per_ep_rank
275- assert (
276- num_ep_ranks == 64
277- ), f"{ num_ep_ranks } , { experts_per_ep_rank } , num_tokens_per_expert.shape: { num_tokens_per_expert .shape } , x={ x .ndim } , w={ w1 .shape } "
275+ # assert (
276+ # num_ep_ranks == 64
277+ # ), f"{num_ep_ranks}, {experts_per_ep_rank}, num_tokens_per_expert.shape: {num_tokens_per_expert.shape}, x={x.ndim}, w={w1.shape}"
278278
279279 # Make sure max_len of permuted token indicies is divisible by TOKEN_GROUP_ALIGN_SIZE_M,
280280 # by padding it to the nearest multiple of TOKEN_GROUP_ALIGN_SIZE_M.
@@ -680,13 +680,12 @@ def local_mapped_region(
680680 experts_w3 : torch .Tensor ,
681681 experts_w2 : torch .Tensor ,
682682 out : torch .Tensor ,
683+ top_k : int ,
684+ num_experts : int ,
683685) -> tuple [torch .Tensor , torch .Tensor ]:
684686 axis_name = "ep"
685687 # assert False, f"{x.shape}, {selected_experts_indices.shape}, {top_scores.shape}, {out.shape}"
686688
687- top_k = 6
688- num_experts = 64
689-
690689 dim = x .shape [- 1 ]
691690
692691 # num_tokens_per_expert = torch.ops.autoparallel.batched_histc(
@@ -850,18 +849,18 @@ def _(
850849
851850
852851def _moe_forward (
853- x ,
854- router_gate_weight ,
855- expert_bias ,
856- experts_w1 ,
857- experts_w3 ,
858- experts_w2 ,
859- shared_w1 ,
860- shared_w3 ,
861- shared_w2 ,
862- router , # None
863- reorderer , # None
864- mesh , # None
852+ x : torch . Tensor ,
853+ router_gate_weight : torch . Tensor ,
854+ expert_bias : Optional [ torch . Tensor ] ,
855+ experts_w1 : torch . Tensor ,
856+ experts_w3 : torch . Tensor ,
857+ experts_w2 : torch . Tensor ,
858+ shared_w1 : torch . Tensor ,
859+ shared_w3 : torch . Tensor ,
860+ shared_w2 : torch . Tensor ,
861+ router : TokenChoiceTopKRouter , # None
862+ reorderer : TokenReorderer , # None
863+ mesh : Optional [ DeviceMesh ] , # None
865864):
866865 # x: 64, 2048, 256
867866 bs , slen , dim = x .shape
@@ -926,6 +925,8 @@ def _moe_forward(
926925 (Replicate (), Shard (0 )),
927926 (Replicate (), Shard (0 )),
928927 (Shard (0 ), Shard (0 )),
928+ None ,
929+ None ,
929930 )
930931
931932 # assert False, f"{x.shape}, {selected_experts_indices.shape}, {top_scores.shape}, {out.shape}"
@@ -937,7 +938,17 @@ def _moe_forward(
937938 redistribute_inputs = True ,
938939 in_grad_placements = None ,
939940 device_mesh = mesh ,
940- )(selected_experts_indices , top_scores , x , experts_w1 , experts_w3 , experts_w2 , out )
941+ )(
942+ selected_experts_indices ,
943+ top_scores ,
944+ x ,
945+ experts_w1 ,
946+ experts_w3 ,
947+ experts_w2 ,
948+ out ,
949+ router .top_k ,
950+ router .num_experts ,
951+ )
941952 # assert False, f"there: {out.shape}, {num_tokens_per_expert.shape}"
942953
943954 ######################################################
@@ -1046,7 +1057,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
10461057 shared_w2 ,
10471058 self .router , # None
10481059 self .reorderer , # None
1049- self .mesh ,
1060+ self .mesh , # None
10501061 )
10511062
10521063 # HOPs don't support buffer mutations, keep this outside
0 commit comments