@@ -659,11 +659,6 @@ def test_moe_a2a_dispatch_moe_combine(ep_size, all_num_tokens, top_k):
659659 hidden_states = payloads [0 ]
660660 token_final_scales = payloads [2 ]
661661
662- # Create experts for this rank
663- experts = create_experts (
664- num_experts_per_rank , hidden_size , rank , "cuda" , dtype = torch .bfloat16
665- )
666-
667662 # Compute reference (single-GPU MoE)
668663 all_experts = torch .cat (
669664 [
@@ -675,14 +670,16 @@ def test_moe_a2a_dispatch_moe_combine(ep_size, all_num_tokens, top_k):
675670 dim = 0 ,
676671 )
677672
673+ rank_experts = create_experts (
674+ num_experts_per_rank , hidden_size , rank , "cuda" , dtype = torch .bfloat16
675+ )
676+
678677 reference_output = fake_moe (
679678 hidden_states ,
680679 token_selected_experts ,
681680 token_final_scales ,
682681 all_experts ,
683- is_ep = True ,
684- ep_rank = rank ,
685- num_experts_per_rank = num_experts_per_rank ,
682+ is_ep = False ,
686683 )
687684
688685 torch .cuda .synchronize ()
@@ -717,21 +714,22 @@ def test_moe_a2a_dispatch_moe_combine(ep_size, all_num_tokens, top_k):
717714 moe_output .zero_ ()
718715
719716 # Process each rank's tokens with local experts
720- for source_rank in range (ep_size ):
721- source_num_tokens = all_num_tokens [source_rank ]
722- for token_idx in range (source_num_tokens ):
723- for k in range (top_k ):
724- expert_id = token_selected_experts_recv [
725- source_rank , token_idx , k
726- ].item ()
727- local_expert_id = expert_id - rank * num_experts_per_rank
728-
729- if 0 <= local_expert_id < num_experts_per_rank :
730- token_hidden = hidden_states_recv [source_rank , token_idx ]
731- scale = token_final_scales_recv [source_rank , token_idx , k ]
732- expert_out = token_hidden @ experts [local_expert_id ]
733- output_idx = source_rank * max_num_tokens + token_idx
734- moe_output [output_idx ] += expert_out * scale
717+ print (
718+ f"hidden_states_recv.shape: { hidden_states_recv .shape } , token_selected_experts_recv.shape: { token_selected_experts_recv .shape } , token_final_scales_recv.shape: { token_final_scales_recv .shape } , rank_experts.shape: { rank_experts .shape } , moe_output.shape: { moe_output .shape } "
719+ )
720+ moe_output [rank ] = fake_moe (
721+ hidden_states_recv .view (ep_size * max_num_tokens , hidden_states_recv .shape [- 1 ]),
722+ token_selected_experts_recv .view (
723+ ep_size * max_num_tokens , token_selected_experts_recv .shape [- 1 ]
724+ ),
725+ token_final_scales_recv .view (
726+ ep_size * max_num_tokens , token_final_scales_recv .shape [- 1 ]
727+ ),
728+ rank_experts , # experts for current rank
729+ is_ep = True ,
730+ ep_rank = rank ,
731+ num_experts_per_rank = num_experts_per_rank ,
732+ ).view (ep_size , max_num_tokens , hidden_states_recv .shape [- 1 ])
735733
736734 # Combine
737735 combined_output = moe_a2a .combine (
0 commit comments