@@ -243,6 +243,7 @@ def run_moe_a2a_dispatch_single_rank(
243243 # Create MoeAlltoAll manager
244244 max_num_tokens = max (all_num_tokens )
245245
246+ MoeAlltoAll ._WORKSPACE = None
246247 moe_a2a = MoeAlltoAll (
247248 mapping ,
248249 max_num_tokens ,
@@ -685,6 +686,7 @@ def test_moe_a2a_dispatch_moe_combine(ep_size, all_num_tokens, top_k):
685686 torch .cuda .synchronize ()
686687
687688 # Initialize MoeAlltoAll
689+ MoeAlltoAll ._WORKSPACE = None
688690 moe_a2a = MoeAlltoAll (
689691 mapping = mapping ,
690692 max_num_tokens = max_num_tokens ,
@@ -714,26 +716,27 @@ def test_moe_a2a_dispatch_moe_combine(ep_size, all_num_tokens, top_k):
714716 moe_output .zero_ ()
715717
716718 # Process each rank's tokens with local experts
717- print (
718- f"hidden_states_recv.shape: { hidden_states_recv .shape } , token_selected_experts_recv.shape: { token_selected_experts_recv .shape } , token_final_scales_recv.shape: { token_final_scales_recv .shape } , rank_experts.shape: { rank_experts .shape } , moe_output.shape: { moe_output .shape } "
719+ moe_output .copy_ (
720+ fake_moe (
721+ hidden_states_recv .view (
722+ ep_size * max_num_tokens , hidden_states_recv .shape [- 1 ]
723+ ),
724+ token_selected_experts_recv .view (
725+ ep_size * max_num_tokens , token_selected_experts_recv .shape [- 1 ]
726+ ),
727+ token_final_scales_recv .view (
728+ ep_size * max_num_tokens , token_final_scales_recv .shape [- 1 ]
729+ ),
730+ rank_experts , # experts for current rank
731+ is_ep = True ,
732+ ep_rank = rank ,
733+ num_experts_per_rank = num_experts_per_rank ,
734+ ).view (ep_size , max_num_tokens , hidden_size )
719735 )
720- moe_output [rank ] = fake_moe (
721- hidden_states_recv .view (ep_size * max_num_tokens , hidden_states_recv .shape [- 1 ]),
722- token_selected_experts_recv .view (
723- ep_size * max_num_tokens , token_selected_experts_recv .shape [- 1 ]
724- ),
725- token_final_scales_recv .view (
726- ep_size * max_num_tokens , token_final_scales_recv .shape [- 1 ]
727- ),
728- rank_experts , # experts for current rank
729- is_ep = True ,
730- ep_rank = rank ,
731- num_experts_per_rank = num_experts_per_rank ,
732- ).view (ep_size , max_num_tokens , hidden_states_recv .shape [- 1 ])
733736
734737 # Combine
735738 combined_output = moe_a2a .combine (
736- payload = moe_output . view ( ep_size , max_num_tokens , hidden_size ) ,
739+ payload = moe_output ,
737740 runtime_max_tokens_per_rank = max_num_tokens ,
738741 payload_in_workspace = True ,
739742 )
0 commit comments