Skip to content

Commit 8220b9e

Browse files
committed
Update tests with fake_moe properly
1 parent 25074e6 commit 8220b9e

File tree

2 files changed

+23
-25
lines changed

2 files changed

+23
-25
lines changed

flashinfer/comm/trtllm_moe_a2a.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -529,7 +529,7 @@ def get_combine_payload_tensor_in_workspace(
529529
dtype: Data type for the tensor
530530
531531
Returns:
532-
tensor: [ep_size * max_tokens, hidden_size] workspace-backed tensor
532+
tensor: [ep_size, max_tokens, hidden_size] workspace-backed tensor
533533
"""
534534
if self._state.phase != "dispatched":
535535
raise RuntimeError(
@@ -539,7 +539,7 @@ def get_combine_payload_tensor_in_workspace(
539539
element_size = torch.tensor([], dtype=dtype).element_size()
540540
return moe_a2a_wrap_payload_tensor_in_workspace(
541541
self.workspace[self.ep_rank, :],
542-
[self.ep_size * runtime_max_tokens_per_rank],
542+
[self.ep_size, runtime_max_tokens_per_rank],
543543
self._state.combine_payload_offset,
544544
self._state.combine_payload_offset
545545
+ self.ep_size * runtime_max_tokens_per_rank * hidden_size * element_size,

tests/comm/test_mnnvl_a2a.py

Lines changed: 21 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -659,11 +659,6 @@ def test_moe_a2a_dispatch_moe_combine(ep_size, all_num_tokens, top_k):
659659
hidden_states = payloads[0]
660660
token_final_scales = payloads[2]
661661

662-
# Create experts for this rank
663-
experts = create_experts(
664-
num_experts_per_rank, hidden_size, rank, "cuda", dtype=torch.bfloat16
665-
)
666-
667662
# Compute reference (single-GPU MoE)
668663
all_experts = torch.cat(
669664
[
@@ -675,14 +670,16 @@ def test_moe_a2a_dispatch_moe_combine(ep_size, all_num_tokens, top_k):
675670
dim=0,
676671
)
677672

673+
rank_experts = create_experts(
674+
num_experts_per_rank, hidden_size, rank, "cuda", dtype=torch.bfloat16
675+
)
676+
678677
reference_output = fake_moe(
679678
hidden_states,
680679
token_selected_experts,
681680
token_final_scales,
682681
all_experts,
683-
is_ep=True,
684-
ep_rank=rank,
685-
num_experts_per_rank=num_experts_per_rank,
682+
is_ep=False,
686683
)
687684

688685
torch.cuda.synchronize()
@@ -717,21 +714,22 @@ def test_moe_a2a_dispatch_moe_combine(ep_size, all_num_tokens, top_k):
717714
moe_output.zero_()
718715

719716
# Process each rank's tokens with local experts
720-
for source_rank in range(ep_size):
721-
source_num_tokens = all_num_tokens[source_rank]
722-
for token_idx in range(source_num_tokens):
723-
for k in range(top_k):
724-
expert_id = token_selected_experts_recv[
725-
source_rank, token_idx, k
726-
].item()
727-
local_expert_id = expert_id - rank * num_experts_per_rank
728-
729-
if 0 <= local_expert_id < num_experts_per_rank:
730-
token_hidden = hidden_states_recv[source_rank, token_idx]
731-
scale = token_final_scales_recv[source_rank, token_idx, k]
732-
expert_out = token_hidden @ experts[local_expert_id]
733-
output_idx = source_rank * max_num_tokens + token_idx
734-
moe_output[output_idx] += expert_out * scale
717+
print(
718+
f"hidden_states_recv.shape: {hidden_states_recv.shape}, token_selected_experts_recv.shape: {token_selected_experts_recv.shape}, token_final_scales_recv.shape: {token_final_scales_recv.shape}, rank_experts.shape: {rank_experts.shape}, moe_output.shape: {moe_output.shape}"
719+
)
720+
moe_output[rank] = fake_moe(
721+
hidden_states_recv.view(ep_size * max_num_tokens, hidden_states_recv.shape[-1]),
722+
token_selected_experts_recv.view(
723+
ep_size * max_num_tokens, token_selected_experts_recv.shape[-1]
724+
),
725+
token_final_scales_recv.view(
726+
ep_size * max_num_tokens, token_final_scales_recv.shape[-1]
727+
),
728+
rank_experts, # experts for current rank
729+
is_ep=True,
730+
ep_rank=rank,
731+
num_experts_per_rank=num_experts_per_rank,
732+
).view(ep_size, max_num_tokens, hidden_states_recv.shape[-1])
735733

736734
# Combine
737735
combined_output = moe_a2a.combine(

0 commit comments

Comments
 (0)