Skip to content

Commit fede35b

Browse files
committed
Clear MOE workspace before each run
1 parent 8220b9e commit fede35b

File tree

1 file changed

+19
-16
lines changed

1 file changed

+19
-16
lines changed

tests/comm/test_mnnvl_a2a.py

Lines changed: 19 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -243,6 +243,7 @@ def run_moe_a2a_dispatch_single_rank(
243243
# Create MoeAlltoAll manager
244244
max_num_tokens = max(all_num_tokens)
245245

246+
MoeAlltoAll._WORKSPACE = None
246247
moe_a2a = MoeAlltoAll(
247248
mapping,
248249
max_num_tokens,
@@ -685,6 +686,7 @@ def test_moe_a2a_dispatch_moe_combine(ep_size, all_num_tokens, top_k):
685686
torch.cuda.synchronize()
686687

687688
# Initialize MoeAlltoAll
689+
MoeAlltoAll._WORKSPACE = None
688690
moe_a2a = MoeAlltoAll(
689691
mapping=mapping,
690692
max_num_tokens=max_num_tokens,
@@ -714,26 +716,27 @@ def test_moe_a2a_dispatch_moe_combine(ep_size, all_num_tokens, top_k):
714716
moe_output.zero_()
715717

716718
# Process each rank's tokens with local experts
717-
print(
718-
f"hidden_states_recv.shape: {hidden_states_recv.shape}, token_selected_experts_recv.shape: {token_selected_experts_recv.shape}, token_final_scales_recv.shape: {token_final_scales_recv.shape}, rank_experts.shape: {rank_experts.shape}, moe_output.shape: {moe_output.shape}"
719+
moe_output.copy_(
720+
fake_moe(
721+
hidden_states_recv.view(
722+
ep_size * max_num_tokens, hidden_states_recv.shape[-1]
723+
),
724+
token_selected_experts_recv.view(
725+
ep_size * max_num_tokens, token_selected_experts_recv.shape[-1]
726+
),
727+
token_final_scales_recv.view(
728+
ep_size * max_num_tokens, token_final_scales_recv.shape[-1]
729+
),
730+
rank_experts, # experts for current rank
731+
is_ep=True,
732+
ep_rank=rank,
733+
num_experts_per_rank=num_experts_per_rank,
734+
).view(ep_size, max_num_tokens, hidden_size)
719735
)
720-
moe_output[rank] = fake_moe(
721-
hidden_states_recv.view(ep_size * max_num_tokens, hidden_states_recv.shape[-1]),
722-
token_selected_experts_recv.view(
723-
ep_size * max_num_tokens, token_selected_experts_recv.shape[-1]
724-
),
725-
token_final_scales_recv.view(
726-
ep_size * max_num_tokens, token_final_scales_recv.shape[-1]
727-
),
728-
rank_experts, # experts for current rank
729-
is_ep=True,
730-
ep_rank=rank,
731-
num_experts_per_rank=num_experts_per_rank,
732-
).view(ep_size, max_num_tokens, hidden_states_recv.shape[-1])
733736

734737
# Combine
735738
combined_output = moe_a2a.combine(
736-
payload=moe_output.view(ep_size, max_num_tokens, hidden_size),
739+
payload=moe_output,
737740
runtime_max_tokens_per_rank=max_num_tokens,
738741
payload_in_workspace=True,
739742
)

0 commit comments

Comments
 (0)