From bb5253c365898830aa3a7237b824a33494d6bfec Mon Sep 17 00:00:00 2001 From: djns99 <40156487+djns99@users.noreply.github.com> Date: Tue, 18 Nov 2025 14:00:43 +1300 Subject: [PATCH 01/25] Initial draft of cursor port --- csrc/nv_internal/cpp/common/envUtils.cpp | 40 +- .../tensorrt_llm/common/envUtils.h | 11 +- .../moeAlltoAllKernels.cu | 847 ++++++++++++++++++ .../communicationKernels/moeAlltoAllKernels.h | 181 ++++ .../tensorrt_llm/thop/moeAlltoAllMeta.h | 61 ++ csrc/trtllm_moe_a2a.cu | 396 ++++++++ docs/api/comm.rst | 15 + flashinfer/aot.py | 2 + flashinfer/comm/__init__.py | 12 + flashinfer/comm/trtllm_moe_a2a.py | 428 +++++++++ flashinfer/jit/__init__.py | 1 + flashinfer/jit/comm.py | 29 + tests/comm/test_mnnvl_a2a.py | 366 ++++++++ 13 files changed, 2382 insertions(+), 7 deletions(-) create mode 100644 csrc/nv_internal/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.cu create mode 100644 csrc/nv_internal/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.h create mode 100644 csrc/nv_internal/tensorrt_llm/thop/moeAlltoAllMeta.h create mode 100644 csrc/trtllm_moe_a2a.cu create mode 100644 flashinfer/comm/trtllm_moe_a2a.py create mode 100644 tests/comm/test_mnnvl_a2a.py diff --git a/csrc/nv_internal/cpp/common/envUtils.cpp b/csrc/nv_internal/cpp/common/envUtils.cpp index e2ee31261c..2f60d0778b 100644 --- a/csrc/nv_internal/cpp/common/envUtils.cpp +++ b/csrc/nv_internal/cpp/common/envUtils.cpp @@ -222,11 +222,6 @@ bool getEnvDisaggLayerwise() { return disaggLayerwise; } -bool getEnvParallelCacheSend() { - static bool const parallelCacheSend = getBoolEnv("TRTLLM_PARALLEL_CACHE_SEND"); - return parallelCacheSend; -} - bool getEnvRequestKVCacheConcurrent() { static bool const requestKVCacheConcurrent = getBoolEnv("TRTLLM_REQUEST_KV_CACHE_CONCURRENT"); return requestKVCacheConcurrent; @@ -277,7 +272,7 @@ size_t getEnvAllReduceWorkspaceSize() { return workspaceSize; } -std::string getEnvKVCacheTransferOutputPath() { +std::string const& getEnvKVCacheTimeOutputPath() { static std::string outputPath = getStrEnv("TRTLLM_KVCACHE_TIME_OUTPUT_PATH").value_or(""); return outputPath; } @@ -328,4 +323,37 @@ uint16_t getEnvNixlPort() { bool getEnvDisaggBenchmarkGenOnly() { return getBoolEnv("TRTLLM_DISAGG_BENCHMARK_GEN_ONLY"); } +bool getEnvMoeA2AOneBlockPerToken() { + // Default true; return false only if env set to "0" + static std::optional const val = getIntEnv("TLLM_MOE_A2A_ONE_BLOCK_PER_TOKEN"); + if (!val.has_value()) { + return true; + } + return val.value() != 0; +} + +static int sanitizeBlockSize(std::optional const& val) { + // Default 256 when not set or invalid + int block = val.value_or(256); + // Clamp to sane CUDA bounds and warp multiples + if (block <= 0) block = 256; + if (block > 1024) block = 1024; + // Round to nearest multiple of 32 (warp size) + block = (block + 31) / 32 * 32; + if (block == 0) block = 256; + return block; +} + +int getEnvMoeA2ADispatchBlockSize() { + static int const kBlock = sanitizeBlockSize(getIntEnv("TLLM_MOE_A2A_DISPATCH_BLOCK_SIZE")); + return kBlock; +} + +int getEnvMoeA2ACombineBlockSize() { + static int const kBlock = sanitizeBlockSize(getIntEnv("TLLM_MOE_A2A_COMBINE_BLOCK_SIZE")); + return kBlock; +} + +bool getEnvEplbForceGdrcopy() { return getBoolEnv("TRTLLM_EPLB_FORCE_GDRCOPY"); } + } // namespace tensorrt_llm::common diff --git a/csrc/nv_internal/tensorrt_llm/common/envUtils.h b/csrc/nv_internal/tensorrt_llm/common/envUtils.h index 887162e786..cdbdd8c414 100644 --- a/csrc/nv_internal/tensorrt_llm/common/envUtils.h +++ b/csrc/nv_internal/tensorrt_llm/common/envUtils.h @@ -64,7 +64,7 @@ bool getEnvDisableKVCacheTransferOverlap(); bool getEnvEnableReceiveKVCacheParallel(); -std::string getEnvKVCacheTransferOutputPath(); +std::string const& getEnvKVCacheTimeOutputPath(); bool getEnvTryZCopyForKVCacheTransfer(); @@ -92,4 +92,13 @@ size_t getEnvKVCacheSendMaxConcurrenceNum(); size_t getEnvMemSizeForKVCacheTransferBuffer(); +// Whether to use one block per token for MoE A2A kernels (default true). +bool getEnvMoeA2AOneBlockPerToken(); + +// TODO: For DEV purpose temporarily. +// Block size (threads per block) for MoE A2A Dispatch kernels (default 256 if unset or invalid) +int getEnvMoeA2ADispatchBlockSize(); +// Block size (threads per block) for MoE A2A Combine kernels (default 256 if unset or invalid) +int getEnvMoeA2ACombineBlockSize(); + } // namespace tensorrt_llm::common diff --git a/csrc/nv_internal/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.cu b/csrc/nv_internal/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.cu new file mode 100644 index 0000000000..91dccdc33f --- /dev/null +++ b/csrc/nv_internal/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.cu @@ -0,0 +1,847 @@ +/* + * Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +#include +#include + +#include "flashinfer/exception.h" +#include "flashinfer/utils.cuh" +#include "flashinfer/vec_dtypes.cuh" +#include "tensorrt_llm/common/dataType.h" +#include "tensorrt_llm/common/envUtils.h" +#include "tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.h" + +namespace tensorrt_llm::kernels::mnnvl_throughput { + +#define ENABLE_DEBUG_PRINT 0 +#define DISABLE_SYNC_FOR_PROFILING 0 + +// Helper function for ceiling division +template +__host__ __device__ inline T ceilDiv(T m, T n) { + return (m + n - 1) / n; +} + +// Macros for concise launch-time specialization +#define SWITCH_BOOL(flag, NAME, ...) \ + if (flag) { \ + constexpr bool NAME = true; \ + __VA_ARGS__ \ + } else { \ + constexpr bool NAME = false; \ + __VA_ARGS__ \ + } + +#define SWITCH_TOP_K(top_k, TOP_K, ...) \ + switch (top_k) { \ + case 8: { \ + constexpr int TOP_K = 8; \ + __VA_ARGS__; \ + break; \ + } \ + case 4: { \ + constexpr int TOP_K = 4; \ + __VA_ARGS__; \ + break; \ + } \ + case 2: { \ + constexpr int TOP_K = 2; \ + __VA_ARGS__; \ + break; \ + } \ + case 1: { \ + constexpr int TOP_K = 1; \ + __VA_ARGS__; \ + break; \ + } \ + default: { \ + FLASHINFER_CHECK(false, "Unsupported top_k"); \ + } \ + } + +#define SWITCH_DTYPE(dtype, TYPE, ...) \ + switch (dtype) { \ + case nvinfer1::DataType::kHALF: { \ + using TYPE = half; \ + __VA_ARGS__; \ + break; \ + } \ + case nvinfer1::DataType::kBF16: { \ + using TYPE = __nv_bfloat16; \ + __VA_ARGS__; \ + break; \ + } \ + case nvinfer1::DataType::kFLOAT: { \ + using TYPE = float; \ + __VA_ARGS__; \ + break; \ + } \ + default: { \ + FLASHINFER_CHECK(false, "Unsupported dtype for moe_a2a_combine"); \ + } \ + } + +#define SWITCH_POLICY(one_block_per_token, POLICY, ...) \ + if (one_block_per_token) { \ + using POLICY = BlockPolicy; \ + __VA_ARGS__ \ + } else { \ + using POLICY = WarpPolicy; \ + __VA_ARGS__ \ + } + +// ============================================================================ +// Helper Functions for Expert-to-Rank Mapping +// ============================================================================ + +__device__ int compute_target_rank_id(int expert_id, int num_experts_per_rank) { + // Compute which rank owns a given expert using contiguous partitioning + // Experts are divided evenly across EP ranks: + // - Rank 0 gets experts [0, num_experts_per_rank) + // - Rank 1 gets experts [num_experts_per_rank, 2*num_experts_per_rank) + // - etc. + // Example: 32 experts, 4 ranks -> 8 experts per rank + // - Rank 0: experts 0-7 + // - Rank 1: experts 8-15 + // - Rank 2: experts 16-23 + // - Rank 3: experts 24-31 + return expert_id / num_experts_per_rank; +} + +// ============================================================================ +// Helper Functions for Vectorized Memory Operations +// ============================================================================ + +struct WarpPolicy { + __device__ static int stride() { return warpSize; } + + __device__ static int offset() { return (threadIdx.x % warpSize); } + + __device__ static int token_idx() { return (blockIdx.x * blockDim.x + threadIdx.x) / warpSize; } + + __device__ static void sync() { __syncwarp(); } +}; + +struct BlockPolicy { + __device__ static int stride() { return blockDim.x; } + + __device__ static int offset() { return threadIdx.x; } + + __device__ static int token_idx() { return blockIdx.x; } + + __device__ static void sync() { __syncthreads(); } +}; + +template +__device__ void vectorized_copy_impl(void* dst, void const* src, int size) { + using flashinfer::vec_t; + + uint8_t* dst_ptr = static_cast(dst); + uint8_t const* src_ptr = static_cast(src); + + int const stride = ThreadingPolicy::stride() * VEC_SIZE; + + for (int offset = ThreadingPolicy::offset() * VEC_SIZE; offset < size; offset += stride) { + vec_t v; + v.load(src_ptr + offset); + v.store(dst_ptr + offset); + } +} + +template +__device__ void vectorized_copy(void* dst, void const* src, int size) { + if (size % 16 == 0) { + vectorized_copy_impl<16, ThreadingPolicy>(dst, src, size); + } else if (size % 8 == 0) { + vectorized_copy_impl<8, ThreadingPolicy>(dst, src, size); + } else if (size % 4 == 0) { + vectorized_copy_impl<4, ThreadingPolicy>(dst, src, size); + } else if (size % 2 == 0) { + vectorized_copy_impl<2, ThreadingPolicy>(dst, src, size); + } else { + vectorized_copy_impl<1, ThreadingPolicy>(dst, src, size); + } +} + +// Vectorized dispatch: load one vec from source and write to up to TOP_K destinations +template +__device__ void vectorized_dispatch_impl(uint8_t const* src_ptr, int bytes_per_token, int rank_id, + int max_tokens_per_rank, int payload_idx, + DispatchKernelPointers const& ptrs, + int const* topk_target_ranks, + int const* topk_send_indices) { + using flashinfer::vec_t; + + // Precompute destination base pointers per k + uint8_t* dst_base_k[TOP_K]; +#pragma unroll + for (int k = 0; k < TOP_K; ++k) { + int dst_idx_k = topk_send_indices[k]; + int target_rank_k = topk_target_ranks[k]; + if (dst_idx_k < 0) { + dst_base_k[k] = nullptr; + continue; + } + uint8_t* dst_data = static_cast(ptrs.recv_buffers[target_rank_k][payload_idx]); + size_t base_source_rank = + static_cast(rank_id) * static_cast(max_tokens_per_rank) + + static_cast(dst_idx_k); + size_t base_token = base_source_rank * static_cast(bytes_per_token); + dst_base_k[k] = dst_data + base_token; + } + + // TODO: process all payloads. index could be reused. + int const stride = ThreadingPolicy::stride() * VEC_SIZE; + for (int offset = ThreadingPolicy::offset() * VEC_SIZE; offset < bytes_per_token; + offset += stride) { + vec_t v; + v.load(src_ptr + offset); + +#pragma unroll + for (int k = 0; k < TOP_K; ++k) { + uint8_t* dst_base = dst_base_k[k]; + if (dst_base == nullptr) { + continue; + } + v.store(dst_base + offset); + } + } +} + +template +__device__ void vectorized_dispatch(uint8_t const* src_ptr, int bytes_per_token, int rank_id, + int max_tokens_per_rank, int payload_idx, + DispatchKernelPointers const& ptrs, + int const* topk_target_ranks, int const* topk_send_indices) { + if (bytes_per_token % 16 == 0) { + vectorized_dispatch_impl<16, TOP_K, ThreadingPolicy>(src_ptr, bytes_per_token, rank_id, + max_tokens_per_rank, payload_idx, ptrs, + topk_target_ranks, topk_send_indices); + } else if (bytes_per_token % 8 == 0) { + vectorized_dispatch_impl<8, TOP_K, ThreadingPolicy>(src_ptr, bytes_per_token, rank_id, + max_tokens_per_rank, payload_idx, ptrs, + topk_target_ranks, topk_send_indices); + } else if (bytes_per_token % 4 == 0) { + vectorized_dispatch_impl<4, TOP_K, ThreadingPolicy>(src_ptr, bytes_per_token, rank_id, + max_tokens_per_rank, payload_idx, ptrs, + topk_target_ranks, topk_send_indices); + } else if (bytes_per_token % 2 == 0) { + vectorized_dispatch_impl<2, TOP_K, ThreadingPolicy>(src_ptr, bytes_per_token, rank_id, + max_tokens_per_rank, payload_idx, ptrs, + topk_target_ranks, topk_send_indices); + } else { + vectorized_dispatch_impl<1, TOP_K, ThreadingPolicy>(src_ptr, bytes_per_token, rank_id, + max_tokens_per_rank, payload_idx, ptrs, + topk_target_ranks, topk_send_indices); + } +} + +__global__ void moeA2APrepareDispatchKernel(int* send_counters, int* local_token_counter, + int ep_size, uint32_t* flag_val_ptr) { + int idx = blockIdx.x * blockDim.x + threadIdx.x; + // Zero send_counters + if (idx < ep_size) { + send_counters[idx] = 0; + } + // Zero local_token_counter and increment flag_val + if (idx == 0) { + *local_token_counter = 0; + // Increment flag_val for this dispatch round + *flag_val_ptr = *flag_val_ptr + 1; + } +} + +// ============================================================================ +// Generic Dispatch Kernel Implementation +// One warp per token design: +// - Each CTA has 256 threads = 8 warps +// - Each warp independently processes one token and all its payloads +// - Better GPU utilization and reduced synchronization overhead +// ============================================================================ + +template +__global__ void moeA2ADispatchKernel( + int32_t const* token_selected_experts, // [local_num_tokens, TOP_K] + const DispatchKernelPointers ptrs, // Struct containing all kernel pointers + int num_payloads, // Number of payloads + int max_tokens_per_rank, // Maximum tokens per rank + int local_num_tokens, int rank_id, int ep_size, int num_experts_per_rank) { + int thread_idx = ThreadingPolicy::offset(); + int local_token_idx = ThreadingPolicy::token_idx(); + + if (local_token_idx >= local_num_tokens) { + return; + } + + // Prepare per-policy shared-memory tiles for this token + extern __shared__ int smem[]; + int* smem_topk_target_ranks; + int* smem_topk_send_indices; + int warps_per_block = blockDim.x / warpSize; + if constexpr (std::is_same::value) { + int lane_id = threadIdx.x / warpSize; + smem_topk_target_ranks = smem + lane_id * TOP_K; + smem_topk_send_indices = smem + warps_per_block * TOP_K + lane_id * TOP_K; + } else { + smem_topk_target_ranks = smem; + smem_topk_send_indices = smem + TOP_K; + } + + uint64_t already_copied = 0; + for (int k = 0; k < TOP_K; k++) { + int expert_id = token_selected_experts[local_token_idx * TOP_K + k]; + // Use contiguous partitioning to determine target rank + int target_rank = compute_target_rank_id(expert_id, num_experts_per_rank); + + if (already_copied & (1ULL << target_rank)) { + if (thread_idx == 0) { + ptrs.topk_target_ranks[local_token_idx * TOP_K + k] = -1; + ptrs.topk_send_indices[local_token_idx * TOP_K + k] = -1; + // Mirror to shared memory immediately + smem_topk_target_ranks[k] = -1; + smem_topk_send_indices[k] = -1; + } + continue; + } + + // Only one thread per warp should increment the counter + int dst_token_idx; + if (thread_idx == 0) { + dst_token_idx = atomicAdd(&ptrs.send_counters[target_rank], 1); + + ptrs.topk_target_ranks[local_token_idx * TOP_K + k] = target_rank; + ptrs.topk_send_indices[local_token_idx * TOP_K + k] = dst_token_idx; + // Mirror to shared memory immediately + smem_topk_target_ranks[k] = target_rank; + smem_topk_send_indices[k] = dst_token_idx; + } + already_copied |= 1ULL << target_rank; + } + // Sync before dispatching data + ThreadingPolicy::sync(); + + // Read staged routing once into registers per thread + int topk_target_ranks[TOP_K]; + int topk_send_indices[TOP_K]; +#pragma unroll + for (int k = 0; k < TOP_K; ++k) { + topk_target_ranks[k] = smem_topk_target_ranks[k]; + topk_send_indices[k] = smem_topk_send_indices[k]; + } + + // Perform a single source load and TOP_K fanout per payload + for (int payload_idx = 0; payload_idx < num_payloads; payload_idx++) { + uint8_t const* src_data = static_cast(ptrs.src_data_ptrs[payload_idx]); + int bytes_per_token = ptrs.payload_bytes_per_token[payload_idx]; + uint8_t const* src_ptr = src_data + local_token_idx * bytes_per_token; + + vectorized_dispatch(src_ptr, bytes_per_token, rank_id, + max_tokens_per_rank, payload_idx, ptrs, + topk_target_ranks, topk_send_indices); + } + + ThreadingPolicy::sync(); + + bool is_first_warp = threadIdx.x / warpSize == 0; + if (is_first_warp) { + int lane_id = threadIdx.x % warpSize; + + bool is_last_token = false; + if (lane_id == 0) { + int cnt = atomicAdd(ptrs.local_token_counter, 1); + is_last_token = cnt + 1 == local_num_tokens; + } + is_last_token = __shfl_sync(0xffffffff, is_last_token, 0); + + if (is_last_token) { +// Store send_counters to recv_counters +#pragma unroll 1 // No unroll as one iter is typically enough + for (int target_rank = lane_id; target_rank < ep_size; target_rank += warpSize) { + int send_count = ptrs.send_counters[target_rank]; + ptrs.recv_counters[target_rank][rank_id] = send_count; + } + +#if !DISABLE_SYNC_FOR_PROFILING + uint32_t expected_value = *ptrs.flag_val; + + asm volatile("fence.release.sys;"); +#pragma unroll 1 // No unroll as one iter is typically enough + for (int target_rank = lane_id; target_rank < ep_size; target_rank += warpSize) { + uint32_t* flag_addr = &ptrs.completion_flags[target_rank][rank_id]; + asm volatile("st.relaxed.sys.u32 [%0], %1;" ::"l"(flag_addr), "r"(expected_value)); + +#if ENABLE_DEBUG_PRINT + printf("dispatch: +++Rank %d setting completion flag to %d for rank %d\n", rank_id, + expected_value, target_rank); +#endif + } + +#pragma unroll 1 // No unroll + for (int peer_rank = lane_id; peer_rank < ep_size; peer_rank += warpSize) { + bool flag_set = false; + do { + uint32_t* flag_ptr = &ptrs.completion_flags[rank_id][peer_rank]; + uint32_t flag_value; + // Acquire load to ensure visibility of peer's release-store + asm volatile("ld.relaxed.sys.u32 %0, [%1];" : "=r"(flag_value) : "l"(flag_ptr)); +#if ENABLE_DEBUG_PRINT + printf( + "combine: ---Rank %d received completion flag from rank %d, flag_value: %d, " + "expected_value: " + "%d, address: %p\n", + rank_id, peer_rank, flag_value, expected_value, flag_ptr); +#endif + flag_set = flag_value == expected_value; + } while (!flag_set); + } + // asm volatile("fence.acquire.sys;"); +#endif + } + } +} + +void moe_a2a_prepare_dispatch_launch(MoeA2ADispatchParams const& params) { + moeA2APrepareDispatchKernel<<<1, params.ep_size, 0, params.stream>>>( + params.send_counters, params.local_token_counter, params.ep_size, params.flag_val); +} + +// ============================================================================ +// Launch Functions +// ============================================================================ + +void moe_a2a_dispatch_launch(MoeA2ADispatchParams const& params) { + // Validate parameters + TLLM_CHECK(params.top_k > 0 && params.top_k <= kMaxTopK); + TLLM_CHECK(params.ep_size > 0 && params.ep_size <= kMaxRanks); + TLLM_CHECK(params.local_num_tokens > 0); + TLLM_CHECK(params.num_payloads > 0 && params.num_payloads <= kMaxPayloads); + + // Prepare kernel pointers struct + DispatchKernelPointers kernel_ptrs = {}; + + // Fill source data pointers and payload sizes + for (int i = 0; i < params.num_payloads; i++) { + kernel_ptrs.src_data_ptrs[i] = params.payloads[i].src_data; + kernel_ptrs.payload_bytes_per_token[i] = + params.payloads[i].element_size * params.payloads[i].elements_per_token; + } + + // Fill receive buffer pointers + for (int target_rank = 0; target_rank < params.ep_size; target_rank++) { + kernel_ptrs.recv_counters[target_rank] = params.recv_counters[target_rank]; + for (int payload = 0; payload < params.num_payloads; payload++) { + kernel_ptrs.recv_buffers[target_rank][payload] = params.recv_buffers[target_rank][payload]; + } + } + + // Copy completion flag pointers + for (int i = 0; i < params.ep_size; i++) { + kernel_ptrs.completion_flags[i] = params.completion_flags[i]; + } + kernel_ptrs.flag_val = params.flag_val; + + // Copy communication tracking pointers + kernel_ptrs.send_counters = params.send_counters; + kernel_ptrs.local_token_counter = params.local_token_counter; + kernel_ptrs.topk_target_ranks = params.topk_target_ranks; + kernel_ptrs.topk_send_indices = params.topk_send_indices; + + int const kBlockSize = tensorrt_llm::common::getEnvMoeA2ADispatchBlockSize(); + constexpr int kWarpSize = 32; + int const kWarpsPerBlock = kBlockSize / kWarpSize; + + // Configure kernel launch + if (params.one_block_per_token) { + int grid_size = params.local_num_tokens; + int shared_bytes = 2 * params.top_k * (int)sizeof(int); + SWITCH_TOP_K(params.top_k, TOP_K, + moeA2ADispatchKernel + <<>>( + params.token_selected_experts, kernel_ptrs, params.num_payloads, + params.max_tokens_per_rank, params.local_num_tokens, params.ep_rank, + params.ep_size, params.num_experts_per_rank)) + } else { + int grid_size = ceilDiv(params.local_num_tokens, kWarpsPerBlock); + int shared_bytes = 2 * kWarpsPerBlock * params.top_k * (int)sizeof(int); + SWITCH_TOP_K(params.top_k, TOP_K, + moeA2ADispatchKernel + <<>>( + params.token_selected_experts, kernel_ptrs, params.num_payloads, + params.max_tokens_per_rank, params.local_num_tokens, params.ep_rank, + params.ep_size, params.num_experts_per_rank)) + } +} + +// ============================================================================ +// Combine kernels +// ============================================================================ + +// Accumulate across all valid ranks into registers, then store once per segment +template +__device__ void vectorized_combine_impl(T* dst_typed_base, int size_per_token, int rank_id, + int max_tokens_per_rank, + CombineKernelPointers const& ptrs) { + constexpr int elems_per_vec = VEC_SIZE / sizeof(T); + using flashinfer::vec_t; + + uint8_t* dst_bytes = reinterpret_cast(dst_typed_base); + + int const stride = ThreadingPolicy::stride() * VEC_SIZE; + int const local_token_idx = ThreadingPolicy::token_idx(); + + for (int offset = ThreadingPolicy::offset() * VEC_SIZE; offset < size_per_token; + offset += stride) { + vec_t acc[TOP_K]; + +// Unrolled K accumulation using compact top-k lists +#pragma unroll + for (int k = 0; k < TOP_K; ++k) { + int target_rank = ptrs.topk_target_ranks[local_token_idx * TOP_K + k]; + int dst_idx = ptrs.topk_send_indices[local_token_idx * TOP_K + k]; + if (dst_idx < 0) { + acc[k].fill(0); + continue; + } + + uint8_t const* recv_buffer = static_cast(ptrs.recv_buffers[target_rank][0]); + size_t base_source_rank = + static_cast(rank_id) * static_cast(max_tokens_per_rank) + + static_cast(dst_idx); + size_t base_token = base_source_rank * static_cast(size_per_token); + + // Load directly into the per-k accumulator; reduce across k below + acc[k].load(recv_buffer + base_token + offset); + } + + // Reduce acc[TOP_K] into acc[0] + if constexpr (TOP_K == 8) { + T* a0 = reinterpret_cast(&acc[0]); + T* a1 = reinterpret_cast(&acc[1]); + T* a2 = reinterpret_cast(&acc[2]); + T* a3 = reinterpret_cast(&acc[3]); + T* a4 = reinterpret_cast(&acc[4]); + T* a5 = reinterpret_cast(&acc[5]); + T* a6 = reinterpret_cast(&acc[6]); + T* a7 = reinterpret_cast(&acc[7]); +#pragma unroll + for (int j = 0; j < elems_per_vec; ++j) { + a0[j] += a1[j]; + a2[j] += a3[j]; + a4[j] += a5[j]; + a6[j] += a7[j]; + } +#pragma unroll + for (int j = 0; j < elems_per_vec; ++j) { + a0[j] += a2[j]; + a4[j] += a6[j]; + } +#pragma unroll + for (int j = 0; j < elems_per_vec; ++j) { + a0[j] += a4[j]; + } + } else if constexpr (TOP_K == 4) { + T* a0 = reinterpret_cast(&acc[0]); + T* a1 = reinterpret_cast(&acc[1]); + T* a2 = reinterpret_cast(&acc[2]); + T* a3 = reinterpret_cast(&acc[3]); +#pragma unroll + for (int j = 0; j < elems_per_vec; ++j) { + a0[j] += a1[j]; + a2[j] += a3[j]; + } +#pragma unroll + for (int j = 0; j < elems_per_vec; ++j) { + a0[j] += a2[j]; + } + } else if constexpr (TOP_K == 2) { + T* a0 = reinterpret_cast(&acc[0]); + T* a1 = reinterpret_cast(&acc[1]); +#pragma unroll + for (int j = 0; j < elems_per_vec; ++j) { + a0[j] += a1[j]; + } + } else if constexpr (TOP_K == 1) { + // nothing to do + } else { + // Generic fallback: accumulate all into acc[0] + T* a0 = reinterpret_cast(&acc[0]); +#pragma unroll + for (int k = 1; k < TOP_K; ++k) { + T* ak = reinterpret_cast(&acc[k]); +#pragma unroll + for (int j = 0; j < elems_per_vec; ++j) { + a0[j] += ak[j]; + } + } + } + + acc[0].store(dst_bytes + offset); + } +} + +// Wrapper that selects vector width based on size_per_token alignment +template +__device__ void vectorized_combine(T* dst_typed_base, int size_per_token, int rank_id, + int max_tokens_per_rank, CombineKernelPointers const& ptrs) { + if (size_per_token % 16 == 0) { + vectorized_combine_impl<16, TOP_K, ThreadingPolicy, T>(dst_typed_base, size_per_token, rank_id, + max_tokens_per_rank, ptrs); + } else if (size_per_token % 8 == 0) { + vectorized_combine_impl<8, TOP_K, ThreadingPolicy, T>(dst_typed_base, size_per_token, rank_id, + max_tokens_per_rank, ptrs); + } else if (size_per_token % 4 == 0) { + vectorized_combine_impl<4, TOP_K, ThreadingPolicy, T>(dst_typed_base, size_per_token, rank_id, + max_tokens_per_rank, ptrs); + } else if (size_per_token % 2 == 0) { + vectorized_combine_impl<2, TOP_K, ThreadingPolicy, T>(dst_typed_base, size_per_token, rank_id, + max_tokens_per_rank, ptrs); + } else { + vectorized_combine_impl<1, TOP_K, ThreadingPolicy, T>(dst_typed_base, size_per_token, rank_id, + max_tokens_per_rank, ptrs); + } +} + +// Copy payload to recv buffer using vectorized copy; supports warp/block token mapping +template +__global__ void moeA2APrepareCombineKernel(uint8_t* recv_buffer_bytes, uint8_t const* payload_bytes, + int bytes_per_token, int ep_size, + int max_tokens_per_rank, uint32_t* flag_val_ptr, + int const* recv_counters) { + if (blockIdx.x == 0 && threadIdx.x == 0) { + // Increment flag_val for this combine round + *flag_val_ptr = *flag_val_ptr + 1; + } + + if (payload_bytes == nullptr) return; + + int slot_idx = ThreadingPolicy::token_idx(); + + int total_slots = ep_size * max_tokens_per_rank; + if (slot_idx >= total_slots) return; + + // Map global token to (source_rank, token_idx) + int source_rank = slot_idx / max_tokens_per_rank; + int token_idx = slot_idx % max_tokens_per_rank; + + // Skip invalid tokens beyond per-source recv count + if (token_idx >= recv_counters[source_rank]) return; + + // Calculate source and destination pointers for this token + size_t slot_offset = static_cast(slot_idx) * bytes_per_token; + uint8_t* dst_ptr = recv_buffer_bytes + slot_offset; + uint8_t const* src_ptr = payload_bytes + slot_offset; + + // Copy one token's data using vectorized copy with policy + vectorized_copy(dst_ptr, src_ptr, bytes_per_token); +} + +// ============================================================================ +// Generic Combine Kernel Implementation (Templated by data type) +// ============================================================================ + +template +__global__ void moeA2ACombineKernel( + const CombineKernelPointers ptrs, // Combine-specific struct, src_data_ptrs[0] is output + int max_tokens_per_rank, int elements_per_token, int local_num_tokens, int rank_id, + int ep_size) { + int local_token_idx = ThreadingPolicy::token_idx(); + int const size_per_token = elements_per_token * sizeof(T); + + if (local_token_idx >= local_num_tokens) { + return; + } + +#if !DISABLE_SYNC_FOR_PROFILING + // In-kernel readiness synchronization at start of combine: + // - One warp signals readiness to all peers with current flag_val. + // - The first warp of each block waits for all peers' readiness (equality), then __syncthreads. + bool is_first_warp = threadIdx.x / warpSize == 0; + if (is_first_warp) { + int lane_id = threadIdx.x % warpSize; + uint32_t expected_value = *ptrs.flag_val; + + if (blockIdx.x == 0) { + // asm volatile("fence.release.sys;"); +#pragma unroll 1 // No unroll + for (int peer_rank = lane_id; peer_rank < ep_size; peer_rank += warpSize) { + uint32_t* flag_addr = &ptrs.completion_flags[peer_rank][rank_id]; + asm volatile("st.relaxed.sys.u32 [%0], %1;" ::"l"(flag_addr), "r"(expected_value)); +#if ENABLE_DEBUG_PRINT + printf("combine: +++Rank %d setting completion flag to %d for rank %d\n", rank_id, + expected_value, peer_rank); +#endif + } + } + +#pragma unroll 1 // No unroll + for (int peer_rank = lane_id; peer_rank < ep_size; peer_rank += warpSize) { + bool flag_set = false; + do { + uint32_t* flag_ptr = &ptrs.completion_flags[rank_id][peer_rank]; + uint32_t flag_value; + // Acquire load to ensure visibility of peer's release-store + asm volatile("ld.relaxed.sys.u32 %0, [%1];" : "=r"(flag_value) : "l"(flag_ptr)); +#if ENABLE_DEBUG_PRINT + printf( + "combine: ---Rank %d received completion flag from rank %d, flag_value: %d, " + "expected_value: %d, " + "address: %p\n", + rank_id, peer_rank, flag_value, expected_value, flag_ptr); +#endif + flag_set = flag_value == expected_value; + } while (!flag_set); + } + asm volatile("fence.acquire.sys;"); + } + __syncthreads(); +#endif + + // Get output location for this token (using src_data_ptrs[0] as output) + T* token_output = static_cast(ptrs.src_data_ptrs[0]) + local_token_idx * elements_per_token; + + // Accumulate across ranks in registers, then store once per segment + vectorized_combine(token_output, size_per_token, rank_id, + max_tokens_per_rank, ptrs); +} + +void moe_a2a_prepare_combine_launch(MoeA2ACombineParams const& params) { + constexpr int kBlockSize = 256; + constexpr int kWarpsPerBlock = kBlockSize / 32; // 8 warps per block + + // Calculate bytes per token based on dtype + int element_size; + switch (params.dtype) { + case nvinfer1::DataType::kHALF: + element_size = sizeof(half); + break; + case nvinfer1::DataType::kBF16: + element_size = sizeof(__nv_bfloat16); + break; + case nvinfer1::DataType::kFLOAT: + element_size = sizeof(float); + break; + default: + FLASHINFER_CHECK(false, "Unsupported dtype for combine prepare"); + return; + } + + int bytes_per_token = params.elements_per_token * element_size; + int total_slots = + params.prepare_payload == nullptr ? 1 : params.ep_size * params.max_tokens_per_rank; + int grid_size_warp = ceilDiv(total_slots, kWarpsPerBlock); + int grid_size_block = total_slots; // one block per token + + if (params.one_block_per_token) { + moeA2APrepareCombineKernel<<>>( + static_cast(const_cast(params.recv_buffers[params.ep_rank])), + static_cast(params.prepare_payload), bytes_per_token, params.ep_size, + params.max_tokens_per_rank, params.flag_val, params.recv_counters); + } else { + moeA2APrepareCombineKernel<<>>( + static_cast(const_cast(params.recv_buffers[params.ep_rank])), + static_cast(params.prepare_payload), bytes_per_token, params.ep_size, + params.max_tokens_per_rank, params.flag_val, params.recv_counters); + } +} + +// ============================================================================ +// Combine Launch Function +// ============================================================================ + +void moe_a2a_combine_launch(MoeA2ACombineParams const& params) { + // Validate parameters + TLLM_CHECK(params.top_k > 0 && params.top_k <= kMaxTopK); + TLLM_CHECK(params.ep_size > 0 && params.ep_size <= kMaxRanks); + TLLM_CHECK(params.local_num_tokens > 0); + TLLM_CHECK(params.elements_per_token > 0); + + // Configure kernel launch + int const kBlockSize = tensorrt_llm::common::getEnvMoeA2ACombineBlockSize(); + int const kWarpsPerBlock = kBlockSize / 32; // warpSize + int grid_size_warp = ceilDiv(params.local_num_tokens, kWarpsPerBlock); + int grid_size_block = params.local_num_tokens; + + // Prepare kernel pointers struct for combine + CombineKernelPointers kernel_ptrs = {}; // Zero-initialize + + // Set output data pointer in src_data_ptrs[0] + kernel_ptrs.src_data_ptrs[0] = params.output_data; + + // Fill recv buffer pointers + for (int rank = 0; rank < params.ep_size; rank++) { + kernel_ptrs.recv_buffers[rank][0] = params.recv_buffers[rank]; + } + + // Copy completion flag pointers + for (int i = 0; i < params.ep_size; i++) { + kernel_ptrs.completion_flags[i] = params.completion_flags[i]; + } + kernel_ptrs.flag_val = params.flag_val; + + // Copy communication tracking pointers + kernel_ptrs.topk_target_ranks = params.topk_target_ranks; + kernel_ptrs.topk_send_indices = params.topk_send_indices; + + // Launch appropriate kernel with compact macros + SWITCH_DTYPE(params.dtype, TKernelType, { + SWITCH_POLICY(params.one_block_per_token, Policy, { + SWITCH_TOP_K(params.top_k, TOP_K, { + auto launch = [&](int grid_blocks, int block_threads) { + moeA2ACombineKernel + <<>>( + kernel_ptrs, params.max_tokens_per_rank, params.elements_per_token, + params.local_num_tokens, params.ep_rank, params.ep_size); + }; + int grid = params.one_block_per_token ? grid_size_block : grid_size_warp; + int cta = kBlockSize; + launch(grid, cta); + }); + }); + }); +} + +// Kernel to sanitize expert ids for invalid tokens +__global__ void moeA2ASanitizeExpertIdsKernel(int32_t* expert_ids_ptr, + int32_t const* recv_counters_ptr, int ep_size, + int max_tokens_per_rank, int top_k, + int32_t invalid_id) { + int tid = blockIdx.x * blockDim.x + threadIdx.x; + int total_tokens = ep_size * max_tokens_per_rank; + if (tid >= total_tokens) return; + + int source_rank = tid / max_tokens_per_rank; + int token_idx = tid % max_tokens_per_rank; + + if (token_idx >= recv_counters_ptr[source_rank]) { + int32_t* token_expert_ids = expert_ids_ptr + tid * top_k; + for (int k = 0; k < top_k; ++k) { + token_expert_ids[k] = invalid_id; + } + } +} + +void moe_a2a_sanitize_expert_ids_launch(int32_t* expert_ids, int32_t const* recv_counters, + int32_t invalid_id, int ep_size, int max_tokens_per_rank, + int top_k, cudaStream_t stream) { + constexpr int kBlockSize = 256; + int total_tokens = ep_size * max_tokens_per_rank; + int grid = ceilDiv(total_tokens, kBlockSize); + moeA2ASanitizeExpertIdsKernel<<>>( + expert_ids, recv_counters, ep_size, max_tokens_per_rank, top_k, invalid_id); +} + +} // namespace tensorrt_llm::kernels::mnnvl_throughput diff --git a/csrc/nv_internal/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.h b/csrc/nv_internal/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.h new file mode 100644 index 0000000000..0e8dfd9b7c --- /dev/null +++ b/csrc/nv_internal/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.h @@ -0,0 +1,181 @@ +/* + * Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once +#include +#include + +namespace tensorrt_llm::kernels::mnnvl_throughput { + +// Configuration constants +static constexpr int kMaxExperts = 256; // Maximum number of experts per rank +static constexpr int kMaxTopK = 8; // Maximum top-k experts per token +static constexpr int kMaxPayloads = 8; // Maximum number of different payload types +static constexpr int kMaxRanks = 64; // Maximum supported EP size + +// Describes a single payload type to be communicated +struct PayloadDescriptor { + void const* src_data; // Source data pointer [local_num_tokens, elements_per_token] + int element_size; // Size of each element in bytes + int elements_per_token; // Number of elements per token (e.g., hidden_size, top_k) +}; + +// Kernel pointers packed into a struct for device access +// Dispatch kernel pointers - const source data +struct DispatchKernelPointers { + // Payload pointers + void const* src_data_ptrs[kMaxPayloads]; // Array of source data pointers + void* recv_buffers[kMaxRanks][kMaxPayloads]; // 2D array of receive buffer pointers + int payload_bytes_per_token[kMaxPayloads]; // Bytes per token for each payload + + // Completion flags for synchronization + uint32_t* + completion_flags[kMaxRanks]; // If completion_flags[target_rank][source_rank] == *flag_val, + // then source rank has signaled the target rank + uint32_t* flag_val; // The value of the flag for this round (stored on the local rank) + + // Local aux data pointers + int* send_counters; // [ep_size] How many tokens have been sent to each target rank + int* recv_counters[kMaxRanks]; // How many tokens have been received from each source rank. Each + // rank has [ep_size] counters + int* local_token_counter; // Atomic counter for completed tokens + + // Top-K compact routing info per local token (size: [local_num_tokens, top_k]) + int* topk_target_ranks; // target rank per k, -1 for duplicates + int* topk_send_indices; // dst index per k, -1 for duplicates +}; + +// Combine kernel pointers - non-const output in src_data_ptrs[0], const recv buffers +struct CombineKernelPointers { + // Payload pointers + void* src_data_ptrs[kMaxPayloads]; // src_data_ptrs[0] is output + void const* recv_buffers[kMaxRanks][kMaxPayloads]; // 2D array of receive buffer pointers (const) + + // Completion flags for synchronization + uint32_t* + completion_flags[kMaxRanks]; // If completion_flags[target_rank][source_rank] == *flag_val, + // then source rank has signaled the target rank + uint32_t* flag_val; // The value of the flag for this round (stored on the local rank) + + // Top-K compact routing info per local token (size: [local_num_tokens, top_k]) + int const* topk_target_ranks; // target rank per k, -1 for duplicates + int const* topk_send_indices; // dst index per k, -1 for duplicates +}; + +// Dispatch phase parameters +struct MoeA2ADispatchParams { + bool one_block_per_token; // True: one block per token, False: one warp per token + + // Threading policy + // EP configuration + int ep_size; // Number of EP ranks + int ep_rank; // Current EP rank + int num_experts_per_rank; // Number of experts per rank (num_experts / ep_size) + + // Token configuration + int local_num_tokens; // Number of tokens on this rank + int max_tokens_per_rank; // Maximum tokens per rank for pre-allocation TODO: Rename to + // runtime_max_tokens_per_rank + int top_k; // Number of experts per token + + // Expert routing information + int32_t const* token_selected_experts; // [local_num_tokens, top_k] + + // Generic payloads + int num_payloads; // Number of different payload types + PayloadDescriptor payloads[kMaxPayloads]; // Array of payload descriptors + + // Local aux data + uint32_t* flag_val; // The value of the flag for this round (stored on the local rank) + int* local_token_counter; // Atomic counter for completed tokens on this rank + int* send_counters; // [ep_size] atomic counters - tracks tokens sent to each target rank + int* topk_target_ranks; // Top-K compact routing info per local token (size: [local_num_tokens, + // top_k]), target rank per k, -1 for duplicates + int* topk_send_indices; // Top-K compact routing info per local token (size: [local_num_tokens, + // top_k]), dst index per k, -1 for duplicates + + // Distributed aux data and recv buffers + int* recv_counters[kMaxRanks]; // tracks tokens received from each source rank. Each rank has + // [ep_size] counters + uint32_t* + completion_flags[kMaxRanks]; // If completion_flags[target_rank][source_rank] == *flag_val, + // then source rank has signaled the target rank + void* recv_buffers[kMaxRanks][kMaxPayloads]; // Per-rank receive buffers for each payload + + // CUDA stream + cudaStream_t stream; +}; + +// Dispatch kernels +void moe_a2a_dispatch_launch(MoeA2ADispatchParams const& params); +// Prepare for dispatch: zero send_counters, local_token_counter and increment flag_val +void moe_a2a_prepare_dispatch_launch(MoeA2ADispatchParams const& params); + +// Combine phase parameters +struct MoeA2ACombineParams { + bool one_block_per_token; // True: one block per token, False: one warp per token + + // EP configuration + int ep_size; // Number of EP ranks + int ep_rank; // Current EP rank + + // Token configuration + int local_num_tokens; // Number of tokens on this rank + int max_tokens_per_rank; // Maximum tokens per rank for pre-allocation TODO: Rename to + // runtime_max_tokens_per_rank + int top_k; // Number of experts per token + + // Prepare-only field: original payload tensor pointer used to stage into workspace + void const* prepare_payload; + + // Output tensor + void* output_data; // Output buffer [local_num_tokens, elements_per_token] + // Payload information + int elements_per_token; // Number of elements per token + nvinfer1::DataType dtype; // Data type for proper summation + + // Local aux data + uint32_t* flag_val; // The value of the flag for this round (stored on the local rank) + int* topk_target_ranks; // Top-K compact routing info per local token (size: [local_num_tokens, + // top_k]), target rank per k, -1 for duplicates + int* topk_send_indices; // Top-K compact routing info per local token (size: [local_num_tokens, + // top_k]), dst index per k, -1 for duplicates + int const* recv_counters; // [ep_size] number of valid tokens per source rank for this target + + // Distributed aux data and recv buffers + uint32_t* + completion_flags[kMaxRanks]; // If completion_flags[target_rank][source_rank] == *flag_val, + // then source rank has signaled the target rank + void const* recv_buffers[kMaxRanks]; // Per-rank receive buffers (only for single payload) + + // CUDA stream + cudaStream_t stream; +}; + +// Combine kernels +void moe_a2a_combine_launch(MoeA2ACombineParams const& params); + +void moe_a2a_prepare_combine_launch(MoeA2ACombineParams const& params); + +// Sanitize expert IDs for invalid tokens +// expert_ids: [ep_size, max_tokens_per_rank, top_k] (int32) +// recv_counters: [ep_size] (int32), number of valid tokens per source +// invalid_id: value to fill for invalid tokens' expert ids +void moe_a2a_sanitize_expert_ids_launch(int32_t* expert_ids, int32_t const* recv_counters, + int32_t invalid_id, int ep_size, int max_tokens_per_rank, + int top_k, cudaStream_t stream); + +} // namespace tensorrt_llm::kernels::mnnvl_throughput diff --git a/csrc/nv_internal/tensorrt_llm/thop/moeAlltoAllMeta.h b/csrc/nv_internal/tensorrt_llm/thop/moeAlltoAllMeta.h new file mode 100644 index 0000000000..354365c1ac --- /dev/null +++ b/csrc/nv_internal/tensorrt_llm/thop/moeAlltoAllMeta.h @@ -0,0 +1,61 @@ +/* + * Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include +#include + +namespace torch_ext { +namespace mnnvl_throughput { + +// Enum for indexing into moe_a2a_metainfo tensor +enum MoeA2AMetaInfoIndex : int64_t { + FLAG_VAL_OFFSET_INDEX = 0, + LOCAL_TOKEN_COUNTER_OFFSET_INDEX = 1, + SEND_COUNTERS_OFFSET_INDEX = 2, + RECV_COUNTERS_OFFSET_INDEX = 3, + // Dispatch completion flags offset + DISPATCH_COMPLETION_FLAGS_OFFSET_INDEX = 4, + // Combine completion flags offset + COMBINE_COMPLETION_FLAGS_OFFSET_INDEX = 5, + TOPK_TARGET_RANKS_OFFSET_INDEX = 6, + TOPK_SEND_INDICES_OFFSET_INDEX = 7, + PAYLOAD_DATA_OFFSET_INDEX = 8, + NUM_METAINFO_FIELDS = 9 +}; + +using MoeA2ADataOffsets = std::array; + +inline std::vector> getMoeA2AMetaInfoIndexPairs() { + return { + {"MOE_A2A_FLAG_VAL_OFFSET_INDEX", FLAG_VAL_OFFSET_INDEX}, + {"MOE_A2A_LOCAL_TOKEN_COUNTER_OFFSET_INDEX", LOCAL_TOKEN_COUNTER_OFFSET_INDEX}, + {"MOE_A2A_SEND_COUNTERS_OFFSET_INDEX", SEND_COUNTERS_OFFSET_INDEX}, + {"MOE_A2A_RECV_COUNTERS_OFFSET_INDEX", RECV_COUNTERS_OFFSET_INDEX}, + {"MOE_A2A_DISPATCH_COMPLETION_FLAGS_OFFSET_INDEX", DISPATCH_COMPLETION_FLAGS_OFFSET_INDEX}, + {"MOE_A2A_COMBINE_COMPLETION_FLAGS_OFFSET_INDEX", COMBINE_COMPLETION_FLAGS_OFFSET_INDEX}, + {"MOE_A2A_TOPK_TARGET_RANKS_OFFSET_INDEX", TOPK_TARGET_RANKS_OFFSET_INDEX}, + {"MOE_A2A_TOPK_SEND_INDICES_OFFSET_INDEX", TOPK_SEND_INDICES_OFFSET_INDEX}, + {"MOE_A2A_PAYLOAD_DATA_OFFSET_INDEX", PAYLOAD_DATA_OFFSET_INDEX}, + {"MOE_A2A_NUM_METAINFO_FIELDS", NUM_METAINFO_FIELDS}, + }; +} + +} // namespace mnnvl_throughput +} // namespace torch_ext diff --git a/csrc/trtllm_moe_a2a.cu b/csrc/trtllm_moe_a2a.cu new file mode 100644 index 0000000000..d474214de7 --- /dev/null +++ b/csrc/trtllm_moe_a2a.cu @@ -0,0 +1,396 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & + * AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include + +#include +#include +#include +#include + +#include "flashinfer/utils.cuh" +#include "tensorrt_llm/common/dataType.h" +#include "tensorrt_llm/common/envUtils.h" +#include "tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.h" +#include "tensorrt_llm/thop/moeAlltoAllMeta.h" +#include "tvm_ffi_utils.h" + +// TODO Review + +using tvm::ffi::Array; +using tvm::ffi::Shape; +using tvm::ffi::Tensor; +using tvm::ffi::TensorView; +using tvm::ffi::Tuple; + +namespace { + +namespace tl_throughput = tensorrt_llm::kernels::mnnvl_throughput; +namespace fi_throughput = torch_ext::mnnvl_throughput; + +constexpr size_t kCachelineAlignment = 128; +constexpr size_t kInt32Bytes = sizeof(int32_t); + +inline size_t alignOffset(size_t offset, size_t alignment = kCachelineAlignment) { + return (offset + alignment - 1) & ~(alignment - 1); +} + +fi_throughput::MoeA2ADataOffsets calculateOffsets(int epSize, int maxNumTokens) { + fi_throughput::MoeA2ADataOffsets offsets{}; + size_t offset = 0; + + offsets[fi_throughput::FLAG_VAL_OFFSET_INDEX] = offset; + offset += kInt32Bytes; + + offsets[fi_throughput::LOCAL_TOKEN_COUNTER_OFFSET_INDEX] = offset; + offset += kInt32Bytes; + + offsets[fi_throughput::SEND_COUNTERS_OFFSET_INDEX] = offset; + offset += static_cast(epSize) * kInt32Bytes; + + offsets[fi_throughput::RECV_COUNTERS_OFFSET_INDEX] = offset; + offset += static_cast(epSize) * kInt32Bytes; + + offset = alignOffset(offset); + offsets[fi_throughput::DISPATCH_COMPLETION_FLAGS_OFFSET_INDEX] = offset; + offset += static_cast(epSize) * kInt32Bytes; + + offset = alignOffset(offset); + offsets[fi_throughput::COMBINE_COMPLETION_FLAGS_OFFSET_INDEX] = offset; + offset += static_cast(epSize) * kInt32Bytes; + + offset = alignOffset(offset); + offsets[fi_throughput::TOPK_TARGET_RANKS_OFFSET_INDEX] = offset; + offset += static_cast(maxNumTokens) * tl_throughput::kMaxTopK * kInt32Bytes; + + offset = alignOffset(offset); + offsets[fi_throughput::TOPK_SEND_INDICES_OFFSET_INDEX] = offset; + offset += static_cast(maxNumTokens) * tl_throughput::kMaxTopK * kInt32Bytes; + + offset = alignOffset(offset); + offsets[fi_throughput::PAYLOAD_DATA_OFFSET_INDEX] = offset; + return offsets; +} + +Tensor moeA2AInitializeOp(TensorView workspace, int64_t epRank, int64_t epSize, + int64_t maxNumTokens) { + CHECK_INPUT_TYPE(workspace, dl_uint8); + TVM_FFI_ICHECK_EQ(workspace.ndim(), 2) << "workspace must be a 2D tensor"; + TVM_FFI_ICHECK_EQ(workspace.size(0), epSize) << "workspace first dim must equal ep_size"; + TVM_FFI_ICHECK(epRank >= 0 && epRank < epSize) << "epRank out of range"; + + auto stream = get_current_stream(); + auto* basePtr = static_cast(workspace.data_ptr()); + auto* rankPtr = basePtr + epRank * workspace.stride(0); + auto result = cudaMemsetAsync(rankPtr, 0, workspace.size(1), stream); + TVM_FFI_ICHECK(result == cudaSuccess) << "cudaMemsetAsync failed"; + + auto offsets = calculateOffsets(static_cast(epSize), static_cast(maxNumTokens)); + Tensor metainfo = alloc_tensor({fi_throughput::NUM_METAINFO_FIELDS}, dl_int64, cpu); + auto* metaPtr = static_cast(metainfo.data_ptr()); + std::copy(offsets.begin(), offsets.end(), metaPtr); + return metainfo; +} + +Tuple, int64_t> moeA2ADispatchOp(TensorView tokenSelectedExperts, + TensorView payloadPtrsTensor, + TensorView payloadElementSizesTensor, + TensorView payloadElementsPerTokenTensor, + TensorView workspace, TensorView metainfo, + int64_t runtimeMaxTokensPerRank, int64_t epRank, + int64_t epSize, int64_t topK, int64_t numExperts) { + using tl_throughput::PayloadDescriptor; + CHECK_INPUT(tokenSelectedExperts); + CHECK_INPUT_TYPE(tokenSelectedExperts, dl_int32); + TVM_FFI_ICHECK_EQ(tokenSelectedExperts.ndim(), 2) << "token_selected_experts must be 2D"; + TVM_FFI_ICHECK_EQ(tokenSelectedExperts.size(1), topK) << "token_selected_experts shape mismatch"; + + CHECK_INPUT_TYPE(payloadPtrsTensor, dl_int64); + CHECK_INPUT_TYPE(payloadElementSizesTensor, dl_int32); + CHECK_INPUT_TYPE(payloadElementsPerTokenTensor, dl_int32); + TVM_FFI_ICHECK_EQ(payloadPtrsTensor.ndim(), 1); + TVM_FFI_ICHECK_EQ(payloadElementSizesTensor.ndim(), 1); + TVM_FFI_ICHECK_EQ(payloadElementsPerTokenTensor.ndim(), 1); + + int numPayloads = static_cast(payloadPtrsTensor.size(0)); + TVM_FFI_ICHECK(numPayloads > 0) << "At least one payload is required"; + TVM_FFI_ICHECK(numPayloads <= tl_throughput::kMaxPayloads) << "Too many payloads"; + TVM_FFI_ICHECK_EQ(payloadElementSizesTensor.size(0), numPayloads); + TVM_FFI_ICHECK_EQ(payloadElementsPerTokenTensor.size(0), numPayloads); + + CHECK_CPU(metainfo); + CHECK_INPUT_TYPE(metainfo, dl_int64); + TVM_FFI_ICHECK_EQ(metainfo.ndim(), 1); + TVM_FFI_ICHECK_EQ(metainfo.size(0), fi_throughput::NUM_METAINFO_FIELDS); + auto const* offsetsPtr = static_cast(metainfo.data_ptr()); + fi_throughput::MoeA2ADataOffsets offsets{}; + std::copy(offsetsPtr, offsetsPtr + fi_throughput::NUM_METAINFO_FIELDS, offsets.begin()); + + CHECK_INPUT_TYPE(workspace, dl_uint8); + TVM_FFI_ICHECK_EQ(workspace.ndim(), 2); + TVM_FFI_ICHECK_EQ(workspace.size(0), epSize); + TVM_FFI_ICHECK(epRank >= 0 && epRank < epSize); + TVM_FFI_ICHECK(runtimeMaxTokensPerRank > 0); + TVM_FFI_ICHECK(numExperts >= epSize && numExperts % epSize == 0) + << "num_experts must be divisible by ep_size"; + TVM_FFI_ICHECK(topK > 0 && topK <= tl_throughput::kMaxTopK); + + auto localNumTokens = static_cast(tokenSelectedExperts.size(0)); + TVM_FFI_ICHECK(localNumTokens > 0) << "local_num_tokens must be positive"; + + auto* payloadPtrs = static_cast(payloadPtrsTensor.data_ptr()); + auto* payloadEltSizes = static_cast(payloadElementSizesTensor.data_ptr()); + auto* payloadEltPerToken = static_cast(payloadElementsPerTokenTensor.data_ptr()); + + std::vector payloadDescriptors(numPayloads); + std::vector payloadByteSizes(numPayloads); + int64_t totalBytesNeeded = 0; + for (int i = 0; i < numPayloads; ++i) { + payloadDescriptors[i].src_data = reinterpret_cast(payloadPtrs[i]); + payloadDescriptors[i].element_size = payloadEltSizes[i]; + payloadDescriptors[i].elements_per_token = payloadEltPerToken[i]; + int64_t bytesPerPayload = static_cast(epSize) * runtimeMaxTokensPerRank * + payloadEltPerToken[i] * payloadEltSizes[i]; + payloadByteSizes[i] = bytesPerPayload; + totalBytesNeeded += bytesPerPayload; + } + + auto* workspaceBase = static_cast(workspace.data_ptr()); + auto strideBytes = workspace.stride(0); + auto* rankWorkspacePtr = workspaceBase + epRank * strideBytes; + int64_t sizePerRank = workspace.size(1); + + int64_t requiredSize = offsets[fi_throughput::PAYLOAD_DATA_OFFSET_INDEX] + totalBytesNeeded; + TVM_FFI_ICHECK(sizePerRank >= requiredSize) << "workspace size per rank insufficient, need " + << requiredSize << " bytes but has " << sizePerRank; + + tl_throughput::MoeA2ADispatchParams params{}; + params.one_block_per_token = tensorrt_llm::common::getEnvMoeA2AOneBlockPerToken(); + params.ep_size = static_cast(epSize); + params.ep_rank = static_cast(epRank); + params.num_experts_per_rank = static_cast(numExperts / epSize); + params.local_num_tokens = localNumTokens; + params.max_tokens_per_rank = static_cast(runtimeMaxTokensPerRank); + params.top_k = static_cast(topK); + params.token_selected_experts = static_cast(tokenSelectedExperts.data_ptr()); + params.num_payloads = numPayloads; + std::copy(payloadDescriptors.begin(), payloadDescriptors.end(), params.payloads); + + params.flag_val = + reinterpret_cast(rankWorkspacePtr + offsets[fi_throughput::FLAG_VAL_OFFSET_INDEX]); + params.local_token_counter = reinterpret_cast( + rankWorkspacePtr + offsets[fi_throughput::LOCAL_TOKEN_COUNTER_OFFSET_INDEX]); + params.send_counters = + reinterpret_cast(rankWorkspacePtr + offsets[fi_throughput::SEND_COUNTERS_OFFSET_INDEX]); + params.topk_target_ranks = reinterpret_cast( + rankWorkspacePtr + offsets[fi_throughput::TOPK_TARGET_RANKS_OFFSET_INDEX]); + params.topk_send_indices = reinterpret_cast( + rankWorkspacePtr + offsets[fi_throughput::TOPK_SEND_INDICES_OFFSET_INDEX]); + + for (int targetRank = 0; targetRank < epSize; ++targetRank) { + auto* targetWorkspacePtr = workspaceBase + targetRank * strideBytes; + params.recv_counters[targetRank] = reinterpret_cast( + targetWorkspacePtr + offsets[fi_throughput::RECV_COUNTERS_OFFSET_INDEX]); + params.completion_flags[targetRank] = reinterpret_cast( + targetWorkspacePtr + offsets[fi_throughput::DISPATCH_COMPLETION_FLAGS_OFFSET_INDEX]); + + size_t offset = static_cast(offsets[fi_throughput::PAYLOAD_DATA_OFFSET_INDEX]); + for (int payloadIdx = 0; payloadIdx < numPayloads; ++payloadIdx) { + params.recv_buffers[targetRank][payloadIdx] = targetWorkspacePtr + offset; + offset += payloadByteSizes[payloadIdx]; + } + } + + params.stream = get_current_stream(); + + tl_throughput::moe_a2a_prepare_dispatch_launch(params); + tl_throughput::moe_a2a_dispatch_launch(params); + auto launchErr = cudaGetLastError(); + TVM_FFI_ICHECK(launchErr == cudaSuccess) + << "moe_a2a_dispatch launch failed: " << cudaGetErrorString(launchErr); + + Array recvPtrs; + recvPtrs.reserve(numPayloads); + size_t localOffset = static_cast(offsets[fi_throughput::PAYLOAD_DATA_OFFSET_INDEX]); + for (int payloadIdx = 0; payloadIdx < numPayloads; ++payloadIdx) { + auto* ptr = rankWorkspacePtr + localOffset; + recvPtrs.push_back(reinterpret_cast(ptr)); + localOffset += payloadByteSizes[payloadIdx]; + } + + int64_t combinePayloadOffset = static_cast(alignOffset(localOffset)); + return Tuple(recvPtrs, combinePayloadOffset); +} + +nvinfer1::DataType toNvDataType(DLDataType dtype) { + auto code = encode_dlpack_dtype(dtype); + if (code == float16_code) { + return nvinfer1::DataType::kHALF; + } + if (code == bfloat16_code) { + return nvinfer1::DataType::kBF16; + } + if (code == float32_code) { + return nvinfer1::DataType::kFLOAT; + } + TVM_FFI_LOG_AND_THROW(TypeError) << "Unsupported dtype for MoE combine"; + return nvinfer1::DataType::kFLOAT; +} + +Tensor moeA2ACombineOp(TensorView payload, int64_t localNumTokens, TensorView workspace, + TensorView metainfo, int64_t runtimeMaxTokensPerRank, int64_t epRank, + int64_t epSize, int64_t topK, int64_t combinePayloadOffset, + bool payloadInWorkspace) { + using tl_throughput::MoeA2ACombineParams; + CHECK_INPUT(payload); + TVM_FFI_ICHECK_EQ(payload.ndim(), 3) + << "payload must be [ep_size, runtime_max_tokens_per_rank, hidden]"; + TVM_FFI_ICHECK_EQ(payload.size(0), epSize); + TVM_FFI_ICHECK_EQ(payload.size(1), runtimeMaxTokensPerRank); + TVM_FFI_ICHECK(epRank >= 0 && epRank < epSize); + TVM_FFI_ICHECK(topK > 0 && topK <= tl_throughput::kMaxTopK); + TVM_FFI_ICHECK(localNumTokens > 0); + + CHECK_CPU(metainfo); + CHECK_INPUT_TYPE(metainfo, dl_int64); + TVM_FFI_ICHECK_EQ(metainfo.ndim(), 1); + TVM_FFI_ICHECK_EQ(metainfo.size(0), fi_throughput::NUM_METAINFO_FIELDS); + auto const* offsetsPtr = static_cast(metainfo.data_ptr()); + fi_throughput::MoeA2ADataOffsets offsets{}; + std::copy(offsetsPtr, offsetsPtr + fi_throughput::NUM_METAINFO_FIELDS, offsets.begin()); + + CHECK_INPUT_TYPE(workspace, dl_uint8); + TVM_FFI_ICHECK_EQ(workspace.ndim(), 2); + TVM_FFI_ICHECK_EQ(workspace.size(0), epSize); + auto* workspaceBase = static_cast(workspace.data_ptr()); + auto strideBytes = workspace.stride(0); + auto* rankWorkspacePtr = workspaceBase + epRank * strideBytes; + int64_t sizePerRank = workspace.size(1); + + int64_t elementsPerToken = payload.size(2); + int64_t payloadBytes = + payload.numel() * + get_element_size(payload); // includes all ranks * runtime_max_tokens_per_rank + TVM_FFI_ICHECK(combinePayloadOffset >= 0 && combinePayloadOffset + payloadBytes <= sizePerRank) + << "workspace insufficient for combine payload region"; + + if (payloadInWorkspace) { + auto* expectedPtr = rankWorkspacePtr + combinePayloadOffset; + TVM_FFI_ICHECK(payload.data_ptr() == expectedPtr) + << "payload_in_workspace is True but tensor pointer mismatch"; + } + + Tensor output = + alloc_tensor({localNumTokens, elementsPerToken}, payload.dtype(), payload.device()); + + MoeA2ACombineParams params{}; + params.one_block_per_token = tensorrt_llm::common::getEnvMoeA2AOneBlockPerToken(); + params.ep_size = static_cast(epSize); + params.ep_rank = static_cast(epRank); + params.local_num_tokens = static_cast(localNumTokens); + params.max_tokens_per_rank = static_cast(runtimeMaxTokensPerRank); + params.top_k = static_cast(topK); + params.prepare_payload = payloadInWorkspace ? nullptr : payload.data_ptr(); + params.output_data = output.data_ptr(); + params.elements_per_token = static_cast(elementsPerToken); + params.dtype = toNvDataType(payload.dtype()); + + params.flag_val = + reinterpret_cast(rankWorkspacePtr + offsets[fi_throughput::FLAG_VAL_OFFSET_INDEX]); + params.topk_target_ranks = reinterpret_cast( + rankWorkspacePtr + offsets[fi_throughput::TOPK_TARGET_RANKS_OFFSET_INDEX]); + params.topk_send_indices = reinterpret_cast( + rankWorkspacePtr + offsets[fi_throughput::TOPK_SEND_INDICES_OFFSET_INDEX]); + params.recv_counters = + reinterpret_cast(rankWorkspacePtr + offsets[fi_throughput::RECV_COUNTERS_OFFSET_INDEX]); + + for (int targetRank = 0; targetRank < epSize; ++targetRank) { + auto* targetWorkspacePtr = workspaceBase + targetRank * strideBytes; + params.completion_flags[targetRank] = reinterpret_cast( + targetWorkspacePtr + offsets[fi_throughput::COMBINE_COMPLETION_FLAGS_OFFSET_INDEX]); + params.recv_buffers[targetRank] = targetWorkspacePtr + combinePayloadOffset; + } + params.stream = get_current_stream(); + + tl_throughput::moe_a2a_prepare_combine_launch(params); + tl_throughput::moe_a2a_combine_launch(params); + auto err = cudaGetLastError(); + TVM_FFI_ICHECK(err == cudaSuccess) + << "moe_a2a_combine launch failed: " << cudaGetErrorString(err); + return output; +} + +void moeA2ASanitizeExpertIdsOp(TensorView expertIds, TensorView workspace, TensorView metainfo, + int64_t epRank, int64_t invalidExpertId) { + CHECK_INPUT(expertIds); + CHECK_INPUT_TYPE(expertIds, dl_int32); + TVM_FFI_ICHECK_EQ(expertIds.ndim(), 3); + int64_t epSize = expertIds.size(0); + int64_t runtimeMaxTokensPerRank = expertIds.size(1); + int64_t topK = expertIds.size(2); + + CHECK_CPU(metainfo); + CHECK_INPUT_TYPE(metainfo, dl_int64); + TVM_FFI_ICHECK_EQ(metainfo.ndim(), 1); + TVM_FFI_ICHECK_EQ(metainfo.size(0), fi_throughput::NUM_METAINFO_FIELDS); + auto const* offsetsPtr = static_cast(metainfo.data_ptr()); + fi_throughput::MoeA2ADataOffsets offsets{}; + std::copy(offsetsPtr, offsetsPtr + fi_throughput::NUM_METAINFO_FIELDS, offsets.begin()); + + CHECK_INPUT_TYPE(workspace, dl_uint8); + TVM_FFI_ICHECK_EQ(workspace.ndim(), 2); + auto* workspaceBase = static_cast(workspace.data_ptr()); + auto* rankWorkspacePtr = workspaceBase + epRank * workspace.stride(0); + auto* recvCounters = + reinterpret_cast(rankWorkspacePtr + offsets[fi_throughput::RECV_COUNTERS_OFFSET_INDEX]); + + tl_throughput::moe_a2a_sanitize_expert_ids_launch( + static_cast(expertIds.data_ptr()), recvCounters, + static_cast(invalidExpertId), static_cast(epSize), + static_cast(runtimeMaxTokensPerRank), static_cast(topK), get_current_stream()); +} + +int64_t moeA2AGetCombinePayloadPtrOp(TensorView workspace, int64_t epRank, int64_t epSize, + int64_t runtimeMaxTokensPerRank, int64_t combinePayloadOffset, + int64_t elementsPerToken, int64_t elementSizeBytes) { + CHECK_INPUT_TYPE(workspace, dl_uint8); + TVM_FFI_ICHECK_EQ(workspace.ndim(), 2); + TVM_FFI_ICHECK_EQ(workspace.size(0), epSize); + TVM_FFI_ICHECK(epRank >= 0 && epRank < epSize); + TVM_FFI_ICHECK(runtimeMaxTokensPerRank > 0); + TVM_FFI_ICHECK(elementsPerToken > 0); + TVM_FFI_ICHECK(elementSizeBytes > 0); + + int64_t sizePerRank = workspace.size(1); + int64_t bytesNeeded = epSize * runtimeMaxTokensPerRank * elementsPerToken * elementSizeBytes; + TVM_FFI_ICHECK(combinePayloadOffset >= 0 && combinePayloadOffset + bytesNeeded <= sizePerRank) + << "combine payload exceeds workspace capacity"; + + auto* basePtr = static_cast(workspace.data_ptr()); + auto* rankPtr = basePtr + epRank * workspace.stride(0); + return reinterpret_cast(rankPtr + combinePayloadOffset); +} + +} // namespace + +TVM_FFI_DLL_EXPORT_TYPED_FUNC(moe_a2a_initialize, moeA2AInitializeOp); +TVM_FFI_DLL_EXPORT_TYPED_FUNC(moe_a2a_dispatch, moeA2ADispatchOp); +TVM_FFI_DLL_EXPORT_TYPED_FUNC(moe_a2a_combine, moeA2ACombineOp); +TVM_FFI_DLL_EXPORT_TYPED_FUNC(moe_a2a_sanitize_expert_ids, moeA2ASanitizeExpertIdsOp); +TVM_FFI_DLL_EXPORT_TYPED_FUNC(moe_a2a_get_combine_payload_ptr, moeA2AGetCombinePayloadPtrOp); diff --git a/docs/api/comm.rst b/docs/api/comm.rst index 735f9a4294..9d38285ad0 100644 --- a/docs/api/comm.rst +++ b/docs/api/comm.rst @@ -128,3 +128,18 @@ TensorRT-LLM MNNVL AllReduce trtllm_mnnvl_all_reduce trtllm_mnnvl_fused_allreduce_rmsnorm mpi_barrier + +MNNVL A2A (Throughput Backend) +------------------------------- + +.. currentmodule:: flashinfer.comm + +.. autosummary:: + :toctree: ../generated + + MoeAlltoAll + moe_a2a_initialize + moe_a2a_dispatch + moe_a2a_combine + moe_a2a_sanitize_expert_ids + moe_a2a_get_combine_payload_tensor diff --git a/flashinfer/aot.py b/flashinfer/aot.py index 609e1bcbcf..84bc0ca199 100644 --- a/flashinfer/aot.py +++ b/flashinfer/aot.py @@ -512,12 +512,14 @@ def gen_all_modules( from .jit.comm import gen_nvshmem_module from .jit.comm import gen_comm_alltoall_module from .jit.comm import gen_trtllm_mnnvl_comm_module + from .jit.comm import gen_mnnvl_a2a_module jit_specs.append(gen_nvshmem_module()) jit_specs.append(gen_comm_alltoall_module()) if has_sm100: jit_specs.append(gen_trtllm_comm_module()) jit_specs.append(gen_trtllm_mnnvl_comm_module()) + jit_specs.append(gen_mnnvl_a2a_module()) jit_specs.append(gen_vllm_comm_module()) if add_misc: diff --git a/flashinfer/comm/__init__.py b/flashinfer/comm/__init__.py index f7ae3754ac..4860187fc5 100644 --- a/flashinfer/comm/__init__.py +++ b/flashinfer/comm/__init__.py @@ -39,4 +39,16 @@ from .vllm_ar import register_buffer as vllm_register_buffer from .vllm_ar import register_graph_buffers as vllm_register_graph_buffers +# MNNVL A2A (Throughput Backend) +from .trtllm_moe_a2a import MoeAlltoAll as MoeAlltoAll +from .trtllm_moe_a2a import moe_a2a_combine as moe_a2a_combine +from .trtllm_moe_a2a import moe_a2a_dispatch as moe_a2a_dispatch +from .trtllm_moe_a2a import ( + moe_a2a_get_combine_payload_tensor as moe_a2a_get_combine_payload_tensor, +) +from .trtllm_moe_a2a import moe_a2a_initialize as moe_a2a_initialize +from .trtllm_moe_a2a import ( + moe_a2a_sanitize_expert_ids as moe_a2a_sanitize_expert_ids, +) + # from .mnnvl import MnnvlMemory, MnnvlMoe, MoEAlltoallInfo diff --git a/flashinfer/comm/trtllm_moe_a2a.py b/flashinfer/comm/trtllm_moe_a2a.py new file mode 100644 index 0000000000..59c2ae0953 --- /dev/null +++ b/flashinfer/comm/trtllm_moe_a2a.py @@ -0,0 +1,428 @@ +""" +MoE All-to-All Operations (Throughput Backend) + +This module provides the throughput-optimized all-to-all backend for MoE expert parallelism, +supporting multiple payloads per collective operation. +""" + +# TODO Review + +from dataclasses import dataclass +from typing import Optional + +import torch + +from .mnnvl import MnnvlMemory +from .mapping import Mapping +from ..jit.comm import gen_mnnvl_a2a_module + + +@dataclass +class _A2AState: + """Internal state tracking for MoeAlltoAll operations.""" + + phase: str = "idle" # idle | dispatched + local_num_tokens: Optional[int] = None + combine_payload_offset: Optional[int] = None + + +def get_mnnvl_a2a_module(): + """Get or build the MNNVL A2A JIT module.""" + return gen_mnnvl_a2a_module().build_and_load() + + +def moe_a2a_initialize( + workspace: torch.Tensor, + ep_rank: int, + ep_size: int, + max_num_tokens: int, +) -> torch.Tensor: + """ + Initialize MoE A2A workspace. + + Args: + workspace: [ep_size, size_per_rank] uint8 tensor (MNNVL memory) + ep_rank: Current expert parallel rank + ep_size: Total expert parallel size + max_num_tokens: Maximum number of tokens supported + + Returns: + metainfo: Tensor containing workspace offsets + """ + return get_mnnvl_a2a_module().moe_a2a_initialize( + workspace, ep_rank, ep_size, max_num_tokens + ) + + +def moe_a2a_dispatch( + token_selected_experts: torch.Tensor, + input_payloads: list[torch.Tensor], + workspace: torch.Tensor, + metainfo: torch.Tensor, + runtime_max_tokens_per_rank: int, + ep_rank: int, + ep_size: int, + top_k: int, + num_experts: int, +): + """ + Dispatch tokens and payloads to expert ranks. + + Args: + token_selected_experts: [local_num_tokens, top_k] int32 tensor + input_payloads: List of [local_num_tokens, *] tensors to dispatch + workspace: [ep_size, size_per_rank] workspace tensor + metainfo: Metadata tensor from initialize + runtime_max_tokens_per_rank: Max tokens per rank in this batch + ep_rank: Current expert parallel rank + ep_size: Total expert parallel size + top_k: Number of experts per token + num_experts: Total number of experts + + Returns: + recv_tensors: List of [ep_size, max_tokens, *] tensors + combine_payload_offset: Offset for combine payload region + """ + return get_mnnvl_a2a_module().moe_a2a_dispatch( + token_selected_experts, + input_payloads, + workspace, + metainfo, + runtime_max_tokens_per_rank, + ep_rank, + ep_size, + top_k, + num_experts, + ) + + +def moe_a2a_combine( + payload: torch.Tensor, + local_num_tokens: int, + workspace: torch.Tensor, + metainfo: torch.Tensor, + runtime_max_tokens_per_rank: int, + ep_rank: int, + ep_size: int, + top_k: int, + combine_payload_offset: int, + payload_in_workspace: bool = False, +) -> torch.Tensor: + """ + Combine expert outputs back to originating tokens. + + Args: + payload: [ep_size, max_tokens, elements_per_token] tensor + local_num_tokens: Number of tokens on this rank + workspace: [ep_size, size_per_rank] workspace tensor + metainfo: Metadata tensor from initialize + runtime_max_tokens_per_rank: Max tokens per rank in this batch + ep_rank: Current expert parallel rank + ep_size: Total expert parallel size + top_k: Number of experts per token + combine_payload_offset: Offset from dispatch + payload_in_workspace: If True, payload is workspace-backed + + Returns: + output: [local_num_tokens, elements_per_token] tensor + """ + return get_mnnvl_a2a_module().moe_a2a_combine( + payload, + local_num_tokens, + workspace, + metainfo, + runtime_max_tokens_per_rank, + ep_rank, + ep_size, + top_k, + combine_payload_offset, + payload_in_workspace, + ) + + +def moe_a2a_sanitize_expert_ids( + expert_ids: torch.Tensor, + workspace: torch.Tensor, + metainfo: torch.Tensor, + ep_rank: int, + invalid_expert_id: int, +) -> None: + """ + Sanitize expert IDs for invalid tokens. + + Args: + expert_ids: [ep_size, max_tokens, top_k] int32 tensor (modified in-place) + workspace: [ep_size, size_per_rank] workspace tensor + metainfo: Metadata tensor from initialize + ep_rank: Current expert parallel rank + invalid_expert_id: Value to fill for invalid tokens + """ + get_mnnvl_a2a_module().moe_a2a_sanitize_expert_ids( + expert_ids, workspace, metainfo, ep_rank, invalid_expert_id + ) + + +def moe_a2a_get_combine_payload_tensor( + workspace: torch.Tensor, + ep_rank: int, + ep_size: int, + runtime_max_tokens_per_rank: int, + combine_payload_offset: int, + dtype: torch.dtype, + hidden_size: int, +) -> torch.Tensor: + """ + Get combine payload tensor backed by workspace (zero-copy). + + Args: + workspace: [ep_size, size_per_rank] workspace tensor + ep_rank: Current expert parallel rank + ep_size: Total expert parallel size + runtime_max_tokens_per_rank: Max tokens per rank in this batch + combine_payload_offset: Offset from dispatch + dtype: Data type for the tensor + hidden_size: Hidden dimension size + + Returns: + tensor: [ep_size * max_tokens, hidden_size] workspace-backed tensor + """ + return get_mnnvl_a2a_module().moe_a2a_get_combine_payload_tensor( + workspace, + ep_rank, + ep_size, + runtime_max_tokens_per_rank, + combine_payload_offset, + dtype, + hidden_size, + ) + + +class MoeAlltoAll: + """ + Manages MoE All-to-All operations with proper workspace allocation and synchronization. + + This class provides the throughput-optimized backend that supports multiple payloads + per collective operation, explicit dispatch/combine phases, and workspace-backed tensors. + + Example: + >>> moe_a2a = MoeAlltoAll(mapping, max_num_tokens=2048, top_k=2, num_experts=8) + >>> recv = moe_a2a.dispatch(experts, [hidden, ids, scales], batch_size) + >>> output = moe_a2a.combine(processed, batch_size) + """ + + # Single shared workspace across the process + _WORKSPACE: Optional[dict] = None + + def __init__( + self, + mapping: Mapping, + max_num_tokens: int, + top_k: int, + num_experts: int, + workspace_size_per_rank: int = 512 * 1024 * 1024, + ): + """ + Initialize MoeAlltoAll with workspace allocation. + + Args: + mapping: Mapping object containing rank information + max_num_tokens: Maximum number of tokens supported + top_k: Number of experts per token + num_experts: Total number of experts + workspace_size_per_rank: Size of workspace per rank in bytes (default: 512MB) + """ + # Initialize MNNVL memory system + MnnvlMemory.initialize() + + self.workspace_size_per_rank = workspace_size_per_rank + self.max_num_tokens = max_num_tokens + self.ep_size = mapping.tp_size + self.ep_rank = mapping.tp_rank + self.top_k = top_k + self.num_experts = num_experts + + if not isinstance(self.top_k, int) or self.top_k <= 0: + raise ValueError("top_k must be a positive int") + if not isinstance(self.num_experts, int) or self.num_experts <= 0: + raise ValueError("num_experts must be a positive int") + + # Allocate or reuse workspace + if self._WORKSPACE is None: + mnnvl_mem = MnnvlMemory(mapping, workspace_size_per_rank) + workspace = mnnvl_mem.as_torch_strided_tensor(torch.uint8) + metainfo = moe_a2a_initialize( + workspace, + self.ep_rank, + self.ep_size, + self.max_num_tokens, + ) + MoeAlltoAll._WORKSPACE = { + "workspace_size_per_rank": workspace_size_per_rank, + "max_num_tokens": self.max_num_tokens, + "ep_rank": self.ep_rank, + "ep_size": self.ep_size, + "mnnvl_mem": mnnvl_mem, + "workspace": workspace, + "metainfo": metainfo, + } + else: + # Validate workspace compatibility + assert ( + self._WORKSPACE["workspace_size_per_rank"] == workspace_size_per_rank + ), "Workspace size mismatch" + assert self._WORKSPACE["max_num_tokens"] == self.max_num_tokens, ( + "Max tokens mismatch" + ) + assert self._WORKSPACE["ep_rank"] == self.ep_rank, "EP rank mismatch" + assert self._WORKSPACE["ep_size"] == self.ep_size, "EP size mismatch" + + self.mnnvl_mem = self._WORKSPACE["mnnvl_mem"] + self.workspace = self._WORKSPACE["workspace"] + self.metainfo = self._WORKSPACE["metainfo"] + self._state = _A2AState() + + def dispatch( + self, + token_selected_experts: torch.Tensor, + input_payloads: list[torch.Tensor], + runtime_max_tokens_per_rank: int, + invalid_token_expert_id: Optional[int] = None, + expert_id_payload_index: Optional[int] = None, + ) -> list[torch.Tensor]: + """ + Perform MoE all-to-all dispatch operation. + + Args: + token_selected_experts: [local_num_tokens, top_k] expert indices + input_payloads: List of [local_num_tokens, *] tensors to dispatch + runtime_max_tokens_per_rank: Max tokens per rank in this batch + invalid_token_expert_id: If set, sanitize invalid tokens to this ID + expert_id_payload_index: Index of expert IDs in input_payloads (required if invalid_token_expert_id is set) + + Returns: + recv_tensors: List of [ep_size, max_tokens, *] tensors + """ + assert self._state.phase == "idle", "dispatch called twice without combine" + assert runtime_max_tokens_per_rank <= self.max_num_tokens, ( + "runtime_max_tokens_per_rank exceeds max_num_tokens" + ) + + recv_tensors, combine_payload_offset = moe_a2a_dispatch( + token_selected_experts, + input_payloads, + self.workspace, + self.metainfo, + runtime_max_tokens_per_rank, + self.ep_rank, + self.ep_size, + self.top_k, + self.num_experts, + ) + + # Update state + self._state.local_num_tokens = token_selected_experts.size(0) + self._state.combine_payload_offset = combine_payload_offset + self._state.phase = "dispatched" + + # Sanitize invalid tokens if requested + if invalid_token_expert_id is not None: + assert expert_id_payload_index is not None, ( + "expert_id_payload_index required when invalid_token_expert_id is set" + ) + recv_expert_ids = recv_tensors[expert_id_payload_index] + moe_a2a_sanitize_expert_ids( + recv_expert_ids, + self.workspace, + self.metainfo, + self.ep_rank, + invalid_token_expert_id, + ) + + return recv_tensors + + def combine( + self, + payload: torch.Tensor, + runtime_max_tokens_per_rank: int, + payload_in_workspace: bool = False, + ) -> torch.Tensor: + """ + Perform MoE all-to-all combine operation. + + Args: + payload: [ep_size, max_tokens, elements_per_token] tensor + runtime_max_tokens_per_rank: Max tokens per rank in this batch + payload_in_workspace: If True, payload is workspace-backed (skip staging) + + Returns: + output: [local_num_tokens, elements_per_token] tensor + """ + assert self._state.phase == "dispatched", ( + "combine called before successful dispatch" + ) + assert runtime_max_tokens_per_rank <= self.max_num_tokens, ( + "runtime_max_tokens_per_rank exceeds max_num_tokens" + ) + + output = moe_a2a_combine( + payload, + self._state.local_num_tokens, + self.workspace, + self.metainfo, + runtime_max_tokens_per_rank, + self.ep_rank, + self.ep_size, + self.top_k, + self._state.combine_payload_offset, + payload_in_workspace, + ) + + # Reset state for next round + self._state = _A2AState() + + return output + + def get_combine_payload_tensor_in_workspace( + self, + runtime_max_tokens_per_rank: int, + hidden_size: int, + dtype: torch.dtype, + ) -> torch.Tensor: + """ + Get combine payload tensor backed by workspace (zero-copy). + + This tensor can be written to directly by expert processing, avoiding + a staging copy in the combine operation. + + Args: + runtime_max_tokens_per_rank: Max tokens per rank in this batch + hidden_size: Hidden dimension size + dtype: Data type for the tensor + + Returns: + tensor: [ep_size * max_tokens, hidden_size] workspace-backed tensor + """ + if self._state.phase != "dispatched": + raise RuntimeError( + "get_combine_payload_tensor_in_workspace called before successful dispatch" + ) + + return moe_a2a_get_combine_payload_tensor( + self.workspace, + self.ep_rank, + self.ep_size, + runtime_max_tokens_per_rank, + self._state.combine_payload_offset, + dtype, + hidden_size, + ) + + +__all__ = [ + "MoeAlltoAll", + "moe_a2a_initialize", + "moe_a2a_dispatch", + "moe_a2a_combine", + "moe_a2a_sanitize_expert_ids", + "moe_a2a_get_combine_payload_tensor", +] diff --git a/flashinfer/jit/__init__.py b/flashinfer/jit/__init__.py index 0bacf2d28b..bd1934ff62 100644 --- a/flashinfer/jit/__init__.py +++ b/flashinfer/jit/__init__.py @@ -77,6 +77,7 @@ from .comm import gen_trtllm_comm_module as gen_trtllm_comm_module from .comm import gen_vllm_comm_module as gen_vllm_comm_module from .comm import gen_nvshmem_module as gen_nvshmem_module +from .comm import gen_mnnvl_a2a_module as gen_mnnvl_a2a_module from .dsv3_optimizations import ( gen_dsv3_router_gemm_module as gen_dsv3_router_gemm_module, ) diff --git a/flashinfer/jit/comm.py b/flashinfer/jit/comm.py index 27661b1fe2..4c350ddf22 100644 --- a/flashinfer/jit/comm.py +++ b/flashinfer/jit/comm.py @@ -78,3 +78,32 @@ def gen_vllm_comm_module() -> JitSpec: jit_env.FLASHINFER_CSRC_DIR / "vllm_custom_all_reduce.cu", ], ) + + +def gen_mnnvl_a2a_module() -> JitSpec: + return gen_jit_spec( + "mnnvl_a2a", + [ + jit_env.FLASHINFER_CSRC_DIR / "trtllm_moe_a2a.cu", + jit_env.FLASHINFER_CSRC_DIR + / "nv_internal" + / "tensorrt_llm" + / "kernels" + / "communicationKernels" + / "moeAlltoAllKernels.cu", + jit_env.FLASHINFER_CSRC_DIR + / "nv_internal" + / "cpp" + / "common" + / "envUtils.cpp", + jit_env.FLASHINFER_CSRC_DIR + / "nv_internal" + / "cpp" + / "common" + / "tllmException.cpp", + ], + extra_include_paths=[ + str(jit_env.FLASHINFER_CSRC_DIR / "nv_internal"), + str(jit_env.FLASHINFER_CSRC_DIR / "nv_internal" / "include"), + ], + ) diff --git a/tests/comm/test_mnnvl_a2a.py b/tests/comm/test_mnnvl_a2a.py new file mode 100644 index 0000000000..aa6a675b71 --- /dev/null +++ b/tests/comm/test_mnnvl_a2a.py @@ -0,0 +1,366 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import traceback + +import pytest +import torch +from mpi4py import MPI + +from flashinfer.comm import MoeAlltoAll +from flashinfer.comm.mapping import Mapping + + +@pytest.fixture(autouse=True) +def setup_test(): + torch.manual_seed(0x1234) + + +def compute_target_rank_id(expert_id, num_experts_per_rank): + """Compute the rank that owns a given expert using contiguous partitioning.""" + return expert_id // num_experts_per_rank + + +def generate_token_selected_experts( + local_num_tokens: int, ep_size: int, num_experts_per_rank: int, top_k: int +) -> torch.Tensor: + """Generate global expert IDs tensor.""" + return torch.randint( + 0, + ep_size * num_experts_per_rank, + (local_num_tokens, top_k), + dtype=torch.int32, + device="cuda", + ) + + +def create_experts( + num_experts_per_rank, hidden_size, ep_rank, device, dtype=torch.bfloat16 +): + """ + Create a 3D tensor of expert weights for a given rank. + + Returns: + experts: Tensor of shape [num_experts_per_rank, hidden_size, hidden_size] + """ + experts = torch.empty( + (num_experts_per_rank, hidden_size, hidden_size), dtype=dtype, device=device + ) + for i in range(num_experts_per_rank): + torch.manual_seed(ep_rank * 1000 + i) + torch.nn.init.xavier_uniform_(experts[i]) + return experts + + +def fake_moe( + hidden_states, + token_selected_experts, + token_final_scales, + experts, + is_ep=False, + ep_rank=None, + num_experts_per_rank=None, +): + """ + Emulate MoE computation. + + Returns: + processed_states: [num_tokens, hidden_size] + """ + num_tokens, _ = hidden_states.shape + _, top_k = token_selected_experts.shape + + if is_ep: + assert ep_rank is not None and num_experts_per_rank is not None + + processed_states = torch.zeros_like(hidden_states) + + for token_idx in range(num_tokens): + for k in range(top_k): + expert_id = token_selected_experts[token_idx, k].item() + if is_ep: + if not ( + expert_id >= ep_rank * num_experts_per_rank + and expert_id < (ep_rank + 1) * num_experts_per_rank + ): + continue + local_expert_id = expert_id - ep_rank * num_experts_per_rank + expert = experts[local_expert_id] + else: + expert = experts[expert_id] + + scale = token_final_scales[token_idx, k] + processed_states[token_idx] += hidden_states[token_idx] @ expert * scale + + return processed_states + + +def make_bfloat16_payloads( + local_num_tokens: int, + hidden_size: int, + top_k: int, + rank: int, + token_selected_experts: torch.Tensor, +) -> tuple[list, int]: + """Create bfloat16 test payloads.""" + payloads = [] + + # Payload 0: Hidden states + hidden_states = torch.randn( + local_num_tokens, hidden_size, dtype=torch.bfloat16, device="cuda" + ) + hidden_states += rank # Add rank offset for verification + payloads.append(hidden_states) + + # Payload 1: token_selected_experts + payloads.append(token_selected_experts) + + # Payload 2: token_final_scales + token_final_scales = torch.rand( + local_num_tokens, top_k, dtype=torch.bfloat16, device="cuda" + ) + payloads.append(token_final_scales) + + return payloads, 1 # expert_id_payload_index = 1 + + +def run_moe_a2a_dispatch_single_rank( + ep_size, all_num_tokens, top_k, num_experts_per_rank, hidden_size +): + """Test MoE A2A dispatch on a single rank.""" + comm = MPI.COMM_WORLD + rank = comm.Get_rank() + world_size = comm.Get_size() + + if world_size != ep_size: + pytest.skip(f"Test requires exactly {ep_size} ranks") + + torch.cuda.set_device(rank) + + mapping = Mapping( + world_size=world_size, + rank=rank, + gpus_per_node=world_size, + tp_size=world_size, + pp_size=1, + cp_size=1, + ) + + local_num_tokens = all_num_tokens[rank] + max_num_tokens = max(all_num_tokens) + + # Generate inputs + token_selected_experts = generate_token_selected_experts( + local_num_tokens, ep_size, num_experts_per_rank, top_k + ) + + payloads, expert_id_payload_index = make_bfloat16_payloads( + local_num_tokens, hidden_size, top_k, rank, token_selected_experts + ) + + # Initialize MoeAlltoAll + moe_a2a = MoeAlltoAll( + mapping=mapping, + max_num_tokens=max_num_tokens, + top_k=top_k, + num_experts=ep_size * num_experts_per_rank, + workspace_size_per_rank=512 * 1024 * 1024, + ) + + # Dispatch + recv_tensors = moe_a2a.dispatch( + token_selected_experts=token_selected_experts, + input_payloads=payloads, + runtime_max_tokens_per_rank=max_num_tokens, + ) + + # Verify shapes + assert len(recv_tensors) == len(payloads) + for i, recv_tensor in enumerate(recv_tensors): + assert recv_tensor.shape[0] == ep_size + assert recv_tensor.shape[1] == max_num_tokens + assert recv_tensor.shape[2] == payloads[i].shape[1] + + print(f"[Rank {rank}] Dispatch test passed") + + +def run_moe_a2a_dispatch_moe_combine_single_rank( + ep_size, all_num_tokens, top_k, num_experts_per_rank, hidden_size +): + """Test full MoE A2A dispatch + expert processing + combine cycle.""" + comm = MPI.COMM_WORLD + rank = comm.Get_rank() + world_size = comm.Get_size() + + if world_size != ep_size: + pytest.skip(f"Test requires exactly {ep_size} ranks") + + torch.cuda.set_device(rank) + + mapping = Mapping( + world_size=world_size, + rank=rank, + gpus_per_node=world_size, + tp_size=world_size, + pp_size=1, + cp_size=1, + ) + + local_num_tokens = all_num_tokens[rank] + max_num_tokens = max(all_num_tokens) + + # Generate inputs + token_selected_experts = generate_token_selected_experts( + local_num_tokens, ep_size, num_experts_per_rank, top_k + ) + + payloads, expert_id_payload_index = make_bfloat16_payloads( + local_num_tokens, hidden_size, top_k, rank, token_selected_experts + ) + + hidden_states = payloads[0] + token_final_scales = payloads[2] + + # Create experts for this rank + experts = create_experts( + num_experts_per_rank, hidden_size, rank, "cuda", dtype=torch.bfloat16 + ) + + # Compute reference (single-GPU MoE) + all_experts = torch.cat( + [ + create_experts( + num_experts_per_rank, hidden_size, r, "cuda", dtype=torch.bfloat16 + ) + for r in range(ep_size) + ], + dim=0, + ) + reference_output = fake_moe( + hidden_states, + token_selected_experts, + token_final_scales, + all_experts, + is_ep=False, + ) + + # Initialize MoeAlltoAll + moe_a2a = MoeAlltoAll( + mapping=mapping, + max_num_tokens=max_num_tokens, + top_k=top_k, + num_experts=ep_size * num_experts_per_rank, + workspace_size_per_rank=512 * 1024 * 1024, + ) + + # Dispatch + recv_tensors = moe_a2a.dispatch( + token_selected_experts=token_selected_experts, + input_payloads=payloads, + runtime_max_tokens_per_rank=max_num_tokens, + ) + + # Unpack received tensors + hidden_states_recv = recv_tensors[0] # [ep_size, max_tokens, hidden_size] + token_selected_experts_recv = recv_tensors[1] # [ep_size, max_tokens, top_k] + token_final_scales_recv = recv_tensors[2] # [ep_size, max_tokens, top_k] + + # Get workspace-backed tensor for output + moe_output = moe_a2a.get_combine_payload_tensor_in_workspace( + runtime_max_tokens_per_rank=max_num_tokens, + hidden_size=hidden_size, + dtype=torch.bfloat16, + ) + moe_output.zero_() + + # Process each rank's tokens with local experts + for source_rank in range(ep_size): + source_num_tokens = all_num_tokens[source_rank] + for token_idx in range(source_num_tokens): + for k in range(top_k): + expert_id = token_selected_experts_recv[ + source_rank, token_idx, k + ].item() + local_expert_id = expert_id - rank * num_experts_per_rank + + if 0 <= local_expert_id < num_experts_per_rank: + token_hidden = hidden_states_recv[source_rank, token_idx] + scale = token_final_scales_recv[source_rank, token_idx, k] + expert_out = token_hidden @ experts[local_expert_id] + output_idx = source_rank * max_num_tokens + token_idx + moe_output[output_idx] += expert_out * scale + + # Combine + combined_output = moe_a2a.combine( + payload=moe_output.view(ep_size, max_num_tokens, hidden_size), + runtime_max_tokens_per_rank=max_num_tokens, + payload_in_workspace=True, + ) + + # Verify against reference + torch.testing.assert_close(combined_output, reference_output, rtol=1e-2, atol=1e-2) + + print(f"[Rank {rank}] Full cycle test passed") + + +@pytest.mark.parametrize("ep_size", [2, 4]) +@pytest.mark.parametrize("all_num_tokens", [[64, 64], [32, 48, 64, 80]]) +@pytest.mark.parametrize("top_k", [2, 4]) +@pytest.mark.parametrize("num_experts_per_rank", [2, 4]) +@pytest.mark.parametrize("hidden_size", [128, 256]) +def test_moe_a2a_dispatch( + ep_size, all_num_tokens, top_k, num_experts_per_rank, hidden_size +): + """Test MoE A2A dispatch operation.""" + if len(all_num_tokens) != ep_size: + pytest.skip( + f"all_num_tokens length {len(all_num_tokens)} must match ep_size {ep_size}" + ) + + try: + run_moe_a2a_dispatch_single_rank( + ep_size, all_num_tokens, top_k, num_experts_per_rank, hidden_size + ) + except Exception as e: + traceback.print_exc() + raise e + + +@pytest.mark.parametrize("ep_size", [2, 4]) +@pytest.mark.parametrize("all_num_tokens", [[64, 64], [32, 48, 64, 80]]) +@pytest.mark.parametrize("top_k", [2, 4]) +@pytest.mark.parametrize("num_experts_per_rank", [2, 4]) +@pytest.mark.parametrize("hidden_size", [128, 256]) +def test_moe_a2a_dispatch_moe_combine( + ep_size, all_num_tokens, top_k, num_experts_per_rank, hidden_size +): + """Test full MoE A2A dispatch + expert processing + combine cycle.""" + if len(all_num_tokens) != ep_size: + pytest.skip( + f"all_num_tokens length {len(all_num_tokens)} must match ep_size {ep_size}" + ) + + try: + run_moe_a2a_dispatch_moe_combine_single_rank( + ep_size, all_num_tokens, top_k, num_experts_per_rank, hidden_size + ) + except Exception as e: + traceback.print_exc() + raise e + + +if __name__ == "__main__": + # Run with: mpirun -n 2 python -m pytest tests/comm/test_mnnvl_a2a.py -v + pytest.main([__file__, "-v", "-s"]) From 881bfc6f31deeef96bee415ddecb7ccf4bef62c5 Mon Sep 17 00:00:00 2001 From: djns99 <40156487+djns99@users.noreply.github.com> Date: Wed, 19 Nov 2025 13:20:52 +1300 Subject: [PATCH 02/25] Properly cache jit compiled module --- csrc/trtllm_moe_a2a.cu | 81 ++++----- flashinfer/comm/trtllm_moe_a2a.py | 268 +++++++++++++++++++++--------- 2 files changed, 234 insertions(+), 115 deletions(-) diff --git a/csrc/trtllm_moe_a2a.cu b/csrc/trtllm_moe_a2a.cu index d474214de7..b980f74020 100644 --- a/csrc/trtllm_moe_a2a.cu +++ b/csrc/trtllm_moe_a2a.cu @@ -108,31 +108,34 @@ Tensor moeA2AInitializeOp(TensorView workspace, int64_t epRank, int64_t epSize, return metainfo; } -Tuple, int64_t> moeA2ADispatchOp(TensorView tokenSelectedExperts, - TensorView payloadPtrsTensor, - TensorView payloadElementSizesTensor, - TensorView payloadElementsPerTokenTensor, - TensorView workspace, TensorView metainfo, - int64_t runtimeMaxTokensPerRank, int64_t epRank, - int64_t epSize, int64_t topK, int64_t numExperts) { +Tuple, Array, int64_t> moeA2ADispatchOp( + TensorView tokenSelectedExperts, Array inputPayloads, TensorView workspace, + TensorView metainfo, int64_t runtimeMaxTokensPerRank, int64_t epRank, int64_t epSize, + int64_t topK, int64_t numExperts) { using tl_throughput::PayloadDescriptor; + fflush(stdout); + CHECK_INPUT(tokenSelectedExperts); CHECK_INPUT_TYPE(tokenSelectedExperts, dl_int32); TVM_FFI_ICHECK_EQ(tokenSelectedExperts.ndim(), 2) << "token_selected_experts must be 2D"; TVM_FFI_ICHECK_EQ(tokenSelectedExperts.size(1), topK) << "token_selected_experts shape mismatch"; - CHECK_INPUT_TYPE(payloadPtrsTensor, dl_int64); - CHECK_INPUT_TYPE(payloadElementSizesTensor, dl_int32); - CHECK_INPUT_TYPE(payloadElementsPerTokenTensor, dl_int32); - TVM_FFI_ICHECK_EQ(payloadPtrsTensor.ndim(), 1); - TVM_FFI_ICHECK_EQ(payloadElementSizesTensor.ndim(), 1); - TVM_FFI_ICHECK_EQ(payloadElementsPerTokenTensor.ndim(), 1); - - int numPayloads = static_cast(payloadPtrsTensor.size(0)); + int numPayloads = static_cast(inputPayloads.size()); TVM_FFI_ICHECK(numPayloads > 0) << "At least one payload is required"; - TVM_FFI_ICHECK(numPayloads <= tl_throughput::kMaxPayloads) << "Too many payloads"; - TVM_FFI_ICHECK_EQ(payloadElementSizesTensor.size(0), numPayloads); - TVM_FFI_ICHECK_EQ(payloadElementsPerTokenTensor.size(0), numPayloads); + TVM_FFI_ICHECK(numPayloads <= tl_throughput::kMaxPayloads) + << "Too many payloads: " << numPayloads << " > " << tl_throughput::kMaxPayloads; + + auto localNumTokens = static_cast(tokenSelectedExperts.size(0)); + TVM_FFI_ICHECK(localNumTokens > 0) << "local_num_tokens must be positive"; + + // Validate all payloads and calculate sizes + for (int i = 0; i < numPayloads; ++i) { + auto const& payload = inputPayloads[i]; + CHECK_INPUT(payload); + TVM_FFI_ICHECK_EQ(payload.ndim(), 2) << "payload " << i << " must be 2D"; + TVM_FFI_ICHECK_EQ(payload.size(0), localNumTokens) + << "payload " << i << " first dimension must match local_num_tokens"; + } CHECK_CPU(metainfo); CHECK_INPUT_TYPE(metainfo, dl_int64); @@ -151,29 +154,30 @@ Tuple, int64_t> moeA2ADispatchOp(TensorView tokenSelectedExperts, << "num_experts must be divisible by ep_size"; TVM_FFI_ICHECK(topK > 0 && topK <= tl_throughput::kMaxTopK); - auto localNumTokens = static_cast(tokenSelectedExperts.size(0)); - TVM_FFI_ICHECK(localNumTokens > 0) << "local_num_tokens must be positive"; - - auto* payloadPtrs = static_cast(payloadPtrsTensor.data_ptr()); - auto* payloadEltSizes = static_cast(payloadElementSizesTensor.data_ptr()); - auto* payloadEltPerToken = static_cast(payloadElementsPerTokenTensor.data_ptr()); - + // Calculate payload descriptors and sizes from input tensors std::vector payloadDescriptors(numPayloads); std::vector payloadByteSizes(numPayloads); int64_t totalBytesNeeded = 0; + for (int i = 0; i < numPayloads; ++i) { - payloadDescriptors[i].src_data = reinterpret_cast(payloadPtrs[i]); - payloadDescriptors[i].element_size = payloadEltSizes[i]; - payloadDescriptors[i].elements_per_token = payloadEltPerToken[i]; - int64_t bytesPerPayload = static_cast(epSize) * runtimeMaxTokensPerRank * - payloadEltPerToken[i] * payloadEltSizes[i]; + auto const& payload = inputPayloads[i]; + int elementsPerToken = static_cast(payload.size(1)); + int elementSize = static_cast(get_element_size(payload)); + + payloadDescriptors[i].src_data = payload.data_ptr(); + payloadDescriptors[i].element_size = elementSize; + payloadDescriptors[i].elements_per_token = elementsPerToken; + + int64_t bytesPerPayload = + static_cast(epSize) * runtimeMaxTokensPerRank * elementsPerToken * elementSize; payloadByteSizes[i] = bytesPerPayload; totalBytesNeeded += bytesPerPayload; } auto* workspaceBase = static_cast(workspace.data_ptr()); auto strideBytes = workspace.stride(0); - auto* rankWorkspacePtr = workspaceBase + epRank * strideBytes; + size_t rankWorkspaceOffset = epRank * strideBytes; + auto* rankWorkspacePtr = workspaceBase + rankWorkspaceOffset; int64_t sizePerRank = workspace.size(1); int64_t requiredSize = offsets[fi_throughput::PAYLOAD_DATA_OFFSET_INDEX] + totalBytesNeeded; @@ -225,17 +229,18 @@ Tuple, int64_t> moeA2ADispatchOp(TensorView tokenSelectedExperts, TVM_FFI_ICHECK(launchErr == cudaSuccess) << "moe_a2a_dispatch launch failed: " << cudaGetErrorString(launchErr); - Array recvPtrs; - recvPtrs.reserve(numPayloads); + Array recvOffsets; + Array recvByteSizes; + recvOffsets.reserve(numPayloads); size_t localOffset = static_cast(offsets[fi_throughput::PAYLOAD_DATA_OFFSET_INDEX]); - for (int payloadIdx = 0; payloadIdx < numPayloads; ++payloadIdx) { - auto* ptr = rankWorkspacePtr + localOffset; - recvPtrs.push_back(reinterpret_cast(ptr)); - localOffset += payloadByteSizes[payloadIdx]; + for (auto payloadByteSize : payloadByteSizes) { + recvOffsets.push_back(rankWorkspaceOffset + localOffset); + recvByteSizes.push_back(payloadByteSize); + localOffset += payloadByteSize; } int64_t combinePayloadOffset = static_cast(alignOffset(localOffset)); - return Tuple(recvPtrs, combinePayloadOffset); + return Tuple(recvOffsets, recvByteSizes, combinePayloadOffset); } nvinfer1::DataType toNvDataType(DLDataType dtype) { diff --git a/flashinfer/comm/trtllm_moe_a2a.py b/flashinfer/comm/trtllm_moe_a2a.py index 59c2ae0953..cdb72e09e8 100644 --- a/flashinfer/comm/trtllm_moe_a2a.py +++ b/flashinfer/comm/trtllm_moe_a2a.py @@ -8,13 +8,16 @@ # TODO Review from dataclasses import dataclass +from types import SimpleNamespace from typing import Optional import torch +import functools from .mnnvl import MnnvlMemory from .mapping import Mapping from ..jit.comm import gen_mnnvl_a2a_module +from ..utils import register_custom_op @dataclass @@ -26,9 +29,193 @@ class _A2AState: combine_payload_offset: Optional[int] = None +@functools.cache def get_mnnvl_a2a_module(): """Get or build the MNNVL A2A JIT module.""" - return gen_mnnvl_a2a_module().build_and_load() + module = gen_mnnvl_a2a_module().build_and_load() + + @register_custom_op( + "flashinfer::moe_a2a_initialize", + mutates_args=[], + ) + def moe_a2a_initialize( + workspace: torch.Tensor, + ep_rank: int, + ep_size: int, + max_num_tokens: int, + ): + return module.moe_a2a_initialize(workspace, ep_rank, ep_size, max_num_tokens) + + @register_custom_op( + "flashinfer::moe_a2a_dispatch", + mutates_args=[], + ) + def moe_a2a_dispatch( + token_selected_experts: torch.Tensor, + input_payloads: list[torch.Tensor], + workspace: torch.Tensor, + metainfo: torch.Tensor, + runtime_max_tokens_per_rank: int, + ep_rank: int, + ep_size: int, + top_k: int, + num_experts: int, + ): + """ + Dispatch tokens and payloads to expert ranks. + + Args: + token_selected_experts: [local_num_tokens, top_k] int32 tensor + input_payloads: List of [local_num_tokens, *] tensors to dispatch + workspace: [ep_size, size_per_rank] workspace tensor + metainfo: Metadata tensor from initialize + runtime_max_tokens_per_rank: Max tokens per rank in this batch + ep_rank: Current expert parallel rank + ep_size: Total expert parallel size + top_k: Number of experts per token + num_experts: Total number of experts + + Returns: + recv_tensors: List of [ep_size, max_tokens, *] tensors + combine_payload_offset: Offset for combine payload region + """ + print( + f"moe_a2a_dispatch: token_selected_experts.shape={token_selected_experts.shape}, input_payloads={input_payloads}, workspace.shape={workspace.shape}, metainfo.shape={metainfo.shape}, runtime_max_tokens_per_rank={runtime_max_tokens_per_rank}, ep_rank={ep_rank}, ep_size={ep_size}, top_k={top_k}, num_experts={num_experts}" + ) + recv_offsets, recv_sizes, combine_payload_offset = module.moe_a2a_dispatch( + token_selected_experts, + input_payloads, + workspace, + metainfo, + runtime_max_tokens_per_rank, + ep_rank, + ep_size, + top_k, + num_experts, + ) + print( + f"moe_a2a_dispatch: recv_offsets={recv_offsets}, recv_sizes={recv_sizes}, combine_payload_offset={combine_payload_offset}" + ) + workspace_base = workspace.flatten().view(dtype=torch.uint8) + output_payloads = [] + for input_payload, offset, size in zip( + input_payloads, recv_offsets, recv_sizes, strict=True + ): + output_payload = ( + workspace_base[offset : offset + size] + .view([ep_size, runtime_max_tokens_per_rank, -1]) + .view(dtype=input_payload.dtype) + ) + output_payloads.append(output_payload) + + return output_payloads, combine_payload_offset + + @register_custom_op( + "flashinfer::moe_a2a_combine", + mutates_args=[], + ) + def moe_a2a_combine( + payload: torch.Tensor, + local_num_tokens: int, + workspace: torch.Tensor, + metainfo: torch.Tensor, + runtime_max_tokens_per_rank: int, + ep_rank: int, + ep_size: int, + top_k: int, + combine_payload_offset: int, + payload_in_workspace: bool = False, + ) -> torch.Tensor: + """ + Combine expert outputs back to originating tokens. + + Args: + payload: [ep_size, max_tokens, elements_per_token] tensor + local_num_tokens: Number of tokens on this rank + workspace: [ep_size, size_per_rank] workspace tensor + metainfo: Metadata tensor from initialize + runtime_max_tokens_per_rank: Max tokens per rank in this batch + ep_rank: Current expert parallel rank + ep_size: Total expert parallel size + top_k: Number of experts per token + combine_payload_offset: Offset from dispatch + payload_in_workspace: If True, payload is workspace-backed + + Returns: + output: [local_num_tokens, elements_per_token] tensor + """ + return module.moe_a2a_combine( + payload, + local_num_tokens, + workspace, + metainfo, + runtime_max_tokens_per_rank, + ep_rank, + ep_size, + top_k, + combine_payload_offset, + payload_in_workspace, + ) + + @register_custom_op( + "flashinfer::moe_a2a_sanitize_expert_ids", + mutates_args=[], + ) + def moe_a2a_sanitize_expert_ids( + expert_ids: torch.Tensor, + workspace: torch.Tensor, + metainfo: torch.Tensor, + ep_rank: int, + invalid_expert_id: int, + ): + return module.moe_a2a_sanitize_expert_ids( + expert_ids, workspace, metainfo, ep_rank, invalid_expert_id + ) + + @register_custom_op( + "flashinfer::moe_a2a_get_combine_payload_tensor", + mutates_args=[], + ) + def moe_a2a_get_combine_payload_tensor( + workspace: torch.Tensor, + ep_rank: int, + ep_size: int, + runtime_max_tokens_per_rank: int, + combine_payload_offset: int, + dtype: torch.dtype, + hidden_size: int, + ) -> torch.Tensor: + """ + Get combine payload tensor backed by workspace (zero-copy). + + Args: + workspace: [ep_size, size_per_rank] workspace tensor + ep_rank: Current expert parallel rank + ep_size: Total expert parallel size + runtime_max_tokens_per_rank: Max tokens per rank in this batch + combine_payload_offset: Offset from dispatch + dtype: Data type for the tensor + hidden_size: Hidden dimension size + + Returns: + tensor: [ep_size * max_tokens, hidden_size] workspace-backed tensor + """ + return module.moe_a2a_get_combine_payload_tensor( + workspace, + ep_rank, + ep_size, + runtime_max_tokens_per_rank, + combine_payload_offset, + dtype, + hidden_size, + ) + + return SimpleNamespace( + moe_a2a_initialize=moe_a2a_initialize, + moe_a2a_dispatch=moe_a2a_dispatch, + moe_a2a_combine=moe_a2a_combine, + moe_a2a_get_combine_payload_tensor=moe_a2a_get_combine_payload_tensor, + ) def moe_a2a_initialize( @@ -36,19 +223,7 @@ def moe_a2a_initialize( ep_rank: int, ep_size: int, max_num_tokens: int, -) -> torch.Tensor: - """ - Initialize MoE A2A workspace. - - Args: - workspace: [ep_size, size_per_rank] uint8 tensor (MNNVL memory) - ep_rank: Current expert parallel rank - ep_size: Total expert parallel size - max_num_tokens: Maximum number of tokens supported - - Returns: - metainfo: Tensor containing workspace offsets - """ +): return get_mnnvl_a2a_module().moe_a2a_initialize( workspace, ep_rank, ep_size, max_num_tokens ) @@ -65,24 +240,6 @@ def moe_a2a_dispatch( top_k: int, num_experts: int, ): - """ - Dispatch tokens and payloads to expert ranks. - - Args: - token_selected_experts: [local_num_tokens, top_k] int32 tensor - input_payloads: List of [local_num_tokens, *] tensors to dispatch - workspace: [ep_size, size_per_rank] workspace tensor - metainfo: Metadata tensor from initialize - runtime_max_tokens_per_rank: Max tokens per rank in this batch - ep_rank: Current expert parallel rank - ep_size: Total expert parallel size - top_k: Number of experts per token - num_experts: Total number of experts - - Returns: - recv_tensors: List of [ep_size, max_tokens, *] tensors - combine_payload_offset: Offset for combine payload region - """ return get_mnnvl_a2a_module().moe_a2a_dispatch( token_selected_experts, input_payloads, @@ -108,24 +265,6 @@ def moe_a2a_combine( combine_payload_offset: int, payload_in_workspace: bool = False, ) -> torch.Tensor: - """ - Combine expert outputs back to originating tokens. - - Args: - payload: [ep_size, max_tokens, elements_per_token] tensor - local_num_tokens: Number of tokens on this rank - workspace: [ep_size, size_per_rank] workspace tensor - metainfo: Metadata tensor from initialize - runtime_max_tokens_per_rank: Max tokens per rank in this batch - ep_rank: Current expert parallel rank - ep_size: Total expert parallel size - top_k: Number of experts per token - combine_payload_offset: Offset from dispatch - payload_in_workspace: If True, payload is workspace-backed - - Returns: - output: [local_num_tokens, elements_per_token] tensor - """ return get_mnnvl_a2a_module().moe_a2a_combine( payload, local_num_tokens, @@ -146,18 +285,8 @@ def moe_a2a_sanitize_expert_ids( metainfo: torch.Tensor, ep_rank: int, invalid_expert_id: int, -) -> None: - """ - Sanitize expert IDs for invalid tokens. - - Args: - expert_ids: [ep_size, max_tokens, top_k] int32 tensor (modified in-place) - workspace: [ep_size, size_per_rank] workspace tensor - metainfo: Metadata tensor from initialize - ep_rank: Current expert parallel rank - invalid_expert_id: Value to fill for invalid tokens - """ - get_mnnvl_a2a_module().moe_a2a_sanitize_expert_ids( +): + return get_mnnvl_a2a_module().moe_a2a_sanitize_expert_ids( expert_ids, workspace, metainfo, ep_rank, invalid_expert_id ) @@ -171,21 +300,6 @@ def moe_a2a_get_combine_payload_tensor( dtype: torch.dtype, hidden_size: int, ) -> torch.Tensor: - """ - Get combine payload tensor backed by workspace (zero-copy). - - Args: - workspace: [ep_size, size_per_rank] workspace tensor - ep_rank: Current expert parallel rank - ep_size: Total expert parallel size - runtime_max_tokens_per_rank: Max tokens per rank in this batch - combine_payload_offset: Offset from dispatch - dtype: Data type for the tensor - hidden_size: Hidden dimension size - - Returns: - tensor: [ep_size * max_tokens, hidden_size] workspace-backed tensor - """ return get_mnnvl_a2a_module().moe_a2a_get_combine_payload_tensor( workspace, ep_rank, From e6a34578e5b4f5a15820d9c6d49a4e1a84e4caf5 Mon Sep 17 00:00:00 2001 From: djns99 <40156487+djns99@users.noreply.github.com> Date: Wed, 19 Nov 2025 13:24:41 +1300 Subject: [PATCH 03/25] Cleanup prints --- flashinfer/comm/trtllm_moe_a2a.py | 6 ------ tests/comm/test_mnnvl_a2a.py | 4 ---- 2 files changed, 10 deletions(-) diff --git a/flashinfer/comm/trtllm_moe_a2a.py b/flashinfer/comm/trtllm_moe_a2a.py index cdb72e09e8..caf73f14d3 100644 --- a/flashinfer/comm/trtllm_moe_a2a.py +++ b/flashinfer/comm/trtllm_moe_a2a.py @@ -79,9 +79,6 @@ def moe_a2a_dispatch( recv_tensors: List of [ep_size, max_tokens, *] tensors combine_payload_offset: Offset for combine payload region """ - print( - f"moe_a2a_dispatch: token_selected_experts.shape={token_selected_experts.shape}, input_payloads={input_payloads}, workspace.shape={workspace.shape}, metainfo.shape={metainfo.shape}, runtime_max_tokens_per_rank={runtime_max_tokens_per_rank}, ep_rank={ep_rank}, ep_size={ep_size}, top_k={top_k}, num_experts={num_experts}" - ) recv_offsets, recv_sizes, combine_payload_offset = module.moe_a2a_dispatch( token_selected_experts, input_payloads, @@ -93,9 +90,6 @@ def moe_a2a_dispatch( top_k, num_experts, ) - print( - f"moe_a2a_dispatch: recv_offsets={recv_offsets}, recv_sizes={recv_sizes}, combine_payload_offset={combine_payload_offset}" - ) workspace_base = workspace.flatten().view(dtype=torch.uint8) output_payloads = [] for input_payload, offset, size in zip( diff --git a/tests/comm/test_mnnvl_a2a.py b/tests/comm/test_mnnvl_a2a.py index aa6a675b71..13c597518b 100644 --- a/tests/comm/test_mnnvl_a2a.py +++ b/tests/comm/test_mnnvl_a2a.py @@ -193,8 +193,6 @@ def run_moe_a2a_dispatch_single_rank( assert recv_tensor.shape[1] == max_num_tokens assert recv_tensor.shape[2] == payloads[i].shape[1] - print(f"[Rank {rank}] Dispatch test passed") - def run_moe_a2a_dispatch_moe_combine_single_rank( ep_size, all_num_tokens, top_k, num_experts_per_rank, hidden_size @@ -312,8 +310,6 @@ def run_moe_a2a_dispatch_moe_combine_single_rank( # Verify against reference torch.testing.assert_close(combined_output, reference_output, rtol=1e-2, atol=1e-2) - print(f"[Rank {rank}] Full cycle test passed") - @pytest.mark.parametrize("ep_size", [2, 4]) @pytest.mark.parametrize("all_num_tokens", [[64, 64], [32, 48, 64, 80]]) From fa69945189c6e3b1cba04903b91d3d764b003d20 Mon Sep 17 00:00:00 2001 From: djns99 <40156487+djns99@users.noreply.github.com> Date: Wed, 19 Nov 2025 17:32:34 +1300 Subject: [PATCH 04/25] Combine tests run without crashing --- .../moeAlltoAllKernels.cu | 2 +- csrc/trtllm_moe_a2a.cu | 38 +- docs/api/comm.rst | 1 - flashinfer/comm/__init__.py | 3 - flashinfer/comm/trtllm_moe_a2a.py | 190 +++--- tests/comm/test_mnnvl_a2a.py | 597 +++++++++++++++--- 6 files changed, 616 insertions(+), 215 deletions(-) diff --git a/csrc/nv_internal/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.cu b/csrc/nv_internal/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.cu index 91dccdc33f..e46c0e9e63 100644 --- a/csrc/nv_internal/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.cu +++ b/csrc/nv_internal/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.cu @@ -400,7 +400,7 @@ __global__ void moeA2ADispatchKernel( asm volatile("ld.relaxed.sys.u32 %0, [%1];" : "=r"(flag_value) : "l"(flag_ptr)); #if ENABLE_DEBUG_PRINT printf( - "combine: ---Rank %d received completion flag from rank %d, flag_value: %d, " + "dispatch: ---Rank %d received completion flag from rank %d, flag_value: %d, " "expected_value: " "%d, address: %p\n", rank_id, peer_rank, flag_value, expected_value, flag_ptr); diff --git a/csrc/trtllm_moe_a2a.cu b/csrc/trtllm_moe_a2a.cu index b980f74020..0407eb201a 100644 --- a/csrc/trtllm_moe_a2a.cu +++ b/csrc/trtllm_moe_a2a.cu @@ -31,10 +31,9 @@ #include "tensorrt_llm/thop/moeAlltoAllMeta.h" #include "tvm_ffi_utils.h" -// TODO Review - using tvm::ffi::Array; using tvm::ffi::Shape; +using tvm::ffi::String; using tvm::ffi::Tensor; using tvm::ffi::TensorView; using tvm::ffi::Tuple; @@ -105,6 +104,10 @@ Tensor moeA2AInitializeOp(TensorView workspace, int64_t epRank, int64_t epSize, Tensor metainfo = alloc_tensor({fi_throughput::NUM_METAINFO_FIELDS}, dl_int64, cpu); auto* metaPtr = static_cast(metainfo.data_ptr()); std::copy(offsets.begin(), offsets.end(), metaPtr); + + auto err = cudaStreamSynchronize(stream); + TVM_FFI_ICHECK(err == cudaSuccess) << "cudaStreamSynchronize failed: " << cudaGetErrorString(err); + return metainfo; } @@ -371,25 +374,20 @@ void moeA2ASanitizeExpertIdsOp(TensorView expertIds, TensorView workspace, Tenso static_cast(runtimeMaxTokensPerRank), static_cast(topK), get_current_stream()); } -int64_t moeA2AGetCombinePayloadPtrOp(TensorView workspace, int64_t epRank, int64_t epSize, - int64_t runtimeMaxTokensPerRank, int64_t combinePayloadOffset, - int64_t elementsPerToken, int64_t elementSizeBytes) { - CHECK_INPUT_TYPE(workspace, dl_uint8); - TVM_FFI_ICHECK_EQ(workspace.ndim(), 2); - TVM_FFI_ICHECK_EQ(workspace.size(0), epSize); - TVM_FFI_ICHECK(epRank >= 0 && epRank < epSize); - TVM_FFI_ICHECK(runtimeMaxTokensPerRank > 0); - TVM_FFI_ICHECK(elementsPerToken > 0); - TVM_FFI_ICHECK(elementSizeBytes > 0); +// Expose metainfo index constants for Python access +// Returns a tuple of (names, values) for all metainfo constants +Tuple, Array> getMoeA2AMetaInfoIndexPairs() { + auto pairs = fi_throughput::getMoeA2AMetaInfoIndexPairs(); - int64_t sizePerRank = workspace.size(1); - int64_t bytesNeeded = epSize * runtimeMaxTokensPerRank * elementsPerToken * elementSizeBytes; - TVM_FFI_ICHECK(combinePayloadOffset >= 0 && combinePayloadOffset + bytesNeeded <= sizePerRank) - << "combine payload exceeds workspace capacity"; + Array names; + Array values; - auto* basePtr = static_cast(workspace.data_ptr()); - auto* rankPtr = basePtr + epRank * workspace.stride(0); - return reinterpret_cast(rankPtr + combinePayloadOffset); + for (const auto& pair : pairs) { + names.push_back(pair.first); + values.push_back(pair.second); + } + + return Tuple{names, values}; } } // namespace @@ -398,4 +396,4 @@ TVM_FFI_DLL_EXPORT_TYPED_FUNC(moe_a2a_initialize, moeA2AInitializeOp); TVM_FFI_DLL_EXPORT_TYPED_FUNC(moe_a2a_dispatch, moeA2ADispatchOp); TVM_FFI_DLL_EXPORT_TYPED_FUNC(moe_a2a_combine, moeA2ACombineOp); TVM_FFI_DLL_EXPORT_TYPED_FUNC(moe_a2a_sanitize_expert_ids, moeA2ASanitizeExpertIdsOp); -TVM_FFI_DLL_EXPORT_TYPED_FUNC(moe_a2a_get_combine_payload_ptr, moeA2AGetCombinePayloadPtrOp); +TVM_FFI_DLL_EXPORT_TYPED_FUNC(moe_a2a_get_metainfo_index_pairs, getMoeA2AMetaInfoIndexPairs); diff --git a/docs/api/comm.rst b/docs/api/comm.rst index 9d38285ad0..f852073ae4 100644 --- a/docs/api/comm.rst +++ b/docs/api/comm.rst @@ -142,4 +142,3 @@ MNNVL A2A (Throughput Backend) moe_a2a_dispatch moe_a2a_combine moe_a2a_sanitize_expert_ids - moe_a2a_get_combine_payload_tensor diff --git a/flashinfer/comm/__init__.py b/flashinfer/comm/__init__.py index 4860187fc5..7aa1ff1cfe 100644 --- a/flashinfer/comm/__init__.py +++ b/flashinfer/comm/__init__.py @@ -43,9 +43,6 @@ from .trtllm_moe_a2a import MoeAlltoAll as MoeAlltoAll from .trtllm_moe_a2a import moe_a2a_combine as moe_a2a_combine from .trtllm_moe_a2a import moe_a2a_dispatch as moe_a2a_dispatch -from .trtllm_moe_a2a import ( - moe_a2a_get_combine_payload_tensor as moe_a2a_get_combine_payload_tensor, -) from .trtllm_moe_a2a import moe_a2a_initialize as moe_a2a_initialize from .trtllm_moe_a2a import ( moe_a2a_sanitize_expert_ids as moe_a2a_sanitize_expert_ids, diff --git a/flashinfer/comm/trtllm_moe_a2a.py b/flashinfer/comm/trtllm_moe_a2a.py index caf73f14d3..3f5018cbad 100644 --- a/flashinfer/comm/trtllm_moe_a2a.py +++ b/flashinfer/comm/trtllm_moe_a2a.py @@ -76,10 +76,11 @@ def moe_a2a_dispatch( num_experts: Total number of experts Returns: - recv_tensors: List of [ep_size, max_tokens, *] tensors + recv_offsets: List of offsets for each payload in the workspace + recv_sizes: List of sizes for each payload in the workspace combine_payload_offset: Offset for combine payload region """ - recv_offsets, recv_sizes, combine_payload_offset = module.moe_a2a_dispatch( + return module.moe_a2a_dispatch( token_selected_experts, input_payloads, workspace, @@ -90,19 +91,6 @@ def moe_a2a_dispatch( top_k, num_experts, ) - workspace_base = workspace.flatten().view(dtype=torch.uint8) - output_payloads = [] - for input_payload, offset, size in zip( - input_payloads, recv_offsets, recv_sizes, strict=True - ): - output_payload = ( - workspace_base[offset : offset + size] - .view([ep_size, runtime_max_tokens_per_rank, -1]) - .view(dtype=input_payload.dtype) - ) - output_payloads.append(output_payload) - - return output_payloads, combine_payload_offset @register_custom_op( "flashinfer::moe_a2a_combine", @@ -167,48 +155,25 @@ def moe_a2a_sanitize_expert_ids( ) @register_custom_op( - "flashinfer::moe_a2a_get_combine_payload_tensor", + "flashinfer::moe_a2a_get_metainfo_index_pairs", mutates_args=[], ) - def moe_a2a_get_combine_payload_tensor( - workspace: torch.Tensor, - ep_rank: int, - ep_size: int, - runtime_max_tokens_per_rank: int, - combine_payload_offset: int, - dtype: torch.dtype, - hidden_size: int, - ) -> torch.Tensor: + def moe_a2a_get_metainfo_index_pairs(): """ - Get combine payload tensor backed by workspace (zero-copy). - - Args: - workspace: [ep_size, size_per_rank] workspace tensor - ep_rank: Current expert parallel rank - ep_size: Total expert parallel size - runtime_max_tokens_per_rank: Max tokens per rank in this batch - combine_payload_offset: Offset from dispatch - dtype: Data type for the tensor - hidden_size: Hidden dimension size + Get all metainfo index constants from C++. Returns: - tensor: [ep_size * max_tokens, hidden_size] workspace-backed tensor + Tuple of (names, values) where names is a list of constant names + and values is a list of their corresponding integer values """ - return module.moe_a2a_get_combine_payload_tensor( - workspace, - ep_rank, - ep_size, - runtime_max_tokens_per_rank, - combine_payload_offset, - dtype, - hidden_size, - ) + return module.moe_a2a_get_metainfo_index_pairs() return SimpleNamespace( moe_a2a_initialize=moe_a2a_initialize, moe_a2a_dispatch=moe_a2a_dispatch, moe_a2a_combine=moe_a2a_combine, - moe_a2a_get_combine_payload_tensor=moe_a2a_get_combine_payload_tensor, + moe_a2a_sanitize_expert_ids=moe_a2a_sanitize_expert_ids, + moe_a2a_get_metainfo_index_pairs=moe_a2a_get_metainfo_index_pairs, ) @@ -223,6 +188,37 @@ def moe_a2a_initialize( ) +def moe_a2a_wrap_payload_tensor_in_workspace( + workspace: torch.Tensor, + leading_shape: list[int], + slice_start: int, + slice_end: int, + dtype: torch.dtype, +) -> torch.Tensor: + """ + Wrap an offset in the workspace into a tensor. + + Args: + workspace: [ep_size, size_per_rank] workspace tensor + ep_rank: Current expert parallel rank + ep_size: Total expert parallel size + runtime_max_tokens_per_rank: Max tokens per rank in this batch + total_size: Total size of the payload + offset: Offset from dispatch + dtype: Data type for the tensor + + Returns: + tensor: [ep_size * max_tokens, hidden_size] workspace-backed tensor + """ + workspace_base = workspace.flatten().view(dtype=torch.uint8) + result = ( + workspace_base[slice_start:slice_end] + .view(leading_shape + [-1]) + .view(dtype=dtype) + ) + return result + + def moe_a2a_dispatch( token_selected_experts: torch.Tensor, input_payloads: list[torch.Tensor], @@ -234,18 +230,37 @@ def moe_a2a_dispatch( top_k: int, num_experts: int, ): - return get_mnnvl_a2a_module().moe_a2a_dispatch( - token_selected_experts, - input_payloads, - workspace, - metainfo, - runtime_max_tokens_per_rank, - ep_rank, - ep_size, - top_k, - num_experts, + recv_offsets, recv_sizes, combine_payload_offset = ( + get_mnnvl_a2a_module().moe_a2a_dispatch( + token_selected_experts, + input_payloads, + workspace, + metainfo, + runtime_max_tokens_per_rank, + ep_rank, + ep_size, + top_k, + num_experts, + ) ) + output_payloads = [] + for input_payload, offset, size in zip( + input_payloads, recv_offsets, recv_sizes, strict=True + ): + # This uses absolute offsets in the workspace, so skip indexing into the workspace + output_payloads.append( + moe_a2a_wrap_payload_tensor_in_workspace( + workspace, + [ep_size, runtime_max_tokens_per_rank], + offset, + offset + size, + input_payload.dtype, + ) + ) + + return output_payloads, combine_payload_offset + def moe_a2a_combine( payload: torch.Tensor, @@ -285,26 +300,6 @@ def moe_a2a_sanitize_expert_ids( ) -def moe_a2a_get_combine_payload_tensor( - workspace: torch.Tensor, - ep_rank: int, - ep_size: int, - runtime_max_tokens_per_rank: int, - combine_payload_offset: int, - dtype: torch.dtype, - hidden_size: int, -) -> torch.Tensor: - return get_mnnvl_a2a_module().moe_a2a_get_combine_payload_tensor( - workspace, - ep_rank, - ep_size, - runtime_max_tokens_per_rank, - combine_payload_offset, - dtype, - hidden_size, - ) - - class MoeAlltoAll: """ Manages MoE All-to-All operations with proper workspace allocation and synchronization. @@ -321,6 +316,29 @@ class MoeAlltoAll: # Single shared workspace across the process _WORKSPACE: Optional[dict] = None + # Metainfo index constants (loaded dynamically from C++) + # These offsets allow accessing internal workspace data for testing/debugging + _METAINFO_INDEX: Optional[dict] = None + + @classmethod + def _init_constants(cls): + """Initialize constants from C++ if not already done.""" + if cls._METAINFO_INDEX is None: + module = get_mnnvl_a2a_module() + names, values = module.moe_a2a_get_metainfo_index_pairs() + + # Convert TVM arrays to Python and build dictionary + # Strip "MOE_A2A_" prefix from names for cleaner API + cls._METAINFO_INDEX = {} + for name, value in zip(names, values, strict=True): + # Convert from "MOE_A2A_SEND_COUNTERS_OFFSET_INDEX" to "SEND_COUNTERS_OFFSET_INDEX" + clean_name = ( + name.replace("MOE_A2A_", "") + if name.startswith("MOE_A2A_") + else name + ) + cls._METAINFO_INDEX[clean_name] = int(value) + def __init__( self, mapping: Mapping, @@ -339,13 +357,16 @@ def __init__( num_experts: Total number of experts workspace_size_per_rank: Size of workspace per rank in bytes (default: 512MB) """ + # Initialize constants from C++ + self._init_constants() + # Initialize MNNVL memory system MnnvlMemory.initialize() self.workspace_size_per_rank = workspace_size_per_rank self.max_num_tokens = max_num_tokens - self.ep_size = mapping.tp_size - self.ep_rank = mapping.tp_rank + self.ep_size = mapping.moe_ep_size + self.ep_rank = mapping.moe_ep_rank self.top_k = top_k self.num_experts = num_experts @@ -515,14 +536,14 @@ def get_combine_payload_tensor_in_workspace( "get_combine_payload_tensor_in_workspace called before successful dispatch" ) - return moe_a2a_get_combine_payload_tensor( - self.workspace, - self.ep_rank, - self.ep_size, - runtime_max_tokens_per_rank, + element_size = torch.tensor([], dtype=dtype).element_size() + return moe_a2a_wrap_payload_tensor_in_workspace( + self.workspace[self.ep_rank, :], + [self.ep_size * runtime_max_tokens_per_rank], self._state.combine_payload_offset, + self._state.combine_payload_offset + + self.ep_size * runtime_max_tokens_per_rank * hidden_size * element_size, dtype, - hidden_size, ) @@ -532,5 +553,4 @@ def get_combine_payload_tensor_in_workspace( "moe_a2a_dispatch", "moe_a2a_combine", "moe_a2a_sanitize_expert_ids", - "moe_a2a_get_combine_payload_tensor", ] diff --git a/tests/comm/test_mnnvl_a2a.py b/tests/comm/test_mnnvl_a2a.py index 13c597518b..6f0de9e872 100644 --- a/tests/comm/test_mnnvl_a2a.py +++ b/tests/comm/test_mnnvl_a2a.py @@ -21,6 +21,7 @@ from flashinfer.comm import MoeAlltoAll from flashinfer.comm.mapping import Mapping +from flashinfer.comm.mnnvl import MnnvlMemory @pytest.fixture(autouse=True) @@ -29,14 +30,24 @@ def setup_test(): def compute_target_rank_id(expert_id, num_experts_per_rank): - """Compute the rank that owns a given expert using contiguous partitioning.""" + """Compute the rank that owns a given expert using contiguous partitioning. + Experts are divided evenly across ranks: + - Rank 0: experts [0, num_experts_per_rank) + - Rank 1: experts [num_experts_per_rank, 2 * num_experts_per_rank) + - ... + For example, with 32 experts and 4 ranks (8 experts per rank): + - Rank 0: experts 0-7 + - Rank 1: experts 8-15 + - Rank 2: experts 16-23 + - Rank 3: experts 24-31 + """ return expert_id // num_experts_per_rank def generate_token_selected_experts( local_num_tokens: int, ep_size: int, num_experts_per_rank: int, top_k: int ) -> torch.Tensor: - """Generate global expert IDs tensor.""" + """Generate global expert IDs tensor, aligned with single-GPU test semantics.""" return torch.randint( 0, ep_size * num_experts_per_rank, @@ -52,14 +63,22 @@ def create_experts( """ Create a 3D tensor of expert weights for a given rank. + Args: + num_experts_per_rank: Number of experts on this rank + hidden_size: Hidden dimension size + ep_rank: EP rank ID + device: Device to create experts on + Returns: experts: Tensor of shape [num_experts_per_rank, hidden_size, hidden_size] """ + # For reproducibility, set the seed based on rank experts = torch.empty( (num_experts_per_rank, hidden_size, hidden_size), dtype=dtype, device=device ) for i in range(num_experts_per_rank): torch.manual_seed(ep_rank * 1000 + i) + # Xavier uniform initialization for each expert torch.nn.init.xavier_uniform_(experts[i]) return experts @@ -74,10 +93,19 @@ def fake_moe( num_experts_per_rank=None, ): """ - Emulate MoE computation. + Emulate MoE computation by scaling tokens based on which experts belong to this rank. + + Args: + hidden_states: [num_tokens, hidden_size] - input hidden states + token_selected_experts: [num_tokens, top_k] - selected expert indices + token_final_scales: [num_tokens, top_k] - scaling factors for each expert + experts: [num_experts_per_rank, hidden_size, hidden_size] if is_ep, otherwise [num_experts, hidden_size, hidden_size] - expert weights + is_ep: If true, emulate MoE on a EP rank; otherwise, emulate MoE with all experts + ep_rank: EP rank ID + num_experts_per_rank: Number of experts per rank Returns: - processed_states: [num_tokens, hidden_size] + processed_states: [num_tokens, hidden_size] - processed hidden states """ num_tokens, _ = hidden_states.shape _, top_k = token_selected_experts.shape @@ -85,9 +113,12 @@ def fake_moe( if is_ep: assert ep_rank is not None and num_experts_per_rank is not None + # Initialize output processed_states = torch.zeros_like(hidden_states) + # Process each token for token_idx in range(num_tokens): + # For each expert selected for this token/ for k in range(top_k): expert_id = token_selected_experts[token_idx, k].item() if is_ep: @@ -96,6 +127,7 @@ def fake_moe( and expert_id < (ep_rank + 1) * num_experts_per_rank ): continue + # Convert global expert ID to local expert ID for this rank local_expert_id = expert_id - ep_rank * num_experts_per_rank expert = experts[local_expert_id] else: @@ -107,6 +139,47 @@ def fake_moe( return processed_states +def make_nvfp4_payloads( + local_num_tokens: int, + hidden_size: int, + top_k: int, + rank: int, + token_selected_experts: torch.Tensor, +) -> tuple[list, int]: + """Create the four NV FP4 payloads exactly as in single-GPU test.""" + payloads = [] + # Payload 0: Packed FP4 tokens (uint8) + packed_hidden_size = hidden_size // 2 + packed_hidden_states = torch.randint( + 0, 256, (local_num_tokens, packed_hidden_size), dtype=torch.uint8, device="cuda" + ) + payloads.append(packed_hidden_states) + + # Payload 1: Scaling factors (fp8) + num_elts_per_sf = 16 + num_scaling_factors = hidden_size // num_elts_per_sf + scaling_factors = torch.randn( + local_num_tokens, num_scaling_factors, dtype=torch.float32, device="cuda" + ) # .to(torch.float8_e4m3fn) TODO: Test failed. + scaling_factors += rank + payloads.append(scaling_factors) + + # Payload 2: token_selected_experts + payloads.append(token_selected_experts) + + # Payload 3: token_final_scales (bfloat16) + token_final_scales = torch.rand( + local_num_tokens, top_k, dtype=torch.bfloat16, device="cuda" + ) + + # Construct the data to contain info about send rank and local_token_idx, which is used for debugging + # token_final_scales[:, 0] = rank + # token_final_scales[:, 1] = torch.linspace(0, local_num_tokens - 1, local_num_tokens, dtype=torch.bfloat16, device='cuda') + + payloads.append(token_final_scales) + return payloads, 2 + + def make_bfloat16_payloads( local_num_tokens: int, hidden_size: int, @@ -114,90 +187,436 @@ def make_bfloat16_payloads( rank: int, token_selected_experts: torch.Tensor, ) -> tuple[list, int]: - """Create bfloat16 test payloads.""" + """Create bfloat16 test payloads matching nvfp4 structure but without scaling factors.""" payloads = [] - # Payload 0: Hidden states + # Payload 0: Hidden states (bfloat16) hidden_states = torch.randn( local_num_tokens, hidden_size, dtype=torch.bfloat16, device="cuda" ) - hidden_states += rank # Add rank offset for verification + # Add rank-specific pattern for verification + hidden_states += rank payloads.append(hidden_states) # Payload 1: token_selected_experts payloads.append(token_selected_experts) - # Payload 2: token_final_scales + # Payload 2: token_final_scales (bfloat16) - similar to nvfp4's payload 4 token_final_scales = torch.rand( local_num_tokens, top_k, dtype=torch.bfloat16, device="cuda" ) + + # Optional: Construct the data that is easier to debug + # token_final_scales[:, 0] = rank + # token_final_scales[:, 1] = torch.linspace(0, local_num_tokens - 1, local_num_tokens, dtype=torch.bfloat16, device='cuda') + payloads.append(token_final_scales) - return payloads, 1 # expert_id_payload_index = 1 + return payloads, 1 def run_moe_a2a_dispatch_single_rank( - ep_size, all_num_tokens, top_k, num_experts_per_rank, hidden_size + ep_size, + all_num_tokens, + top_k, + workspace_size_per_rank, + num_experts_per_rank, + hidden_size, + invalid_token_expert_id, ): - """Test MoE A2A dispatch on a single rank.""" + """Worker function for MPI testing.""" comm = MPI.COMM_WORLD rank = comm.Get_rank() - world_size = comm.Get_size() + torch.cuda.set_device(rank) - if world_size != ep_size: - pytest.skip(f"Test requires exactly {ep_size} ranks") + try: + mapping = Mapping( + rank=rank, + tp_size=ep_size, + moe_ep_size=ep_size, + world_size=ep_size, + gpus_per_node=ep_size, + pp_size=1, + cp_size=1, + ) - torch.cuda.set_device(rank) + # Create MoeAlltoAll manager + max_num_tokens = max(all_num_tokens) - mapping = Mapping( - world_size=world_size, - rank=rank, - gpus_per_node=world_size, - tp_size=world_size, - pp_size=1, - cp_size=1, - ) + moe_a2a = MoeAlltoAll( + mapping, + max_num_tokens, + top_k, + ep_size * num_experts_per_rank, + workspace_size_per_rank, + ) + + # Get the number of tokens for this specific rank (same as single-GPU) + rank_local_tokens = all_num_tokens[rank] + + # Generate data using helper functions + token_selected_experts = generate_token_selected_experts( + rank_local_tokens, ep_size, num_experts_per_rank, top_k + ) + payloads, expert_id_payload_index = make_nvfp4_payloads( + rank_local_tokens, hidden_size, top_k, rank, token_selected_experts + ) + + recv_tensors = moe_a2a.dispatch( + token_selected_experts, + payloads, + max_num_tokens, + invalid_token_expert_id=invalid_token_expert_id, + expert_id_payload_index=expert_id_payload_index, + ) + + # Read counters and compact routing tensors from workspace + send_counters_offset = moe_a2a.metainfo[ + MoeAlltoAll._METAINFO_INDEX["SEND_COUNTERS_OFFSET_INDEX"] + ].item() + recv_counters_offset = moe_a2a.metainfo[ + MoeAlltoAll._METAINFO_INDEX["RECV_COUNTERS_OFFSET_INDEX"] + ].item() + topk_target_ranks_offset = moe_a2a.metainfo[ + MoeAlltoAll._METAINFO_INDEX["TOPK_TARGET_RANKS_OFFSET_INDEX"] + ].item() + topk_send_indices_offset = moe_a2a.metainfo[ + MoeAlltoAll._METAINFO_INDEX["TOPK_SEND_INDICES_OFFSET_INDEX"] + ].item() + + send_counters = ( + moe_a2a.workspace[ + rank, send_counters_offset : send_counters_offset + ep_size * 4 + ] + .view(torch.int32) + .cpu() + ) + recv_counters = ( + moe_a2a.workspace[ + rank, recv_counters_offset : recv_counters_offset + ep_size * 4 + ] + .view(torch.int32) + .cpu() + ) + topk_target_ranks = ( + moe_a2a.workspace[ + rank, + topk_target_ranks_offset : topk_target_ranks_offset + + max_num_tokens * top_k * 4, + ] + .view(torch.int32) + .view(max_num_tokens, top_k) + .cpu() + ) + topk_send_indices = ( + moe_a2a.workspace[ + rank, + topk_send_indices_offset : topk_send_indices_offset + + max_num_tokens * top_k * 4, + ] + .view(torch.int32) + .view(max_num_tokens, top_k) + .cpu() + ) + + # Return results to be collected (move to CPU for MPI transfer) + return ( + token_selected_experts.cpu(), + [p.cpu() for p in payloads], + [rt.cpu() for rt in recv_tensors], + send_counters, + topk_send_indices, + topk_target_ranks, + recv_counters, + expert_id_payload_index, + ) + except Exception: + traceback.print_exc() + raise + + +def verify_dispatch( + all_token_selected_experts, + all_payloads, + all_recv_tensors, + all_send_counters, + all_topk_send_indices, + all_topk_target_ranks, + all_recv_counters, + ep_size, + all_num_tokens, + top_k, + num_experts_per_rank, + expert_id_payload_index, + invalid_token_expert_id, +): + """Verify dispatch results including actual content verification""" - local_num_tokens = all_num_tokens[rank] max_num_tokens = max(all_num_tokens) + # Verify dimensions and dtypes + for send_rank in range(ep_size): + local_num_tokens = all_num_tokens[send_rank] - # Generate inputs - token_selected_experts = generate_token_selected_experts( - local_num_tokens, ep_size, num_experts_per_rank, top_k - ) + token_selected_experts = all_token_selected_experts[send_rank] + assert len(token_selected_experts.shape) == 2, ( + "token_selected_experts should be a 2D tensor" + ) + assert token_selected_experts.dtype == torch.int32, ( + "token_selected_experts should be a 32-bit integer tensor" + ) + assert token_selected_experts.shape[0] == local_num_tokens, ( + "token_selected_experts.shape[0] should be local_num_tokens" + ) + assert token_selected_experts.shape[1] == top_k, ( + "token_selected_experts.shape[1] should be top_k" + ) - payloads, expert_id_payload_index = make_bfloat16_payloads( - local_num_tokens, hidden_size, top_k, rank, token_selected_experts - ) + payloads = all_payloads[send_rank] + recv_tensors = all_recv_tensors[send_rank] + num_payloads = len(payloads) + assert len(recv_tensors) == num_payloads, ( + "recv_tensors should have the same number of payloads as payloads" + ) + for i in range(num_payloads): + payload = payloads[i] + assert len(payload.shape) == 2, "payload should be a 2D tensor" + assert payload.shape[0] == local_num_tokens, ( + "payload.shape[0] should be local_num_tokens" + ) - # Initialize MoeAlltoAll - moe_a2a = MoeAlltoAll( - mapping=mapping, - max_num_tokens=max_num_tokens, - top_k=top_k, - num_experts=ep_size * num_experts_per_rank, - workspace_size_per_rank=512 * 1024 * 1024, - ) + recv_tensor = recv_tensors[i] + assert len(recv_tensor.shape) == 3, "recv_tensor should be a 3D tensor" + assert recv_tensor.shape[0] == ep_size, ( + "recv_tensor.shape[0] should be ep_size" + ) + assert recv_tensor.shape[1] == max_num_tokens, ( + "recv_tensor.shape[1] should be max_num_tokens" + ) + assert recv_tensor.shape[2] == payload.shape[1], ( + "recv_tensor.shape[2] should be payload.shape[1]" + ) + assert recv_tensor.dtype == payload.dtype, ( + "recv_tensor.dtype should be payload.dtype" + ) - # Dispatch - recv_tensors = moe_a2a.dispatch( - token_selected_experts=token_selected_experts, - input_payloads=payloads, - runtime_max_tokens_per_rank=max_num_tokens, + # Verify counters and compact routing tensors + send_counters = all_send_counters[send_rank] + assert len(send_counters.shape) == 1, "send_counters should be a 1D tensor" + assert send_counters.shape[0] == ep_size + assert send_counters.dtype == torch.int32 + + recv_counters = all_recv_counters[send_rank] + assert len(recv_counters.shape) == 1, "recv_counters should be a 1D tensor" + assert recv_counters.shape[0] == ep_size + assert recv_counters.dtype == torch.int32 + + topk_send_indices = all_topk_send_indices[send_rank] + topk_target_ranks = all_topk_target_ranks[send_rank] + assert topk_send_indices.shape == (max_num_tokens, top_k), ( + "topk_send_indices shape" + ) + assert topk_target_ranks.shape == (max_num_tokens, top_k), ( + "topk_target_ranks shape" + ) + assert topk_send_indices.dtype == torch.int32 + assert topk_target_ranks.dtype == torch.int32 + + # Verify send_counters per (send_rank -> target_rank) + for send_rank in range(ep_size): + expected_sends = {} + token_experts = all_token_selected_experts[send_rank] + sent_to_rank = set() + + for token_idx in range(token_experts.shape[0]): + experts = token_experts[token_idx] + target_ranks = compute_target_rank_id(experts, num_experts_per_rank) + sent_to_rank.clear() + + for target_rank in target_ranks.tolist(): + if target_rank not in sent_to_rank: + if target_rank not in expected_sends: + expected_sends[target_rank] = 0 + expected_sends[target_rank] += 1 + sent_to_rank.add(target_rank) + + for target_rank in range(ep_size): + expected_to_rank = expected_sends.get(target_rank, 0) + actual_to_rank = all_send_counters[send_rank][target_rank].item() + assert actual_to_rank == expected_to_rank, ( + f"Rank {send_rank} sent {actual_to_rank} tokens to rank {target_rank}, expected {expected_to_rank}" + ) + + # Verify recv_counters match send_counters + for recv_rank in range(ep_size): + for send_rank in range(ep_size): + expected_recv = all_send_counters[send_rank][recv_rank].item() + actual_recv = all_recv_counters[recv_rank][send_rank].item() + assert actual_recv == expected_recv, ( + f"Rank {recv_rank} received {actual_recv} tokens from rank {send_rank}, expected {expected_recv}" + ) + + # Verify payload content using topk_send_indices and topk_target_ranks + for send_rank in range(ep_size): + token_selected_experts = all_token_selected_experts[send_rank] + payloads = all_payloads[send_rank] + topk_send_indices = all_topk_send_indices[send_rank] + topk_target_ranks = all_topk_target_ranks[send_rank] + local_num_tokens = all_num_tokens[send_rank] + + for token_idx in range(local_num_tokens): + experts = token_selected_experts[token_idx] + target_ranks = compute_target_rank_id(experts, num_experts_per_rank) + # Deduplicate target ranks per token + topk_target_ranks_ref = target_ranks.clone() + seen = set() + for kk in range(top_k): + tr = int(topk_target_ranks_ref[kk].item()) + if tr in seen: + topk_target_ranks_ref[kk] = -1 + else: + seen.add(tr) + + assert ( + topk_target_ranks[token_idx, :].tolist() + == topk_target_ranks_ref.tolist() + ) + + for k in range(top_k): + dst_pos = topk_send_indices[token_idx, k].item() + target_rank = topk_target_ranks[token_idx, k].item() + if dst_pos == -1: + assert target_rank == -1 + continue + recv_tensors = all_recv_tensors[target_rank] + for payload_idx, payload in enumerate(payloads): + recv_tensor = recv_tensors[payload_idx] + source_data = payload[token_idx] + received_data = recv_tensor[send_rank, dst_pos] + torch.testing.assert_close( + received_data, source_data, atol=0, rtol=0 + ) + + # Verify token_selected_experts of invalid tokens are correctly sanitized + for recv_rank in range(ep_size): + expert_ids_recv = all_recv_tensors[recv_rank][expert_id_payload_index] + for source_rank in range(ep_size): + valid = int(all_recv_counters[recv_rank][source_rank].item()) + for token_idx in range(max_num_tokens): + token_expert_ids = expert_ids_recv[source_rank, token_idx] + if token_idx >= valid: + assert torch.all(token_expert_ids == invalid_token_expert_id) + + +@pytest.mark.parametrize( + "ep_size,all_num_tokens,top_k", + [ + # Basic configurations + (4, [32, 32, 32, 32], 2), # Four ranks with uniform distribution + (4, [16, 32, 64, 48], 2), # Four ranks with non-uniform distribution + (2, [100, 50], 2), # Two ranks with different loads + (8, [10, 20, 30, 40, 50, 60, 70, 80], 2), # Eight ranks with increasing load + # Different top_k values + (4, [32, 32, 32, 32], 4), # Four ranks with top_k = 4 + (4, [32, 32, 32, 32], 8), # Four ranks with top_k = 8 + # Edge cases + (4, [1, 1, 1, 1], 2), # Four ranks with single token per rank + ], +) +def test_moe_a2a_dispatch(ep_size, all_num_tokens, top_k): + """Test MoE A2A dispatch operation.""" + if len(all_num_tokens) != ep_size: + pytest.skip( + f"all_num_tokens length {len(all_num_tokens)} must match ep_size {ep_size}" + ) + + comm = MPI.COMM_WORLD + # rank = comm.Get_rank() + world_size = comm.Get_size() + + if world_size != ep_size: + pytest.skip(f"Test requires exactly {ep_size} ranks") + + try: + MnnvlMemory.initialize() + if not MnnvlMemory.supports_mnnvl(): + pytest.skip("MNNVL not supported on this system") + except Exception: + pytest.skip("MNNVL not supported on this system") + + hidden_size = 1024 + num_experts_per_rank = 8 + workspace_size_per_rank = 512 * 1024 * 1024 + invalid_token_expert_id = -1 + + # Run dispatch on this rank + try: + result = run_moe_a2a_dispatch_single_rank( + ep_size, + all_num_tokens, + top_k, + workspace_size_per_rank, + num_experts_per_rank, + hidden_size, + invalid_token_expert_id, + ) + except Exception as e: + traceback.print_exc() + raise e + + # Gather results from all ranks + all_results = comm.allgather(result) + + # Extract results + all_token_selected_experts = [r[0] for r in all_results] + all_payloads = [r[1] for r in all_results] + all_recv_tensors = [r[2] for r in all_results] + all_send_counters = [r[3] for r in all_results] + all_topk_send_indices = [r[4] for r in all_results] + all_topk_target_ranks = [r[5] for r in all_results] + all_recv_counters = [r[6] for r in all_results] + all_expert_id_payload_index = [r[7] for r in all_results] + expert_id_payload_index = all_expert_id_payload_index[0] + + assert all(i == expert_id_payload_index for i in all_expert_id_payload_index), ( + "all_expert_id_payload_index should be the same" ) - # Verify shapes - assert len(recv_tensors) == len(payloads) - for i, recv_tensor in enumerate(recv_tensors): - assert recv_tensor.shape[0] == ep_size - assert recv_tensor.shape[1] == max_num_tokens - assert recv_tensor.shape[2] == payloads[i].shape[1] + # Verify dispatch results with full counter verification + verify_dispatch( + all_token_selected_experts, + all_payloads, + all_recv_tensors, + all_send_counters, + all_topk_send_indices, + all_topk_target_ranks, + all_recv_counters, + ep_size, + all_num_tokens, + top_k, + num_experts_per_rank, + expert_id_payload_index, + invalid_token_expert_id, + ) -def run_moe_a2a_dispatch_moe_combine_single_rank( - ep_size, all_num_tokens, top_k, num_experts_per_rank, hidden_size -): +@pytest.mark.parametrize( + "ep_size,all_num_tokens,top_k", + [ + (4, [32, 32, 32, 32], 2), + (4, [16, 32, 64, 48], 2), + (2, [100, 50], 2), + (4, [32, 32, 32, 32], 4), + (4, [1, 1, 1, 1], 2), + (8, [640, 640, 640, 640, 640, 640, 640, 640], 4), + ], +) +def test_moe_a2a_dispatch_moe_combine(ep_size, all_num_tokens, top_k): """Test full MoE A2A dispatch + expert processing + combine cycle.""" + if len(all_num_tokens) != ep_size: + pytest.skip( + f"all_num_tokens length {len(all_num_tokens)} must match ep_size {ep_size}" + ) + comm = MPI.COMM_WORLD rank = comm.Get_rank() world_size = comm.Get_size() @@ -205,15 +624,24 @@ def run_moe_a2a_dispatch_moe_combine_single_rank( if world_size != ep_size: pytest.skip(f"Test requires exactly {ep_size} ranks") + try: + MnnvlMemory.initialize() + if not MnnvlMemory.supports_mnnvl(): + pytest.skip("MNNVL not supported on this system") + except Exception: + pytest.skip("MNNVL not supported on this system") + torch.cuda.set_device(rank) + hidden_size = 2880 # gpt-oss + num_experts_per_rank = 8 + workspace_size_per_rank = 512 * 1024 * 1024 + mapping = Mapping( - world_size=world_size, rank=rank, - gpus_per_node=world_size, + moe_ep_size=world_size, tp_size=world_size, - pp_size=1, - cp_size=1, + world_size=world_size, ) local_num_tokens = all_num_tokens[rank] @@ -246,21 +674,26 @@ def run_moe_a2a_dispatch_moe_combine_single_rank( ], dim=0, ) + reference_output = fake_moe( hidden_states, token_selected_experts, token_final_scales, all_experts, - is_ep=False, + is_ep=True, + ep_rank=rank, + num_experts_per_rank=num_experts_per_rank, ) + torch.cuda.synchronize() + # Initialize MoeAlltoAll moe_a2a = MoeAlltoAll( mapping=mapping, max_num_tokens=max_num_tokens, top_k=top_k, num_experts=ep_size * num_experts_per_rank, - workspace_size_per_rank=512 * 1024 * 1024, + workspace_size_per_rank=workspace_size_per_rank, ) # Dispatch @@ -311,52 +744,6 @@ def run_moe_a2a_dispatch_moe_combine_single_rank( torch.testing.assert_close(combined_output, reference_output, rtol=1e-2, atol=1e-2) -@pytest.mark.parametrize("ep_size", [2, 4]) -@pytest.mark.parametrize("all_num_tokens", [[64, 64], [32, 48, 64, 80]]) -@pytest.mark.parametrize("top_k", [2, 4]) -@pytest.mark.parametrize("num_experts_per_rank", [2, 4]) -@pytest.mark.parametrize("hidden_size", [128, 256]) -def test_moe_a2a_dispatch( - ep_size, all_num_tokens, top_k, num_experts_per_rank, hidden_size -): - """Test MoE A2A dispatch operation.""" - if len(all_num_tokens) != ep_size: - pytest.skip( - f"all_num_tokens length {len(all_num_tokens)} must match ep_size {ep_size}" - ) - - try: - run_moe_a2a_dispatch_single_rank( - ep_size, all_num_tokens, top_k, num_experts_per_rank, hidden_size - ) - except Exception as e: - traceback.print_exc() - raise e - - -@pytest.mark.parametrize("ep_size", [2, 4]) -@pytest.mark.parametrize("all_num_tokens", [[64, 64], [32, 48, 64, 80]]) -@pytest.mark.parametrize("top_k", [2, 4]) -@pytest.mark.parametrize("num_experts_per_rank", [2, 4]) -@pytest.mark.parametrize("hidden_size", [128, 256]) -def test_moe_a2a_dispatch_moe_combine( - ep_size, all_num_tokens, top_k, num_experts_per_rank, hidden_size -): - """Test full MoE A2A dispatch + expert processing + combine cycle.""" - if len(all_num_tokens) != ep_size: - pytest.skip( - f"all_num_tokens length {len(all_num_tokens)} must match ep_size {ep_size}" - ) - - try: - run_moe_a2a_dispatch_moe_combine_single_rank( - ep_size, all_num_tokens, top_k, num_experts_per_rank, hidden_size - ) - except Exception as e: - traceback.print_exc() - raise e - - if __name__ == "__main__": # Run with: mpirun -n 2 python -m pytest tests/comm/test_mnnvl_a2a.py -v pytest.main([__file__, "-v", "-s"]) From 3b5e0d36f6a9455ac40e2fd0216940e05adb26c8 Mon Sep 17 00:00:00 2001 From: djns99 <40156487+djns99@users.noreply.github.com> Date: Thu, 20 Nov 2025 11:53:47 +1300 Subject: [PATCH 05/25] Update tests with fake_moe properly --- flashinfer/comm/trtllm_moe_a2a.py | 4 +-- tests/comm/test_mnnvl_a2a.py | 44 +++++++++++++++---------------- 2 files changed, 23 insertions(+), 25 deletions(-) diff --git a/flashinfer/comm/trtllm_moe_a2a.py b/flashinfer/comm/trtllm_moe_a2a.py index 3f5018cbad..58f0a1834a 100644 --- a/flashinfer/comm/trtllm_moe_a2a.py +++ b/flashinfer/comm/trtllm_moe_a2a.py @@ -529,7 +529,7 @@ def get_combine_payload_tensor_in_workspace( dtype: Data type for the tensor Returns: - tensor: [ep_size * max_tokens, hidden_size] workspace-backed tensor + tensor: [ep_size, max_tokens, hidden_size] workspace-backed tensor """ if self._state.phase != "dispatched": raise RuntimeError( @@ -539,7 +539,7 @@ def get_combine_payload_tensor_in_workspace( element_size = torch.tensor([], dtype=dtype).element_size() return moe_a2a_wrap_payload_tensor_in_workspace( self.workspace[self.ep_rank, :], - [self.ep_size * runtime_max_tokens_per_rank], + [self.ep_size, runtime_max_tokens_per_rank], self._state.combine_payload_offset, self._state.combine_payload_offset + self.ep_size * runtime_max_tokens_per_rank * hidden_size * element_size, diff --git a/tests/comm/test_mnnvl_a2a.py b/tests/comm/test_mnnvl_a2a.py index 6f0de9e872..ad6ce87c75 100644 --- a/tests/comm/test_mnnvl_a2a.py +++ b/tests/comm/test_mnnvl_a2a.py @@ -659,11 +659,6 @@ def test_moe_a2a_dispatch_moe_combine(ep_size, all_num_tokens, top_k): hidden_states = payloads[0] token_final_scales = payloads[2] - # Create experts for this rank - experts = create_experts( - num_experts_per_rank, hidden_size, rank, "cuda", dtype=torch.bfloat16 - ) - # Compute reference (single-GPU MoE) all_experts = torch.cat( [ @@ -675,14 +670,16 @@ def test_moe_a2a_dispatch_moe_combine(ep_size, all_num_tokens, top_k): dim=0, ) + rank_experts = create_experts( + num_experts_per_rank, hidden_size, rank, "cuda", dtype=torch.bfloat16 + ) + reference_output = fake_moe( hidden_states, token_selected_experts, token_final_scales, all_experts, - is_ep=True, - ep_rank=rank, - num_experts_per_rank=num_experts_per_rank, + is_ep=False, ) torch.cuda.synchronize() @@ -717,21 +714,22 @@ def test_moe_a2a_dispatch_moe_combine(ep_size, all_num_tokens, top_k): moe_output.zero_() # Process each rank's tokens with local experts - for source_rank in range(ep_size): - source_num_tokens = all_num_tokens[source_rank] - for token_idx in range(source_num_tokens): - for k in range(top_k): - expert_id = token_selected_experts_recv[ - source_rank, token_idx, k - ].item() - local_expert_id = expert_id - rank * num_experts_per_rank - - if 0 <= local_expert_id < num_experts_per_rank: - token_hidden = hidden_states_recv[source_rank, token_idx] - scale = token_final_scales_recv[source_rank, token_idx, k] - expert_out = token_hidden @ experts[local_expert_id] - output_idx = source_rank * max_num_tokens + token_idx - moe_output[output_idx] += expert_out * scale + print( + f"hidden_states_recv.shape: {hidden_states_recv.shape}, token_selected_experts_recv.shape: {token_selected_experts_recv.shape}, token_final_scales_recv.shape: {token_final_scales_recv.shape}, rank_experts.shape: {rank_experts.shape}, moe_output.shape: {moe_output.shape}" + ) + moe_output[rank] = fake_moe( + hidden_states_recv.view(ep_size * max_num_tokens, hidden_states_recv.shape[-1]), + token_selected_experts_recv.view( + ep_size * max_num_tokens, token_selected_experts_recv.shape[-1] + ), + token_final_scales_recv.view( + ep_size * max_num_tokens, token_final_scales_recv.shape[-1] + ), + rank_experts, # experts for current rank + is_ep=True, + ep_rank=rank, + num_experts_per_rank=num_experts_per_rank, + ).view(ep_size, max_num_tokens, hidden_states_recv.shape[-1]) # Combine combined_output = moe_a2a.combine( From 3a84b49237ed3ddcd1e5111dc69589dfb9dc1408 Mon Sep 17 00:00:00 2001 From: djns99 <40156487+djns99@users.noreply.github.com> Date: Thu, 20 Nov 2025 12:39:18 +1300 Subject: [PATCH 06/25] Clear MOE workspace before each run --- tests/comm/test_mnnvl_a2a.py | 35 +++++++++++++++++++---------------- 1 file changed, 19 insertions(+), 16 deletions(-) diff --git a/tests/comm/test_mnnvl_a2a.py b/tests/comm/test_mnnvl_a2a.py index ad6ce87c75..c37965ec05 100644 --- a/tests/comm/test_mnnvl_a2a.py +++ b/tests/comm/test_mnnvl_a2a.py @@ -243,6 +243,7 @@ def run_moe_a2a_dispatch_single_rank( # Create MoeAlltoAll manager max_num_tokens = max(all_num_tokens) + MoeAlltoAll._WORKSPACE = None moe_a2a = MoeAlltoAll( mapping, max_num_tokens, @@ -685,6 +686,7 @@ def test_moe_a2a_dispatch_moe_combine(ep_size, all_num_tokens, top_k): torch.cuda.synchronize() # Initialize MoeAlltoAll + MoeAlltoAll._WORKSPACE = None moe_a2a = MoeAlltoAll( mapping=mapping, max_num_tokens=max_num_tokens, @@ -714,26 +716,27 @@ def test_moe_a2a_dispatch_moe_combine(ep_size, all_num_tokens, top_k): moe_output.zero_() # Process each rank's tokens with local experts - print( - f"hidden_states_recv.shape: {hidden_states_recv.shape}, token_selected_experts_recv.shape: {token_selected_experts_recv.shape}, token_final_scales_recv.shape: {token_final_scales_recv.shape}, rank_experts.shape: {rank_experts.shape}, moe_output.shape: {moe_output.shape}" + moe_output.copy_( + fake_moe( + hidden_states_recv.view( + ep_size * max_num_tokens, hidden_states_recv.shape[-1] + ), + token_selected_experts_recv.view( + ep_size * max_num_tokens, token_selected_experts_recv.shape[-1] + ), + token_final_scales_recv.view( + ep_size * max_num_tokens, token_final_scales_recv.shape[-1] + ), + rank_experts, # experts for current rank + is_ep=True, + ep_rank=rank, + num_experts_per_rank=num_experts_per_rank, + ).view(ep_size, max_num_tokens, hidden_size) ) - moe_output[rank] = fake_moe( - hidden_states_recv.view(ep_size * max_num_tokens, hidden_states_recv.shape[-1]), - token_selected_experts_recv.view( - ep_size * max_num_tokens, token_selected_experts_recv.shape[-1] - ), - token_final_scales_recv.view( - ep_size * max_num_tokens, token_final_scales_recv.shape[-1] - ), - rank_experts, # experts for current rank - is_ep=True, - ep_rank=rank, - num_experts_per_rank=num_experts_per_rank, - ).view(ep_size, max_num_tokens, hidden_states_recv.shape[-1]) # Combine combined_output = moe_a2a.combine( - payload=moe_output.view(ep_size, max_num_tokens, hidden_size), + payload=moe_output, runtime_max_tokens_per_rank=max_num_tokens, payload_in_workspace=True, ) From f894bce28620362d2f3477269e86e7813810a449 Mon Sep 17 00:00:00 2001 From: djns99 <40156487+djns99@users.noreply.github.com> Date: Thu, 20 Nov 2025 16:20:38 +1300 Subject: [PATCH 07/25] Cleanup MPI processes on test failures --- tests/comm/test_mnnvl_a2a.py | 207 ++++++++++++++++++++--------------- 1 file changed, 116 insertions(+), 91 deletions(-) diff --git a/tests/comm/test_mnnvl_a2a.py b/tests/comm/test_mnnvl_a2a.py index c37965ec05..5628c4ed1d 100644 --- a/tests/comm/test_mnnvl_a2a.py +++ b/tests/comm/test_mnnvl_a2a.py @@ -562,8 +562,13 @@ def test_moe_a2a_dispatch(ep_size, all_num_tokens, top_k): ) except Exception as e: traceback.print_exc() + comm.allgather(e) raise e + exceptions = comm.allgather(None) + if any(exceptions): + raise filter(lambda x: x is not None, exceptions)[0] + # Gather results from all ranks all_results = comm.allgather(result) @@ -638,111 +643,131 @@ def test_moe_a2a_dispatch_moe_combine(ep_size, all_num_tokens, top_k): num_experts_per_rank = 8 workspace_size_per_rank = 512 * 1024 * 1024 - mapping = Mapping( - rank=rank, - moe_ep_size=world_size, - tp_size=world_size, - world_size=world_size, - ) + try: + mapping = Mapping( + rank=rank, + moe_ep_size=world_size, + tp_size=world_size, + world_size=world_size, + ) - local_num_tokens = all_num_tokens[rank] - max_num_tokens = max(all_num_tokens) + local_num_tokens = all_num_tokens[rank] + max_num_tokens = max(all_num_tokens) - # Generate inputs - token_selected_experts = generate_token_selected_experts( - local_num_tokens, ep_size, num_experts_per_rank, top_k - ) + # Generate inputs + token_selected_experts = generate_token_selected_experts( + local_num_tokens, ep_size, num_experts_per_rank, top_k + ) - payloads, expert_id_payload_index = make_bfloat16_payloads( - local_num_tokens, hidden_size, top_k, rank, token_selected_experts - ) + payloads, expert_id_payload_index = make_bfloat16_payloads( + local_num_tokens, hidden_size, top_k, rank, token_selected_experts + ) - hidden_states = payloads[0] - token_final_scales = payloads[2] + hidden_states = payloads[0] + token_final_scales = payloads[2] + + # Compute reference (single-GPU MoE) + all_experts = torch.cat( + [ + create_experts( + num_experts_per_rank, hidden_size, r, "cuda", dtype=torch.bfloat16 + ) + for r in range(ep_size) + ], + dim=0, + ) - # Compute reference (single-GPU MoE) - all_experts = torch.cat( - [ - create_experts( - num_experts_per_rank, hidden_size, r, "cuda", dtype=torch.bfloat16 - ) - for r in range(ep_size) - ], - dim=0, - ) + rank_experts = create_experts( + num_experts_per_rank, hidden_size, rank, "cuda", dtype=torch.bfloat16 + ) - rank_experts = create_experts( - num_experts_per_rank, hidden_size, rank, "cuda", dtype=torch.bfloat16 - ) + reference_output = fake_moe( + hidden_states, + token_selected_experts, + token_final_scales, + all_experts, + is_ep=False, + ) - reference_output = fake_moe( - hidden_states, - token_selected_experts, - token_final_scales, - all_experts, - is_ep=False, - ) + torch.cuda.synchronize() - torch.cuda.synchronize() + # Initialize MoeAlltoAll + MoeAlltoAll._WORKSPACE = None + moe_a2a = MoeAlltoAll( + mapping=mapping, + max_num_tokens=max_num_tokens, + top_k=top_k, + num_experts=ep_size * num_experts_per_rank, + workspace_size_per_rank=workspace_size_per_rank, + ) - # Initialize MoeAlltoAll - MoeAlltoAll._WORKSPACE = None - moe_a2a = MoeAlltoAll( - mapping=mapping, - max_num_tokens=max_num_tokens, - top_k=top_k, - num_experts=ep_size * num_experts_per_rank, - workspace_size_per_rank=workspace_size_per_rank, - ) + # Dispatch + recv_tensors = moe_a2a.dispatch( + token_selected_experts=token_selected_experts, + input_payloads=payloads, + runtime_max_tokens_per_rank=max_num_tokens, + ) - # Dispatch - recv_tensors = moe_a2a.dispatch( - token_selected_experts=token_selected_experts, - input_payloads=payloads, - runtime_max_tokens_per_rank=max_num_tokens, - ) + # Unpack received tensors + hidden_states_recv = recv_tensors[0] # [ep_size, max_tokens, hidden_size] + token_selected_experts_recv = recv_tensors[1] # [ep_size, max_tokens, top_k] + token_final_scales_recv = recv_tensors[2] # [ep_size, max_tokens, top_k] - # Unpack received tensors - hidden_states_recv = recv_tensors[0] # [ep_size, max_tokens, hidden_size] - token_selected_experts_recv = recv_tensors[1] # [ep_size, max_tokens, top_k] - token_final_scales_recv = recv_tensors[2] # [ep_size, max_tokens, top_k] + # Get workspace-backed tensor for output + moe_output = moe_a2a.get_combine_payload_tensor_in_workspace( + runtime_max_tokens_per_rank=max_num_tokens, + hidden_size=hidden_size, + dtype=torch.bfloat16, + ) + moe_output.zero_() + + # Process each rank's tokens with local experts + moe_output.copy_( + fake_moe( + hidden_states_recv.view( + ep_size * max_num_tokens, hidden_states_recv.shape[-1] + ), + token_selected_experts_recv.view( + ep_size * max_num_tokens, token_selected_experts_recv.shape[-1] + ), + token_final_scales_recv.view( + ep_size * max_num_tokens, token_final_scales_recv.shape[-1] + ), + rank_experts, # experts for current rank + is_ep=True, + ep_rank=rank, + num_experts_per_rank=num_experts_per_rank, + ).view(ep_size, max_num_tokens, hidden_size) + ) + except Exception as e: + traceback.print_exc() + comm.allgather(e) + raise e - # Get workspace-backed tensor for output - moe_output = moe_a2a.get_combine_payload_tensor_in_workspace( - runtime_max_tokens_per_rank=max_num_tokens, - hidden_size=hidden_size, - dtype=torch.bfloat16, - ) - moe_output.zero_() - - # Process each rank's tokens with local experts - moe_output.copy_( - fake_moe( - hidden_states_recv.view( - ep_size * max_num_tokens, hidden_states_recv.shape[-1] - ), - token_selected_experts_recv.view( - ep_size * max_num_tokens, token_selected_experts_recv.shape[-1] - ), - token_final_scales_recv.view( - ep_size * max_num_tokens, token_final_scales_recv.shape[-1] - ), - rank_experts, # experts for current rank - is_ep=True, - ep_rank=rank, - num_experts_per_rank=num_experts_per_rank, - ).view(ep_size, max_num_tokens, hidden_size) - ) + exceptions = comm.allgather(None) + if any(exceptions): + raise filter(lambda x: x is not None, exceptions)[0] - # Combine - combined_output = moe_a2a.combine( - payload=moe_output, - runtime_max_tokens_per_rank=max_num_tokens, - payload_in_workspace=True, - ) + try: + # Combine + combined_output = moe_a2a.combine( + payload=moe_output, + runtime_max_tokens_per_rank=max_num_tokens, + payload_in_workspace=True, + ) + + # Verify against reference + torch.testing.assert_close( + combined_output, reference_output, rtol=1e-2, atol=1e-2 + ) + except Exception as e: + traceback.print_exc() + comm.allgather(e) + raise e - # Verify against reference - torch.testing.assert_close(combined_output, reference_output, rtol=1e-2, atol=1e-2) + exceptions = comm.allgather(None) + if any(exceptions): + raise filter(lambda x: x is not None, exceptions)[0] if __name__ == "__main__": From a7c427c0caaddfb20a671ee14ec38015fb1ebaec Mon Sep 17 00:00:00 2001 From: djns99 <40156487+djns99@users.noreply.github.com> Date: Thu, 20 Nov 2025 17:07:25 +1300 Subject: [PATCH 08/25] More exit handling for rank failures --- tests/comm/test_mnnvl_a2a.py | 216 ++++++++++++++++++----------------- 1 file changed, 112 insertions(+), 104 deletions(-) diff --git a/tests/comm/test_mnnvl_a2a.py b/tests/comm/test_mnnvl_a2a.py index 5628c4ed1d..d7b898892f 100644 --- a/tests/comm/test_mnnvl_a2a.py +++ b/tests/comm/test_mnnvl_a2a.py @@ -226,10 +226,10 @@ def run_moe_a2a_dispatch_single_rank( ): """Worker function for MPI testing.""" comm = MPI.COMM_WORLD - rank = comm.Get_rank() - torch.cuda.set_device(rank) - try: + rank = comm.Get_rank() + torch.cuda.set_device(rank) + mapping = Mapping( rank=rank, tp_size=ep_size, @@ -262,79 +262,83 @@ def run_moe_a2a_dispatch_single_rank( payloads, expert_id_payload_index = make_nvfp4_payloads( rank_local_tokens, hidden_size, top_k, rank, token_selected_experts ) - - recv_tensors = moe_a2a.dispatch( - token_selected_experts, - payloads, - max_num_tokens, - invalid_token_expert_id=invalid_token_expert_id, - expert_id_payload_index=expert_id_payload_index, - ) - - # Read counters and compact routing tensors from workspace - send_counters_offset = moe_a2a.metainfo[ - MoeAlltoAll._METAINFO_INDEX["SEND_COUNTERS_OFFSET_INDEX"] - ].item() - recv_counters_offset = moe_a2a.metainfo[ - MoeAlltoAll._METAINFO_INDEX["RECV_COUNTERS_OFFSET_INDEX"] - ].item() - topk_target_ranks_offset = moe_a2a.metainfo[ - MoeAlltoAll._METAINFO_INDEX["TOPK_TARGET_RANKS_OFFSET_INDEX"] - ].item() - topk_send_indices_offset = moe_a2a.metainfo[ - MoeAlltoAll._METAINFO_INDEX["TOPK_SEND_INDICES_OFFSET_INDEX"] - ].item() - - send_counters = ( - moe_a2a.workspace[ - rank, send_counters_offset : send_counters_offset + ep_size * 4 - ] - .view(torch.int32) - .cpu() - ) - recv_counters = ( - moe_a2a.workspace[ - rank, recv_counters_offset : recv_counters_offset + ep_size * 4 - ] - .view(torch.int32) - .cpu() - ) - topk_target_ranks = ( - moe_a2a.workspace[ - rank, - topk_target_ranks_offset : topk_target_ranks_offset - + max_num_tokens * top_k * 4, - ] - .view(torch.int32) - .view(max_num_tokens, top_k) - .cpu() - ) - topk_send_indices = ( - moe_a2a.workspace[ - rank, - topk_send_indices_offset : topk_send_indices_offset - + max_num_tokens * top_k * 4, - ] - .view(torch.int32) - .view(max_num_tokens, top_k) - .cpu() - ) - - # Return results to be collected (move to CPU for MPI transfer) - return ( - token_selected_experts.cpu(), - [p.cpu() for p in payloads], - [rt.cpu() for rt in recv_tensors], - send_counters, - topk_send_indices, - topk_target_ranks, - recv_counters, - expert_id_payload_index, - ) except Exception: traceback.print_exc() + comm.allgather(True) raise + if any(comm.allgather(False)): + raise Exception("Another rank failed") + + recv_tensors = moe_a2a.dispatch( + token_selected_experts, + payloads, + max_num_tokens, + invalid_token_expert_id=invalid_token_expert_id, + expert_id_payload_index=expert_id_payload_index, + ) + + # Read counters and compact routing tensors from workspace + send_counters_offset = moe_a2a.metainfo[ + MoeAlltoAll._METAINFO_INDEX["SEND_COUNTERS_OFFSET_INDEX"] + ].item() + recv_counters_offset = moe_a2a.metainfo[ + MoeAlltoAll._METAINFO_INDEX["RECV_COUNTERS_OFFSET_INDEX"] + ].item() + topk_target_ranks_offset = moe_a2a.metainfo[ + MoeAlltoAll._METAINFO_INDEX["TOPK_TARGET_RANKS_OFFSET_INDEX"] + ].item() + topk_send_indices_offset = moe_a2a.metainfo[ + MoeAlltoAll._METAINFO_INDEX["TOPK_SEND_INDICES_OFFSET_INDEX"] + ].item() + + send_counters = ( + moe_a2a.workspace[ + rank, send_counters_offset : send_counters_offset + ep_size * 4 + ] + .view(torch.int32) + .cpu() + ) + recv_counters = ( + moe_a2a.workspace[ + rank, recv_counters_offset : recv_counters_offset + ep_size * 4 + ] + .view(torch.int32) + .cpu() + ) + topk_target_ranks = ( + moe_a2a.workspace[ + rank, + topk_target_ranks_offset : topk_target_ranks_offset + + max_num_tokens * top_k * 4, + ] + .view(torch.int32) + .view(max_num_tokens, top_k) + .cpu() + ) + topk_send_indices = ( + moe_a2a.workspace[ + rank, + topk_send_indices_offset : topk_send_indices_offset + + max_num_tokens * top_k * 4, + ] + .view(torch.int32) + .view(max_num_tokens, top_k) + .cpu() + ) + + # Return results to be collected (move to CPU for MPI transfer) + return ( + token_selected_experts.cpu(), + [p.cpu() for p in payloads], + [rt.cpu() for rt in recv_tensors], + send_counters, + topk_send_indices, + topk_target_ranks, + recv_counters, + expert_id_payload_index, + ) + def verify_dispatch( all_token_selected_experts, @@ -538,19 +542,19 @@ def test_moe_a2a_dispatch(ep_size, all_num_tokens, top_k): pytest.skip(f"Test requires exactly {ep_size} ranks") try: - MnnvlMemory.initialize() - if not MnnvlMemory.supports_mnnvl(): + try: + MnnvlMemory.initialize() + if not MnnvlMemory.supports_mnnvl(): + pytest.skip("MNNVL not supported on this system") + except Exception: pytest.skip("MNNVL not supported on this system") - except Exception: - pytest.skip("MNNVL not supported on this system") - hidden_size = 1024 - num_experts_per_rank = 8 - workspace_size_per_rank = 512 * 1024 * 1024 - invalid_token_expert_id = -1 + hidden_size = 1024 + num_experts_per_rank = 8 + workspace_size_per_rank = 512 * 1024 * 1024 + invalid_token_expert_id = -1 - # Run dispatch on this rank - try: + # Run dispatch on this rank result = run_moe_a2a_dispatch_single_rank( ep_size, all_num_tokens, @@ -562,12 +566,11 @@ def test_moe_a2a_dispatch(ep_size, all_num_tokens, top_k): ) except Exception as e: traceback.print_exc() - comm.allgather(e) + comm.allgather(True) raise e - exceptions = comm.allgather(None) - if any(exceptions): - raise filter(lambda x: x is not None, exceptions)[0] + if any(comm.allgather(False)): + raise Exception("Another rank failed") # Gather results from all ranks all_results = comm.allgather(result) @@ -631,19 +634,18 @@ def test_moe_a2a_dispatch_moe_combine(ep_size, all_num_tokens, top_k): pytest.skip(f"Test requires exactly {ep_size} ranks") try: - MnnvlMemory.initialize() - if not MnnvlMemory.supports_mnnvl(): + try: + MnnvlMemory.initialize() + if not MnnvlMemory.supports_mnnvl(): + pytest.skip("MNNVL not supported on this system") + except Exception: pytest.skip("MNNVL not supported on this system") - except Exception: - pytest.skip("MNNVL not supported on this system") - - torch.cuda.set_device(rank) - hidden_size = 2880 # gpt-oss - num_experts_per_rank = 8 - workspace_size_per_rank = 512 * 1024 * 1024 + torch.cuda.set_device(rank) - try: + hidden_size = 2880 # gpt-oss + num_experts_per_rank = 8 + workspace_size_per_rank = 512 * 1024 * 1024 mapping = Mapping( rank=rank, moe_ep_size=world_size, @@ -700,7 +702,15 @@ def test_moe_a2a_dispatch_moe_combine(ep_size, all_num_tokens, top_k): num_experts=ep_size * num_experts_per_rank, workspace_size_per_rank=workspace_size_per_rank, ) + except Exception as e: + traceback.print_exc() + comm.allgather(True) + raise e + if any(comm.allgather(False)): + raise Exception("Another rank failed") + + try: # Dispatch recv_tensors = moe_a2a.dispatch( token_selected_experts=token_selected_experts, @@ -741,12 +751,11 @@ def test_moe_a2a_dispatch_moe_combine(ep_size, all_num_tokens, top_k): ) except Exception as e: traceback.print_exc() - comm.allgather(e) + comm.allgather(True) raise e - exceptions = comm.allgather(None) - if any(exceptions): - raise filter(lambda x: x is not None, exceptions)[0] + if any(comm.allgather(False)): + raise Exception("Another rank failed") try: # Combine @@ -762,12 +771,11 @@ def test_moe_a2a_dispatch_moe_combine(ep_size, all_num_tokens, top_k): ) except Exception as e: traceback.print_exc() - comm.allgather(e) + comm.allgather(True) raise e - exceptions = comm.allgather(None) - if any(exceptions): - raise filter(lambda x: x is not None, exceptions)[0] + if any(comm.allgather(False)): + raise Exception("Another rank failed") if __name__ == "__main__": From fb2b9b2555b034f401b21944595d17abcb388311 Mon Sep 17 00:00:00 2001 From: djns99 <40156487+djns99@users.noreply.github.com> Date: Thu, 20 Nov 2025 17:36:44 +1300 Subject: [PATCH 09/25] Cleaner test implementation --- tests/comm/test_mnnvl_a2a.py | 424 ++++++++++++++++++----------------- 1 file changed, 214 insertions(+), 210 deletions(-) diff --git a/tests/comm/test_mnnvl_a2a.py b/tests/comm/test_mnnvl_a2a.py index d7b898892f..8a1e84c94e 100644 --- a/tests/comm/test_mnnvl_a2a.py +++ b/tests/comm/test_mnnvl_a2a.py @@ -24,6 +24,28 @@ from flashinfer.comm.mnnvl import MnnvlMemory +class MPIExit(Exception): + pass + + +def check_any_rank_failed(): + comm = MPI.COMM_WORLD + if any(comm.allgather(False)): + raise MPIExit("Another rank failed") + + +def safe_run(func, *args, **kwargs): + comm = MPI.COMM_WORLD + try: + func(*args, **kwargs) + except MPIExit as e: + raise e + except Exception as e: + traceback.print_exc() + comm.allgather(True) + raise e + + @pytest.fixture(autouse=True) def setup_test(): torch.manual_seed(0x1234) @@ -226,49 +248,45 @@ def run_moe_a2a_dispatch_single_rank( ): """Worker function for MPI testing.""" comm = MPI.COMM_WORLD - try: - rank = comm.Get_rank() - torch.cuda.set_device(rank) - - mapping = Mapping( - rank=rank, - tp_size=ep_size, - moe_ep_size=ep_size, - world_size=ep_size, - gpus_per_node=ep_size, - pp_size=1, - cp_size=1, - ) + rank = comm.Get_rank() + torch.cuda.set_device(rank) + + check_any_rank_failed() + + mapping = Mapping( + rank=rank, + tp_size=ep_size, + moe_ep_size=ep_size, + world_size=ep_size, + gpus_per_node=ep_size, + pp_size=1, + cp_size=1, + ) - # Create MoeAlltoAll manager - max_num_tokens = max(all_num_tokens) + # Create MoeAlltoAll manager + max_num_tokens = max(all_num_tokens) - MoeAlltoAll._WORKSPACE = None - moe_a2a = MoeAlltoAll( - mapping, - max_num_tokens, - top_k, - ep_size * num_experts_per_rank, - workspace_size_per_rank, - ) + MoeAlltoAll._WORKSPACE = None + moe_a2a = MoeAlltoAll( + mapping, + max_num_tokens, + top_k, + ep_size * num_experts_per_rank, + workspace_size_per_rank, + ) - # Get the number of tokens for this specific rank (same as single-GPU) - rank_local_tokens = all_num_tokens[rank] + # Get the number of tokens for this specific rank (same as single-GPU) + rank_local_tokens = all_num_tokens[rank] - # Generate data using helper functions - token_selected_experts = generate_token_selected_experts( - rank_local_tokens, ep_size, num_experts_per_rank, top_k - ) - payloads, expert_id_payload_index = make_nvfp4_payloads( - rank_local_tokens, hidden_size, top_k, rank, token_selected_experts - ) - except Exception: - traceback.print_exc() - comm.allgather(True) - raise + # Generate data using helper functions + token_selected_experts = generate_token_selected_experts( + rank_local_tokens, ep_size, num_experts_per_rank, top_k + ) + payloads, expert_id_payload_index = make_nvfp4_payloads( + rank_local_tokens, hidden_size, top_k, rank, token_selected_experts + ) - if any(comm.allgather(False)): - raise Exception("Another rank failed") + check_any_rank_failed() recv_tensors = moe_a2a.dispatch( token_selected_experts, @@ -512,22 +530,7 @@ def verify_dispatch( assert torch.all(token_expert_ids == invalid_token_expert_id) -@pytest.mark.parametrize( - "ep_size,all_num_tokens,top_k", - [ - # Basic configurations - (4, [32, 32, 32, 32], 2), # Four ranks with uniform distribution - (4, [16, 32, 64, 48], 2), # Four ranks with non-uniform distribution - (2, [100, 50], 2), # Two ranks with different loads - (8, [10, 20, 30, 40, 50, 60, 70, 80], 2), # Eight ranks with increasing load - # Different top_k values - (4, [32, 32, 32, 32], 4), # Four ranks with top_k = 4 - (4, [32, 32, 32, 32], 8), # Four ranks with top_k = 8 - # Edge cases - (4, [1, 1, 1, 1], 2), # Four ranks with single token per rank - ], -) -def test_moe_a2a_dispatch(ep_size, all_num_tokens, top_k): +def test_moe_a2a_dispatch_impl(ep_size, all_num_tokens, top_k): """Test MoE A2A dispatch operation.""" if len(all_num_tokens) != ep_size: pytest.skip( @@ -542,35 +545,31 @@ def test_moe_a2a_dispatch(ep_size, all_num_tokens, top_k): pytest.skip(f"Test requires exactly {ep_size} ranks") try: - try: - MnnvlMemory.initialize() - if not MnnvlMemory.supports_mnnvl(): - pytest.skip("MNNVL not supported on this system") - except Exception: + MnnvlMemory.initialize() + if not MnnvlMemory.supports_mnnvl(): pytest.skip("MNNVL not supported on this system") + except Exception: + pytest.skip("MNNVL not supported on this system") - hidden_size = 1024 - num_experts_per_rank = 8 - workspace_size_per_rank = 512 * 1024 * 1024 - invalid_token_expert_id = -1 - - # Run dispatch on this rank - result = run_moe_a2a_dispatch_single_rank( - ep_size, - all_num_tokens, - top_k, - workspace_size_per_rank, - num_experts_per_rank, - hidden_size, - invalid_token_expert_id, - ) - except Exception as e: - traceback.print_exc() - comm.allgather(True) - raise e + hidden_size = 1024 + num_experts_per_rank = 8 + workspace_size_per_rank = 512 * 1024 * 1024 + invalid_token_expert_id = -1 - if any(comm.allgather(False)): - raise Exception("Another rank failed") + check_any_rank_failed() + + # Run dispatch on this rank + result = run_moe_a2a_dispatch_single_rank( + ep_size, + all_num_tokens, + top_k, + workspace_size_per_rank, + num_experts_per_rank, + hidden_size, + invalid_token_expert_id, + ) + + check_any_rank_failed() # Gather results from all ranks all_results = comm.allgather(result) @@ -611,15 +610,24 @@ def test_moe_a2a_dispatch(ep_size, all_num_tokens, top_k): @pytest.mark.parametrize( "ep_size,all_num_tokens,top_k", [ - (4, [32, 32, 32, 32], 2), - (4, [16, 32, 64, 48], 2), - (2, [100, 50], 2), - (4, [32, 32, 32, 32], 4), - (4, [1, 1, 1, 1], 2), - (8, [640, 640, 640, 640, 640, 640, 640, 640], 4), + # Basic configurations + (4, [32, 32, 32, 32], 2), # Four ranks with uniform distribution + (4, [16, 32, 64, 48], 2), # Four ranks with non-uniform distribution + (2, [100, 50], 2), # Two ranks with different loads + (8, [10, 20, 30, 40, 50, 60, 70, 80], 2), # Eight ranks with increasing load + # Different top_k values + (4, [32, 32, 32, 32], 4), # Four ranks with top_k = 4 + (4, [32, 32, 32, 32], 8), # Four ranks with top_k = 8 + # Edge cases + (4, [1, 1, 1, 1], 2), # Four ranks with single token per rank ], ) -def test_moe_a2a_dispatch_moe_combine(ep_size, all_num_tokens, top_k): +def test_moe_a2a_dispatch(ep_size, all_num_tokens, top_k): + """Test MoE A2A dispatch operation.""" + safe_run(test_moe_a2a_dispatch_impl, ep_size, all_num_tokens, top_k) + + +def test_moe_a2a_dispatch_moe_combine_impl(ep_size, all_num_tokens, top_k): """Test full MoE A2A dispatch + expert processing + combine cycle.""" if len(all_num_tokens) != ep_size: pytest.skip( @@ -634,148 +642,144 @@ def test_moe_a2a_dispatch_moe_combine(ep_size, all_num_tokens, top_k): pytest.skip(f"Test requires exactly {ep_size} ranks") try: - try: - MnnvlMemory.initialize() - if not MnnvlMemory.supports_mnnvl(): - pytest.skip("MNNVL not supported on this system") - except Exception: + MnnvlMemory.initialize() + if not MnnvlMemory.supports_mnnvl(): pytest.skip("MNNVL not supported on this system") + except Exception: + pytest.skip("MNNVL not supported on this system") - torch.cuda.set_device(rank) + torch.cuda.set_device(rank) - hidden_size = 2880 # gpt-oss - num_experts_per_rank = 8 - workspace_size_per_rank = 512 * 1024 * 1024 - mapping = Mapping( - rank=rank, - moe_ep_size=world_size, - tp_size=world_size, - world_size=world_size, - ) + check_any_rank_failed() - local_num_tokens = all_num_tokens[rank] - max_num_tokens = max(all_num_tokens) + hidden_size = 2880 # gpt-oss + num_experts_per_rank = 8 + workspace_size_per_rank = 512 * 1024 * 1024 + mapping = Mapping( + rank=rank, + moe_ep_size=world_size, + tp_size=world_size, + world_size=world_size, + ) - # Generate inputs - token_selected_experts = generate_token_selected_experts( - local_num_tokens, ep_size, num_experts_per_rank, top_k - ) + local_num_tokens = all_num_tokens[rank] + max_num_tokens = max(all_num_tokens) - payloads, expert_id_payload_index = make_bfloat16_payloads( - local_num_tokens, hidden_size, top_k, rank, token_selected_experts - ) + # Generate inputs + token_selected_experts = generate_token_selected_experts( + local_num_tokens, ep_size, num_experts_per_rank, top_k + ) - hidden_states = payloads[0] - token_final_scales = payloads[2] - - # Compute reference (single-GPU MoE) - all_experts = torch.cat( - [ - create_experts( - num_experts_per_rank, hidden_size, r, "cuda", dtype=torch.bfloat16 - ) - for r in range(ep_size) - ], - dim=0, - ) + payloads, expert_id_payload_index = make_bfloat16_payloads( + local_num_tokens, hidden_size, top_k, rank, token_selected_experts + ) - rank_experts = create_experts( - num_experts_per_rank, hidden_size, rank, "cuda", dtype=torch.bfloat16 - ) + hidden_states = payloads[0] + token_final_scales = payloads[2] - reference_output = fake_moe( - hidden_states, - token_selected_experts, - token_final_scales, - all_experts, - is_ep=False, - ) + # Compute reference (single-GPU MoE) + all_experts = torch.cat( + [ + create_experts( + num_experts_per_rank, hidden_size, r, "cuda", dtype=torch.bfloat16 + ) + for r in range(ep_size) + ], + dim=0, + ) - torch.cuda.synchronize() + rank_experts = create_experts( + num_experts_per_rank, hidden_size, rank, "cuda", dtype=torch.bfloat16 + ) - # Initialize MoeAlltoAll - MoeAlltoAll._WORKSPACE = None - moe_a2a = MoeAlltoAll( - mapping=mapping, - max_num_tokens=max_num_tokens, - top_k=top_k, - num_experts=ep_size * num_experts_per_rank, - workspace_size_per_rank=workspace_size_per_rank, - ) - except Exception as e: - traceback.print_exc() - comm.allgather(True) - raise e + reference_output = fake_moe( + hidden_states, + token_selected_experts, + token_final_scales, + all_experts, + is_ep=False, + ) - if any(comm.allgather(False)): - raise Exception("Another rank failed") + # Initialize MoeAlltoAll + MoeAlltoAll._WORKSPACE = None + moe_a2a = MoeAlltoAll( + mapping=mapping, + max_num_tokens=max_num_tokens, + top_k=top_k, + num_experts=ep_size * num_experts_per_rank, + workspace_size_per_rank=workspace_size_per_rank, + ) - try: - # Dispatch - recv_tensors = moe_a2a.dispatch( - token_selected_experts=token_selected_experts, - input_payloads=payloads, - runtime_max_tokens_per_rank=max_num_tokens, - ) + check_any_rank_failed() - # Unpack received tensors - hidden_states_recv = recv_tensors[0] # [ep_size, max_tokens, hidden_size] - token_selected_experts_recv = recv_tensors[1] # [ep_size, max_tokens, top_k] - token_final_scales_recv = recv_tensors[2] # [ep_size, max_tokens, top_k] + # Dispatch + recv_tensors = moe_a2a.dispatch( + token_selected_experts=token_selected_experts, + input_payloads=payloads, + runtime_max_tokens_per_rank=max_num_tokens, + ) - # Get workspace-backed tensor for output - moe_output = moe_a2a.get_combine_payload_tensor_in_workspace( - runtime_max_tokens_per_rank=max_num_tokens, - hidden_size=hidden_size, - dtype=torch.bfloat16, - ) - moe_output.zero_() - - # Process each rank's tokens with local experts - moe_output.copy_( - fake_moe( - hidden_states_recv.view( - ep_size * max_num_tokens, hidden_states_recv.shape[-1] - ), - token_selected_experts_recv.view( - ep_size * max_num_tokens, token_selected_experts_recv.shape[-1] - ), - token_final_scales_recv.view( - ep_size * max_num_tokens, token_final_scales_recv.shape[-1] - ), - rank_experts, # experts for current rank - is_ep=True, - ep_rank=rank, - num_experts_per_rank=num_experts_per_rank, - ).view(ep_size, max_num_tokens, hidden_size) - ) - except Exception as e: - traceback.print_exc() - comm.allgather(True) - raise e + # Unpack received tensors + hidden_states_recv = recv_tensors[0] # [ep_size, max_tokens, hidden_size] + token_selected_experts_recv = recv_tensors[1] # [ep_size, max_tokens, top_k] + token_final_scales_recv = recv_tensors[2] # [ep_size, max_tokens, top_k] - if any(comm.allgather(False)): - raise Exception("Another rank failed") + # Get workspace-backed tensor for output + moe_output = moe_a2a.get_combine_payload_tensor_in_workspace( + runtime_max_tokens_per_rank=max_num_tokens, + hidden_size=hidden_size, + dtype=torch.bfloat16, + ) + moe_output.zero_() + + # Process each rank's tokens with local experts + moe_output.copy_( + fake_moe( + hidden_states_recv.view( + ep_size * max_num_tokens, hidden_states_recv.shape[-1] + ), + token_selected_experts_recv.view( + ep_size * max_num_tokens, token_selected_experts_recv.shape[-1] + ), + token_final_scales_recv.view( + ep_size * max_num_tokens, token_final_scales_recv.shape[-1] + ), + rank_experts, # experts for current rank + is_ep=True, + ep_rank=rank, + num_experts_per_rank=num_experts_per_rank, + ).view(ep_size, max_num_tokens, hidden_size) + ) - try: - # Combine - combined_output = moe_a2a.combine( - payload=moe_output, - runtime_max_tokens_per_rank=max_num_tokens, - payload_in_workspace=True, - ) + check_any_rank_failed() - # Verify against reference - torch.testing.assert_close( - combined_output, reference_output, rtol=1e-2, atol=1e-2 - ) - except Exception as e: - traceback.print_exc() - comm.allgather(True) - raise e + # Combine + combined_output = moe_a2a.combine( + payload=moe_output, + runtime_max_tokens_per_rank=max_num_tokens, + payload_in_workspace=True, + ) - if any(comm.allgather(False)): - raise Exception("Another rank failed") + # Verify against reference + torch.testing.assert_close(combined_output, reference_output, rtol=1e-2, atol=1e-2) + + check_any_rank_failed() + + +@pytest.mark.parametrize( + "ep_size,all_num_tokens,top_k", + [ + (4, [32, 32, 32, 32], 2), + (4, [16, 32, 64, 48], 2), + (2, [100, 50], 2), + (4, [32, 32, 32, 32], 4), + (4, [1, 1, 1, 1], 2), + (8, [640, 640, 640, 640, 640, 640, 640, 640], 4), + ], +) +def test_moe_a2a_dispatch_moe_combine(ep_size, all_num_tokens, top_k): + """Test full MoE A2A dispatch + expert processing + combine cycle.""" + safe_run(test_moe_a2a_dispatch_moe_combine_impl, ep_size, all_num_tokens, top_k) if __name__ == "__main__": From 7d55b4900a7f850bd422cbd2ff4246ee1105b76a Mon Sep 17 00:00:00 2001 From: djns99 <40156487+djns99@users.noreply.github.com> Date: Wed, 26 Nov 2025 13:55:12 +1300 Subject: [PATCH 10/25] Update MNNVL config setup --- flashinfer/comm/trtllm_moe_a2a.py | 5 ++++- tests/comm/test_mnnvl_a2a.py | 8 ++++---- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/flashinfer/comm/trtllm_moe_a2a.py b/flashinfer/comm/trtllm_moe_a2a.py index 58f0a1834a..df47a8f0fd 100644 --- a/flashinfer/comm/trtllm_moe_a2a.py +++ b/flashinfer/comm/trtllm_moe_a2a.py @@ -14,7 +14,7 @@ import torch import functools -from .mnnvl import MnnvlMemory +from .mnnvl import MnnvlMemory, MnnvlConfig from .mapping import Mapping from ..jit.comm import gen_mnnvl_a2a_module from ..utils import register_custom_op @@ -346,6 +346,7 @@ def __init__( top_k: int, num_experts: int, workspace_size_per_rank: int = 512 * 1024 * 1024, + mnnvl_config: Optional[MnnvlConfig] = None, ): """ Initialize MoeAlltoAll with workspace allocation. @@ -362,6 +363,8 @@ def __init__( # Initialize MNNVL memory system MnnvlMemory.initialize() + if mnnvl_config: + MnnvlMemory.set_comm_from_config(mapping, mnnvl_config) # type: ignore[attr-defined] self.workspace_size_per_rank = workspace_size_per_rank self.max_num_tokens = max_num_tokens diff --git a/tests/comm/test_mnnvl_a2a.py b/tests/comm/test_mnnvl_a2a.py index 8a1e84c94e..904d7d9596 100644 --- a/tests/comm/test_mnnvl_a2a.py +++ b/tests/comm/test_mnnvl_a2a.py @@ -530,7 +530,7 @@ def verify_dispatch( assert torch.all(token_expert_ids == invalid_token_expert_id) -def test_moe_a2a_dispatch_impl(ep_size, all_num_tokens, top_k): +def moe_a2a_dispatch_test_impl(ep_size, all_num_tokens, top_k): """Test MoE A2A dispatch operation.""" if len(all_num_tokens) != ep_size: pytest.skip( @@ -624,10 +624,10 @@ def test_moe_a2a_dispatch_impl(ep_size, all_num_tokens, top_k): ) def test_moe_a2a_dispatch(ep_size, all_num_tokens, top_k): """Test MoE A2A dispatch operation.""" - safe_run(test_moe_a2a_dispatch_impl, ep_size, all_num_tokens, top_k) + safe_run(moe_a2a_dispatch_test_impl, ep_size, all_num_tokens, top_k) -def test_moe_a2a_dispatch_moe_combine_impl(ep_size, all_num_tokens, top_k): +def moe_a2a_dispatch_moe_combine_test_impl(ep_size, all_num_tokens, top_k): """Test full MoE A2A dispatch + expert processing + combine cycle.""" if len(all_num_tokens) != ep_size: pytest.skip( @@ -779,7 +779,7 @@ def test_moe_a2a_dispatch_moe_combine_impl(ep_size, all_num_tokens, top_k): ) def test_moe_a2a_dispatch_moe_combine(ep_size, all_num_tokens, top_k): """Test full MoE A2A dispatch + expert processing + combine cycle.""" - safe_run(test_moe_a2a_dispatch_moe_combine_impl, ep_size, all_num_tokens, top_k) + safe_run(moe_a2a_dispatch_moe_combine_test_impl, ep_size, all_num_tokens, top_k) if __name__ == "__main__": From a033a94303e5e601c7d1d00b497c5db3b57172fe Mon Sep 17 00:00:00 2001 From: djns99 <40156487+djns99@users.noreply.github.com> Date: Wed, 26 Nov 2025 14:24:57 +1300 Subject: [PATCH 11/25] Update test to get ep size from MPI --- scripts/task_test_multi_node_comm_kernels.sh | 1 + tests/comm/test_mnnvl_a2a.py | 89 +++++++++++--------- tests/comm/test_mnnvl_memory.py | 2 +- 3 files changed, 51 insertions(+), 41 deletions(-) diff --git a/scripts/task_test_multi_node_comm_kernels.sh b/scripts/task_test_multi_node_comm_kernels.sh index f1dcedc93b..7a60f463a5 100644 --- a/scripts/task_test_multi_node_comm_kernels.sh +++ b/scripts/task_test_multi_node_comm_kernels.sh @@ -16,3 +16,4 @@ pip install -e . -v pytest -s tests/comm/test_mnnvl_memory.py pytest -s tests/comm/test_trtllm_mnnvl_allreduce.py +pytest -s tests/comm/test_mnnvl_a2a.py diff --git a/tests/comm/test_mnnvl_a2a.py b/tests/comm/test_mnnvl_a2a.py index 904d7d9596..01e3141b45 100644 --- a/tests/comm/test_mnnvl_a2a.py +++ b/tests/comm/test_mnnvl_a2a.py @@ -14,6 +14,7 @@ # limitations under the License. import traceback +import random import pytest import torch @@ -249,7 +250,11 @@ def run_moe_a2a_dispatch_single_rank( """Worker function for MPI testing.""" comm = MPI.COMM_WORLD rank = comm.Get_rank() - torch.cuda.set_device(rank) + + # get local rank + node_comm = comm.Split_type(MPI.COMM_TYPE_SHARED) + node_local_rank = node_comm.Get_rank() + torch.cuda.set_device(node_local_rank) check_any_rank_failed() @@ -530,19 +535,19 @@ def verify_dispatch( assert torch.all(token_expert_ids == invalid_token_expert_id) -def moe_a2a_dispatch_test_impl(ep_size, all_num_tokens, top_k): +def moe_a2a_dispatch_test_impl(distribution, top_k): """Test MoE A2A dispatch operation.""" - if len(all_num_tokens) != ep_size: - pytest.skip( - f"all_num_tokens length {len(all_num_tokens)} must match ep_size {ep_size}" - ) - comm = MPI.COMM_WORLD - # rank = comm.Get_rank() world_size = comm.Get_size() + ep_size = world_size - if world_size != ep_size: - pytest.skip(f"Test requires exactly {ep_size} ranks") + if distribution == "random": + random.seed(0xD5) + all_num_tokens = [random.randint(1, 100) for _ in range(world_size)] + elif distribution == "uniform": + all_num_tokens = [50] * world_size + else: + pytest.skip(f"Invalid distribution: {distribution}") try: MnnvlMemory.initialize() @@ -608,38 +613,37 @@ def moe_a2a_dispatch_test_impl(ep_size, all_num_tokens, top_k): @pytest.mark.parametrize( - "ep_size,all_num_tokens,top_k", + "distribution,top_k", [ - # Basic configurations - (4, [32, 32, 32, 32], 2), # Four ranks with uniform distribution - (4, [16, 32, 64, 48], 2), # Four ranks with non-uniform distribution - (2, [100, 50], 2), # Two ranks with different loads - (8, [10, 20, 30, 40, 50, 60, 70, 80], 2), # Eight ranks with increasing load - # Different top_k values - (4, [32, 32, 32, 32], 4), # Four ranks with top_k = 4 - (4, [32, 32, 32, 32], 8), # Four ranks with top_k = 8 - # Edge cases - (4, [1, 1, 1, 1], 2), # Four ranks with single token per rank + ("random", 1), # topk=1 with random distribution + ("uniform", 1), # topk=1 with uniform distribution + ("random", 2), # topk=2 with random distribution + ("uniform", 2), # topk=2 with uniform distribution + ("random", 8), # topk=8 with random distribution + ("uniform", 8), # topk=8 with uniform distribution + ("random", 64), # topk=64 with random distribution + ("uniform", 64), # topk=64 with uniform distribution ], ) -def test_moe_a2a_dispatch(ep_size, all_num_tokens, top_k): +def test_moe_a2a_dispatch(distribution, top_k): """Test MoE A2A dispatch operation.""" - safe_run(moe_a2a_dispatch_test_impl, ep_size, all_num_tokens, top_k) + safe_run(moe_a2a_dispatch_test_impl, distribution, top_k) -def moe_a2a_dispatch_moe_combine_test_impl(ep_size, all_num_tokens, top_k): +def moe_a2a_dispatch_moe_combine_test_impl(distribution, top_k): """Test full MoE A2A dispatch + expert processing + combine cycle.""" - if len(all_num_tokens) != ep_size: - pytest.skip( - f"all_num_tokens length {len(all_num_tokens)} must match ep_size {ep_size}" - ) comm = MPI.COMM_WORLD rank = comm.Get_rank() world_size = comm.Get_size() + ep_size = world_size - if world_size != ep_size: - pytest.skip(f"Test requires exactly {ep_size} ranks") + if distribution == "random": + all_num_tokens = [random.randint(1, 100) for _ in range(world_size)] + elif distribution == "uniform": + all_num_tokens = [50] * world_size + else: + pytest.skip(f"Invalid distribution: {distribution}") try: MnnvlMemory.initialize() @@ -648,7 +652,10 @@ def moe_a2a_dispatch_moe_combine_test_impl(ep_size, all_num_tokens, top_k): except Exception: pytest.skip("MNNVL not supported on this system") - torch.cuda.set_device(rank) + # get local rank + node_comm = comm.Split_type(MPI.COMM_TYPE_SHARED) + node_local_rank = node_comm.Get_rank() + torch.cuda.set_device(node_local_rank) check_any_rank_failed() @@ -767,19 +774,21 @@ def moe_a2a_dispatch_moe_combine_test_impl(ep_size, all_num_tokens, top_k): @pytest.mark.parametrize( - "ep_size,all_num_tokens,top_k", + "distribution,top_k", [ - (4, [32, 32, 32, 32], 2), - (4, [16, 32, 64, 48], 2), - (2, [100, 50], 2), - (4, [32, 32, 32, 32], 4), - (4, [1, 1, 1, 1], 2), - (8, [640, 640, 640, 640, 640, 640, 640, 640], 4), + ("random", 1), # topk=1 with random distribution + ("uniform", 1), # topk=1 with uniform distribution + ("random", 2), # topk=2 with random distribution + ("uniform", 2), # topk=2 with uniform distribution + ("random", 8), # topk=8 with random distribution + ("uniform", 8), # topk=8 with uniform distribution + ("random", 64), # topk=64 with random distribution + ("uniform", 64), # topk=64 with uniform distribution ], ) -def test_moe_a2a_dispatch_moe_combine(ep_size, all_num_tokens, top_k): +def test_moe_a2a_dispatch_moe_combine(distribution, top_k): """Test full MoE A2A dispatch + expert processing + combine cycle.""" - safe_run(moe_a2a_dispatch_moe_combine_test_impl, ep_size, all_num_tokens, top_k) + safe_run(moe_a2a_dispatch_moe_combine_test_impl, distribution, top_k) if __name__ == "__main__": diff --git a/tests/comm/test_mnnvl_memory.py b/tests/comm/test_mnnvl_memory.py index bbda852f06..06e0d7fe8a 100644 --- a/tests/comm/test_mnnvl_memory.py +++ b/tests/comm/test_mnnvl_memory.py @@ -122,7 +122,7 @@ def test_mnnvl_memory(self): reason="Mnnvl memory is not supported on this platform", ) def test_moe_alltoall_multi_rank_single_gpu(self): - torch.cuda.set_device(self.rank) + torch.cuda.set_device(self.local_rank) max_world_size = 8 assert self.world_size <= max_world_size, ( f"should run with world_size at most {max_world_size}" From e879354ea72d626dccd278239f4f0d6bf6c67009 Mon Sep 17 00:00:00 2001 From: djns99 <40156487+djns99@users.noreply.github.com> Date: Wed, 26 Nov 2025 17:09:10 +1300 Subject: [PATCH 12/25] Update tests with better test bounds --- .../moeAlltoAllKernels.cu | 30 +++++++- tests/comm/test_mnnvl_a2a.py | 76 ++++++++++++++----- 2 files changed, 87 insertions(+), 19 deletions(-) diff --git a/csrc/nv_internal/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.cu b/csrc/nv_internal/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.cu index e46c0e9e63..a448a09763 100644 --- a/csrc/nv_internal/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.cu +++ b/csrc/nv_internal/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.cu @@ -29,6 +29,9 @@ namespace tensorrt_llm::kernels::mnnvl_throughput { #define ENABLE_DEBUG_PRINT 0 #define DISABLE_SYNC_FOR_PROFILING 0 +#ifndef DISABLE_TIMEOUT +#define DISABLE_TIMEOUT 0 +#endif // Helper function for ceiling division template @@ -104,6 +107,13 @@ __host__ __device__ inline T ceilDiv(T m, T n) { __VA_ARGS__ \ } +#if DISABLE_TIMEOUT +#define check_timeout(s) false +#else +// 300 * 2000 MHz - should be high enough on any GPU but will prevent a hang +#define check_timeout(s) ((clock64() - (s)) > (300ll * 2000ll * 1000ll * 1000ll)) +#endif + // ============================================================================ // Helper Functions for Expert-to-Rank Mapping // ============================================================================ @@ -393,6 +403,7 @@ __global__ void moeA2ADispatchKernel( #pragma unroll 1 // No unroll for (int peer_rank = lane_id; peer_rank < ep_size; peer_rank += warpSize) { bool flag_set = false; + [[maybe_unused]] clock_t s = clock64(); do { uint32_t* flag_ptr = &ptrs.completion_flags[rank_id][peer_rank]; uint32_t flag_value; @@ -406,7 +417,14 @@ __global__ void moeA2ADispatchKernel( rank_id, peer_rank, flag_value, expected_value, flag_ptr); #endif flag_set = flag_value == expected_value; - } while (!flag_set); + } while (!flag_set || check_timeout(s)); + + if (__builtin_expect(!flag_set, 0)) { + printf("dispatch: ---Rank %d timed out waiting for completion flag from rank %d\n", + rank_id, peer_rank); + asm volatile("trap;"); + return; + } } // asm volatile("fence.acquire.sys;"); #endif @@ -690,6 +708,7 @@ __global__ void moeA2ACombineKernel( #pragma unroll 1 // No unroll for (int peer_rank = lane_id; peer_rank < ep_size; peer_rank += warpSize) { bool flag_set = false; + [[maybe_unused]] clock_t s = clock64(); do { uint32_t* flag_ptr = &ptrs.completion_flags[rank_id][peer_rank]; uint32_t flag_value; @@ -703,7 +722,14 @@ __global__ void moeA2ACombineKernel( rank_id, peer_rank, flag_value, expected_value, flag_ptr); #endif flag_set = flag_value == expected_value; - } while (!flag_set); + } while (!flag_set || check_timeout(s)); + + if (__builtin_expect(!flag_set, 0)) { + printf("combine: ---Rank %d timed out waiting for completion flag from rank %d\n", rank_id, + peer_rank); + asm volatile("trap;"); + return; + } } asm volatile("fence.acquire.sys;"); } diff --git a/tests/comm/test_mnnvl_a2a.py b/tests/comm/test_mnnvl_a2a.py index 01e3141b45..a006844bee 100644 --- a/tests/comm/test_mnnvl_a2a.py +++ b/tests/comm/test_mnnvl_a2a.py @@ -14,7 +14,6 @@ # limitations under the License. import traceback -import random import pytest import torch @@ -71,12 +70,20 @@ def generate_token_selected_experts( local_num_tokens: int, ep_size: int, num_experts_per_rank: int, top_k: int ) -> torch.Tensor: """Generate global expert IDs tensor, aligned with single-GPU test semantics.""" - return torch.randint( - 0, - ep_size * num_experts_per_rank, - (local_num_tokens, top_k), - dtype=torch.int32, - device="cuda", + if local_num_tokens == 0: + return torch.empty(0, top_k, dtype=torch.int32, device="cuda") + + # Select topk random experts for each token + def select_experts(items, topk): + perm = torch.randperm(items, dtype=torch.int32, device="cuda") + return perm[:topk] + + return torch.stack( + [ + select_experts(ep_size * num_experts_per_rank, top_k) + for _ in range(local_num_tokens) + ], + dim=0, ) @@ -95,6 +102,11 @@ def create_experts( Returns: experts: Tensor of shape [num_experts_per_rank, hidden_size, hidden_size] """ + + # A simpler to debug initialization + # identity = torch.eye(hidden_size, dtype=dtype, device=device) + # return torch.stack([identity * (i + 1) for i in range(num_experts_per_rank)], dim=0) + # For reproducibility, set the seed based on rank experts = torch.empty( (num_experts_per_rank, hidden_size, hidden_size), dtype=dtype, device=device @@ -141,6 +153,7 @@ def fake_moe( # Process each token for token_idx in range(num_tokens): + results = [] # For each expert selected for this token/ for k in range(top_k): expert_id = token_selected_experts[token_idx, k].item() @@ -157,7 +170,13 @@ def fake_moe( expert = experts[expert_id] scale = token_final_scales[token_idx, k] - processed_states[token_idx] += hidden_states[token_idx] @ expert * scale + results.append(hidden_states[token_idx] @ expert * scale) + + # Summing the results after is closer to the actual implementation as we do a tree reduction. + if results: + processed_states[token_idx] = torch.sum( + torch.stack(results, dim=0), dim=0, dtype=torch.float32 + ).to(processed_states.dtype) return processed_states @@ -542,8 +561,8 @@ def moe_a2a_dispatch_test_impl(distribution, top_k): ep_size = world_size if distribution == "random": - random.seed(0xD5) - all_num_tokens = [random.randint(1, 100) for _ in range(world_size)] + torch.manual_seed(0xD5) + all_num_tokens = torch.randint(1, 100, (world_size,)).tolist() elif distribution == "uniform": all_num_tokens = [50] * world_size else: @@ -557,12 +576,18 @@ def moe_a2a_dispatch_test_impl(distribution, top_k): pytest.skip("MNNVL not supported on this system") hidden_size = 1024 - num_experts_per_rank = 8 + num_experts_per_rank = max(8, (top_k + ep_size - 1) // ep_size) workspace_size_per_rank = 512 * 1024 * 1024 invalid_token_expert_id = -1 check_any_rank_failed() + # Check all ranks have the same all_num_tokens + gathered_all_num_tokens = comm.allgather(all_num_tokens) + assert all(i == all_num_tokens for i in gathered_all_num_tokens[1:]), ( + "all_num_tokens should be the same" + ) + # Run dispatch on this rank result = run_moe_a2a_dispatch_single_rank( ep_size, @@ -621,8 +646,6 @@ def moe_a2a_dispatch_test_impl(distribution, top_k): ("uniform", 2), # topk=2 with uniform distribution ("random", 8), # topk=8 with random distribution ("uniform", 8), # topk=8 with uniform distribution - ("random", 64), # topk=64 with random distribution - ("uniform", 64), # topk=64 with uniform distribution ], ) def test_moe_a2a_dispatch(distribution, top_k): @@ -639,7 +662,8 @@ def moe_a2a_dispatch_moe_combine_test_impl(distribution, top_k): ep_size = world_size if distribution == "random": - all_num_tokens = [random.randint(1, 100) for _ in range(world_size)] + torch.manual_seed(0xD5) + all_num_tokens = torch.randint(1, 100, (world_size,)).tolist() elif distribution == "uniform": all_num_tokens = [50] * world_size else: @@ -659,6 +683,12 @@ def moe_a2a_dispatch_moe_combine_test_impl(distribution, top_k): check_any_rank_failed() + # Check all ranks have the same all_num_tokens + gathered_all_num_tokens = comm.allgather(all_num_tokens) + assert all(i == all_num_tokens for i in gathered_all_num_tokens), ( + "all_num_tokens should be the same" + ) + hidden_size = 2880 # gpt-oss num_experts_per_rank = 8 workspace_size_per_rank = 512 * 1024 * 1024 @@ -768,7 +798,21 @@ def moe_a2a_dispatch_moe_combine_test_impl(distribution, top_k): ) # Verify against reference - torch.testing.assert_close(combined_output, reference_output, rtol=1e-2, atol=1e-2) + num_matches = ( + torch.isclose(combined_output, reference_output, atol=2e-2, rtol=2e-2) + .sum() + .item() + ) + match_rate = num_matches / combined_output.numel() + match_threshold = 0.99 + + # The accumulation order is not the same for the reference and the combine. For topk=8 this means that we see some accumulated errors for bf16. We tolerate up to 1% mismatches. + assert match_rate >= match_threshold, ( + f"Sample match rate {match_rate:.2%} is below threshold " + f"({combined_output.numel() - num_matches}/{combined_output.numel()} mismatches, expected >={match_threshold:.2%})" + ) + + # torch.testing.assert_close(combined_output, reference_output, rtol=6e-2, atol=6e-2) check_any_rank_failed() @@ -782,8 +826,6 @@ def moe_a2a_dispatch_moe_combine_test_impl(distribution, top_k): ("uniform", 2), # topk=2 with uniform distribution ("random", 8), # topk=8 with random distribution ("uniform", 8), # topk=8 with uniform distribution - ("random", 64), # topk=64 with random distribution - ("uniform", 64), # topk=64 with uniform distribution ], ) def test_moe_a2a_dispatch_moe_combine(distribution, top_k): From c574e367d01d5fd62f82c82410f505e774d2a675 Mon Sep 17 00:00:00 2001 From: djns99 <40156487+djns99@users.noreply.github.com> Date: Wed, 26 Nov 2025 17:28:20 +1300 Subject: [PATCH 13/25] Fix timeout logic --- .../kernels/communicationKernels/moeAlltoAllKernels.cu | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/csrc/nv_internal/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.cu b/csrc/nv_internal/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.cu index a448a09763..2751a129c6 100644 --- a/csrc/nv_internal/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.cu +++ b/csrc/nv_internal/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.cu @@ -403,7 +403,7 @@ __global__ void moeA2ADispatchKernel( #pragma unroll 1 // No unroll for (int peer_rank = lane_id; peer_rank < ep_size; peer_rank += warpSize) { bool flag_set = false; - [[maybe_unused]] clock_t s = clock64(); + [[maybe_unused]] auto s = clock64(); do { uint32_t* flag_ptr = &ptrs.completion_flags[rank_id][peer_rank]; uint32_t flag_value; @@ -417,7 +417,7 @@ __global__ void moeA2ADispatchKernel( rank_id, peer_rank, flag_value, expected_value, flag_ptr); #endif flag_set = flag_value == expected_value; - } while (!flag_set || check_timeout(s)); + } while (!flag_set && !check_timeout(s)); if (__builtin_expect(!flag_set, 0)) { printf("dispatch: ---Rank %d timed out waiting for completion flag from rank %d\n", @@ -708,7 +708,7 @@ __global__ void moeA2ACombineKernel( #pragma unroll 1 // No unroll for (int peer_rank = lane_id; peer_rank < ep_size; peer_rank += warpSize) { bool flag_set = false; - [[maybe_unused]] clock_t s = clock64(); + [[maybe_unused]] auto s = clock64(); do { uint32_t* flag_ptr = &ptrs.completion_flags[rank_id][peer_rank]; uint32_t flag_value; @@ -722,7 +722,7 @@ __global__ void moeA2ACombineKernel( rank_id, peer_rank, flag_value, expected_value, flag_ptr); #endif flag_set = flag_value == expected_value; - } while (!flag_set || check_timeout(s)); + } while (!flag_set && !check_timeout(s)); if (__builtin_expect(!flag_set, 0)) { printf("combine: ---Rank %d timed out waiting for completion flag from rank %d\n", rank_id, From f22e6a0371346b8a537aaecdedcb8e833e77ba73 Mon Sep 17 00:00:00 2001 From: djns99 <40156487+djns99@users.noreply.github.com> Date: Wed, 26 Nov 2025 17:29:53 +1300 Subject: [PATCH 14/25] Disable python steps for MPI tests --- scripts/task_test_multi_node_comm_kernels.sh | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/scripts/task_test_multi_node_comm_kernels.sh b/scripts/task_test_multi_node_comm_kernels.sh index 7a60f463a5..0d3f3d5c54 100644 --- a/scripts/task_test_multi_node_comm_kernels.sh +++ b/scripts/task_test_multi_node_comm_kernels.sh @@ -6,13 +6,13 @@ set -x : ${CUDA_VISIBLE_DEVICES:=0} # Clean Python bytecode cache to avoid stale imports (e.g., after module refactoring) -echo "Cleaning Python bytecode cache..." -find . -type d -name __pycache__ -exec rm -rf {} + 2>/dev/null || true -find . -type f -name '*.pyc' -delete 2>/dev/null || true -echo "Cache cleaned." -echo "" +# echo "Cleaning Python bytecode cache..." +# find . -type d -name __pycache__ -exec rm -rf {} + 2>/dev/null || true +# find . -type f -name '*.pyc' -delete 2>/dev/null || true +# echo "Cache cleaned." +# echo "" -pip install -e . -v +# pip install -e . -v pytest -s tests/comm/test_mnnvl_memory.py pytest -s tests/comm/test_trtllm_mnnvl_allreduce.py From 2baac547ed5c5765567f6067d5768497c1939de9 Mon Sep 17 00:00:00 2001 From: djns99 <40156487+djns99@users.noreply.github.com> Date: Thu, 27 Nov 2025 14:58:38 +1300 Subject: [PATCH 15/25] Standardise API name to match existing code better --- ...tllm_moe_a2a.py => trtllm_moe_alltoall.py} | 0 scripts/task_test_multi_node_comm_kernels.sh | 2 +- ...nnvl_a2a.py => test_mnnvl_moe_alltoall.py} | 0 tests/comm/test_trtllm_moe_alltoall.py | 852 ++++++++++++++++++ 4 files changed, 853 insertions(+), 1 deletion(-) rename flashinfer/comm/{trtllm_moe_a2a.py => trtllm_moe_alltoall.py} (100%) rename tests/comm/{test_mnnvl_a2a.py => test_mnnvl_moe_alltoall.py} (100%) create mode 100644 tests/comm/test_trtllm_moe_alltoall.py diff --git a/flashinfer/comm/trtllm_moe_a2a.py b/flashinfer/comm/trtllm_moe_alltoall.py similarity index 100% rename from flashinfer/comm/trtllm_moe_a2a.py rename to flashinfer/comm/trtllm_moe_alltoall.py diff --git a/scripts/task_test_multi_node_comm_kernels.sh b/scripts/task_test_multi_node_comm_kernels.sh index 0d3f3d5c54..0a7cad47ad 100644 --- a/scripts/task_test_multi_node_comm_kernels.sh +++ b/scripts/task_test_multi_node_comm_kernels.sh @@ -16,4 +16,4 @@ set -x pytest -s tests/comm/test_mnnvl_memory.py pytest -s tests/comm/test_trtllm_mnnvl_allreduce.py -pytest -s tests/comm/test_mnnvl_a2a.py +pytest -s tests/comm/test_mnnvl_moe_alltoall.py diff --git a/tests/comm/test_mnnvl_a2a.py b/tests/comm/test_mnnvl_moe_alltoall.py similarity index 100% rename from tests/comm/test_mnnvl_a2a.py rename to tests/comm/test_mnnvl_moe_alltoall.py diff --git a/tests/comm/test_trtllm_moe_alltoall.py b/tests/comm/test_trtllm_moe_alltoall.py new file mode 100644 index 0000000000..d03c8e1ad6 --- /dev/null +++ b/tests/comm/test_trtllm_moe_alltoall.py @@ -0,0 +1,852 @@ +# """ +# Copyright (c) 2024 by FlashInfer team. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# """ + +# import pytest +# import torch + +# from flashinfer.comm.mapping import Mapping + +# import flashinfer.comm.trtllm_moe_alltoall as tllm_moe_alltoall + + +# @pytest.fixture(autouse=True, scope="session") +# def setup_test_environment(): +# """Set up test environment and warm up JIT compilation.""" +# torch.manual_seed(0x1234) +# yield + + +# # Single GPU test parameters +# SINGLE_GPU_PARAMS = [ +# (902, 32768, 256, 8, torch.float16), # Large data, float16 +# (101, 288, 128, 4, torch.float16), # Medium data, float16 +# (902, 7168, 256, 8, torch.bfloat16), # Large data, bfloat16 +# (101, 288, 128, 4, torch.bfloat16), # Medium data, bfloat16 +# (10, 8, 8, 2, torch.bfloat16), # Small data, bfloat16 +# ] + +# MULTI_RANK_PARAMS = [ +# (2, 5, 8, torch.float16), # Small input, 2 ranks +# (4, 901, 32768, torch.bfloat16), # Large input, 4 ranks +# (8, 16384, 128, torch.float16), # Many small vectors, 8 ranks +# ] + +# PREPARE_INDICES_PARAMS = [ +# (0, 8, 256, 4, 3, False), # Rank 0, small config +# (1, 8, 256, 4, 3, True), # Rank 1, small config with real cumsum +# (7, 8, 256, 8, 1025, False), # High rank, medium config +# (7, 64, 1024, 32, 1029, True), # High rank, large config with real cumsum +# ] + +# LOCAL_GATHER_PARAMS = [ +# (0, 8, 256, 4, 3), # Rank 0, small config +# (7, 8, 256, 8, 32), # High rank, medium config +# (7, 64, 1024, 32, 1029), # High rank, large config +# ] + + +# # Real cross-GPU communication test parameters +# CROSS_GPU_PARAMS = [ +# (2, 100, 256, torch.float16), # 2 GPUs, 2 ranks +# (2, 300, 512, torch.bfloat16), # 2 GPUs, 2 ranks, larger data +# (4, 150, 256, torch.float16), # 4 GPUs, 4 ranks (if available) +# (4, 400, 512, torch.float16), # 4 GPUs, 4 ranks, larger data +# ] + + +# def get_available_gpu_count(): +# """Get the number of available GPUs.""" +# if not torch.cuda.is_available(): +# return 0 +# return torch.cuda.device_count() + + +# def requires_gpus(min_gpus): +# """Decorator to skip test if insufficient GPUs are available.""" + +# def decorator(func): +# return pytest.mark.skipif( +# get_available_gpu_count() < min_gpus, +# reason=f"Requires at least {min_gpus} GPUs, but only {get_available_gpu_count()} available", +# )(func) + +# return decorator + + +# @pytest.mark.parametrize( +# "num_tokens,vector_dim,num_experts,top_k,dtype", +# SINGLE_GPU_PARAMS, +# ) +# def test_moe_alltoall_single_gpu(num_tokens, vector_dim, num_experts, top_k, dtype): +# """Test MOE alltoall communication on single GPU.""" +# torch.cuda.set_device(0) +# # Create a random input tensor +# input_tensor1 = torch.randn( +# num_tokens, vector_dim, dtype=dtype, device=torch.device("cuda") +# ) +# input_tensor2 = torch.randn( +# num_tokens, vector_dim, dtype=dtype, device=torch.device("cuda") +# ) + +# token_selected_experts = torch.randperm( +# num_tokens, dtype=torch.int32, device=torch.device("cuda") +# )[:top_k] + +# # num_tokens * topk items * 2 buffers, with vector_dim * 2 bytes each. +# # num_tokens * topk token selected experts, 4 bytes each. +# # Add 1KiB for aux data. +# workspace_size = ( +# (num_tokens * top_k * vector_dim * 2) * 2 + (num_tokens * top_k * 4) + 1024 +# ) + +# mapping = Mapping(rank=0, world_size=1) +# moe_a2a = trtllm_moe_alltoall.MoeAlltoAll( +# mapping, +# num_tokens, +# top_k, +# num_experts, +# workspace_size, +# ) + +# output_tensor1, output_tensor2, token_selected_experts_output = moe_a2a.dispatch( +# token_selected_experts, +# [input_tensor1, input_tensor2, token_selected_experts], +# num_tokens, +# invalid_token_expert_id=0, # Tokens assigned to expert 0 are invalid +# expert_id_payload_index=2, +# ) + +# print( +# output_tensor1.shape, output_tensor2.shape, token_selected_experts_output.shape +# ) + + +# @pytest.mark.parametrize( +# "world_size,input_entry_per_rank,vector_dim,dtype", MULTI_RANK_PARAMS +# ) +# def test_moe_alltoall_multi_rank_single_gpu( +# world_size, input_entry_per_rank, vector_dim, dtype +# ): +# """Test MOE alltoall communication with multiple ranks on single GPU.""" +# torch.cuda.set_device(0) +# max_world_size = 8 +# assert world_size <= max_world_size, ( +# f"should run with world_size at most {max_world_size}" +# ) + +# # SM count is now set up globally in the fixture + +# # Create a random input tensor +# input_tensor = torch.randn( +# input_entry_per_rank * world_size, +# vector_dim, +# dtype=dtype, +# device=torch.device("cuda"), +# ) +# output_tensor = torch.zeros( +# input_entry_per_rank * world_size, +# vector_dim, +# dtype=dtype, +# device=torch.device("cuda"), +# ) +# ref_output_tensor = torch.zeros( +# input_entry_per_rank * world_size, +# vector_dim, +# dtype=dtype, +# device=torch.device("cuda"), +# ) +# target_rank_ids = torch.randint( +# 0, +# world_size, +# (input_entry_per_rank * world_size,), +# dtype=torch.int32, +# device=torch.device("cuda"), +# ) + +# input_tensors_all_ranks = list(torch.split(input_tensor, input_entry_per_rank)) +# target_rank_ids_all_ranks = list(torch.split(target_rank_ids, input_entry_per_rank)) + +# send_ids_all_ranks = [] +# send_counts_all_ranks = [] +# send_cumsum_all_ranks = [] +# send_start_end_all_ranks = [] + +# # each rank do its own local compute to get how to send data to other ranks. +# for rank in range(world_size): +# send_start_end = [] +# local_target_rank_ids = target_rank_ids_all_ranks[rank] +# sorted_local_target_rank_ids, local_send_id = torch.sort(local_target_rank_ids) +# local_send_id = local_send_id.to(torch.int32) +# padded_sorted_local_target_rank_ids = torch.cat( +# ( +# sorted_local_target_rank_ids, +# torch.arange( +# world_size, dtype=torch.int32, device=torch.device("cuda") +# ), +# ) +# ) +# unique_target_rank_ids, local_send_counts = torch.unique( +# padded_sorted_local_target_rank_ids, return_counts=True +# ) +# local_send_counts = local_send_counts.to(torch.int32) +# assert unique_target_rank_ids.numel() == world_size, ( +# "unique_target_rank_ids must be equal to world_size" +# ) +# local_send_counts -= 1 # remove padding +# local_send_cumsum = torch.cumsum(local_send_counts, dim=0).to(torch.int32) +# send_ids_all_ranks.append(local_send_id) +# send_counts_all_ranks.append(local_send_counts) +# send_cumsum_all_ranks.append(local_send_cumsum) +# local_send_cumsum_cpu = local_send_cumsum.cpu().tolist() +# for i in range(len(local_send_cumsum_cpu)): +# send_start_end.append( +# ( +# local_send_cumsum_cpu[i - 1] if i > 0 else 0, +# local_send_cumsum_cpu[i], +# ) +# ) +# send_start_end_all_ranks.append(send_start_end) + +# recv_ids_all_ranks = [] +# recv_cumsum_all_ranks = [] + +# output_tensors_all_ranks = [] + +# total_recv_all_ranks_cpu = [] +# output_indice_offset = 0 + +# output_start_current_rank = 0 +# # each rank do compute based on other ranks' send counts to get how to receive data from other ranks. +# for rank in range(world_size): +# local_recv_counts = torch.zeros( +# world_size, dtype=torch.int32, device=torch.device("cuda") +# ) +# for other_rank in range(world_size): +# local_recv_counts[other_rank] = send_counts_all_ranks[other_rank][rank] +# local_recv_count_pair = local_recv_counts[other_rank].cpu().item() +# send_rank_start_end = send_start_end_all_ranks[other_rank][rank] +# ref_output_tensor[ +# output_indice_offset : output_indice_offset + local_recv_count_pair +# ] = input_tensors_all_ranks[other_rank][ +# send_ids_all_ranks[other_rank][ +# send_rank_start_end[0] : send_rank_start_end[1] +# ] +# ] +# output_indice_offset += local_recv_count_pair +# local_recv_cumsum = torch.cumsum(local_recv_counts, dim=0).to(torch.int32) +# recv_cumsum_all_ranks.append(local_recv_cumsum) +# total_recv_count = local_recv_cumsum[-1].cpu() +# total_recv_all_ranks_cpu.append(total_recv_count) +# output_tensors_all_ranks.append( +# output_tensor[ +# output_start_current_rank : output_start_current_rank + total_recv_count +# ] +# ) +# output_start_current_rank += total_recv_count +# local_recv_ids = torch.arange( +# total_recv_count, dtype=torch.int32, device=torch.device("cuda") +# ) +# recv_ids_all_ranks.append(local_recv_ids) + +# cuda_streams_all_ranks = [torch.cuda.Stream() for _ in range(world_size)] + +# workspace_size = tllm_alltoall.get_moe_commworkspace_size_per_rank(world_size) +# all_workspaces = torch.zeros( +# world_size, workspace_size, dtype=torch.uint64, device=torch.device("cuda") +# ) + +# # Synchronize before starting parallel communication +# torch.cuda.synchronize() + +# # do alltoall in parallel +# for rank in range(world_size): +# with torch.cuda.stream(cuda_streams_all_ranks[rank]): +# tllm_alltoall.moe_comm( +# input_tensors_all_ranks[rank], +# send_cumsum_all_ranks[rank], +# send_ids_all_ranks[rank], +# output_tensors_all_ranks[rank], +# recv_cumsum_all_ranks[rank], +# recv_ids_all_ranks[rank], +# all_workspaces, +# rank, +# world_size, +# ) +# for rank in range(world_size): +# cuda_streams_all_ranks[rank].synchronize() + +# torch.testing.assert_close(output_tensor, ref_output_tensor, atol=1e-5, rtol=1e-5) + + +# @pytest.mark.parametrize( +# "ep_rank,ep_size,expert_count,top_k,max_token_count_per_rank,use_real_rank_token_count_cumsum", +# PREPARE_INDICES_PARAMS, +# ) +# def test_moe_alltoall_prepare_indices( +# ep_rank, +# ep_size, +# expert_count, +# top_k, +# max_token_count_per_rank, +# use_real_rank_token_count_cumsum, +# ): +# """Test MOE alltoall prepare indices functionality.""" +# torch.cuda.set_device(0) + +# def generate_references(): +# rank_token_count = max_token_count_per_rank +# if use_real_rank_token_count_cumsum: +# # Make sure we have at least 1 token in each rank except last rank +# rank_token_counts = [ +# max(1, torch.randint(1, max_token_count_per_rank + 1, (1,)).item()) +# for _ in range(ep_size - 1) +# ] +# rank_token_counts.append( +# max_token_count_per_rank +# ) # last rank has max tokens +# real_rank_token_count_cumsum = ( +# torch.tensor( +# rank_token_counts, dtype=torch.int32, device=torch.device("cuda") +# ) +# .cumsum(dim=0) +# .to(torch.int32) +# ) +# rank_token_count = rank_token_counts[ep_rank] +# else: +# real_rank_token_count_cumsum = None + +# # Generate target rank ids for this rank +# target_rank_ids = torch.randint( +# 0, +# ep_size, +# (rank_token_count, top_k), +# dtype=torch.int32, +# device=torch.device("cuda"), +# ) + +# if not use_real_rank_token_count_cumsum: +# gathered_target_rank_ids = torch.zeros( +# ep_size * max_token_count_per_rank, +# top_k, +# dtype=torch.int32, +# device=torch.device("cuda"), +# ) +# gathered_target_rank_ids[ +# ep_rank * max_token_count_per_rank : ep_rank * max_token_count_per_rank +# + rank_token_count +# ] = target_rank_ids +# else: +# total_tokens = real_rank_token_count_cumsum[-1].item() +# gathered_target_rank_ids = torch.zeros( +# total_tokens, top_k, dtype=torch.int32, device=torch.device("cuda") +# ) +# start_pos = ( +# 0 if ep_rank == 0 else real_rank_token_count_cumsum[ep_rank - 1].item() +# ) +# gathered_target_rank_ids[start_pos : start_pos + rank_token_count] = ( +# target_rank_ids +# ) + +# return gathered_target_rank_ids, real_rank_token_count_cumsum, target_rank_ids + +# gathered_target_rank_ids, real_rank_token_count_cumsum, target_rank_ids = ( +# generate_references() +# ) + +# ( +# local_gather_indices, +# send_rank_count_cumsum, +# send_rank_local_indices, +# recv_rank_count_cumsum, +# recv_rank_local_indices, +# backward_recv_rank_local_indices, +# ) = tllm_alltoall.moe_comm_prepare_indices( +# gathered_target_rank_ids, +# real_rank_token_count_cumsum, +# max_token_count_per_rank, +# expert_count, +# top_k, +# ep_rank, +# ep_size, +# ) + +# # Validate shapes +# assert local_gather_indices.shape[0] <= max_token_count_per_rank * ep_size +# assert send_rank_count_cumsum.shape[0] == ep_size +# assert recv_rank_count_cumsum.shape[0] == ep_size +# assert send_rank_local_indices.shape[0] <= max_token_count_per_rank * max( +# ep_size, top_k +# ) +# assert recv_rank_local_indices.shape[0] <= max_token_count_per_rank * ep_size +# assert backward_recv_rank_local_indices.shape[0] <= max_token_count_per_rank * max( +# ep_size, top_k +# ) + +# # Basic validation - cumulative sums should be non-decreasing +# assert torch.all(send_rank_count_cumsum[1:] >= send_rank_count_cumsum[:-1]) +# assert torch.all(recv_rank_count_cumsum[1:] >= recv_rank_count_cumsum[:-1]) + + +# @pytest.mark.parametrize( +# "ep_rank,ep_size,expert_count,top_k,max_token_count_per_rank", LOCAL_GATHER_PARAMS +# ) +# def test_moe_local_gather( +# ep_rank, +# ep_size, +# expert_count, +# top_k, +# max_token_count_per_rank, +# ): +# """Test MOE local gather functionality.""" +# torch.cuda.set_device(0) + +# # Generate test data using the original method +# rank_token_count_cumsum = torch.randint( +# 0, +# max_token_count_per_rank + 1, +# (ep_size,), +# dtype=torch.int32, +# device=torch.device("cuda"), +# ) +# rank_token_count_cumsum = torch.cumsum(rank_token_count_cumsum, dim=0).to( +# torch.int32 +# ) +# local_token_count = rank_token_count_cumsum[ep_size - 1].cpu().item() +# local_max_token_count = max_token_count_per_rank * ep_size +# local_gather_indices = torch.randint( +# 0, +# max_token_count_per_rank * ep_size, +# (local_max_token_count,), +# dtype=torch.int32, +# device=torch.device("cuda"), +# ) + +# gathered_expert_ids = torch.randint( +# 0, +# expert_count, +# (max_token_count_per_rank * ep_size, top_k), +# dtype=torch.int32, +# device=torch.device("cuda"), +# ) +# gathered_scales = torch.rand( +# (max_token_count_per_rank * ep_size, top_k), +# dtype=torch.float32, +# device=torch.device("cuda"), +# ) + +# ref_local_expert_ids = torch.zeros( +# local_max_token_count, top_k, dtype=torch.int32, device=torch.device("cuda") +# ) +# ref_local_scales = torch.zeros( +# local_max_token_count, +# top_k, +# dtype=torch.float32, +# device=torch.device("cuda"), +# ) + +# # compute reference +# ref_local_expert_ids += expert_count +# valid_local_gather_indices = local_gather_indices[:local_token_count] +# ref_local_expert_ids[:local_token_count] = gathered_expert_ids[ +# valid_local_gather_indices +# ] +# ref_local_scales[:local_token_count] = gathered_scales[valid_local_gather_indices] + +# local_expert_ids = torch.empty( +# local_max_token_count, top_k, dtype=torch.int32, device=torch.device("cuda") +# ) +# local_scales = torch.empty( +# local_max_token_count, +# top_k, +# dtype=torch.float32, +# device=torch.device("cuda"), +# ) + +# tllm_alltoall.moe_local_gather( +# rank_token_count_cumsum, +# local_gather_indices, +# gathered_expert_ids, +# gathered_scales, +# local_expert_ids, +# local_scales, +# max_token_count_per_rank, +# expert_count, +# top_k, +# ep_rank, +# ep_size, +# ) + +# assert torch.equal(local_expert_ids, ref_local_expert_ids) +# assert torch.equal(local_scales, ref_local_scales) + + +# @pytest.mark.parametrize( +# "ep_rank, ep_size, expert_count, slot_count, top_k, max_token_count_per_rank", +# [ +# (0, 2, 16, 20, 8, 512), +# (0, 2, 16, 16, 3, 300), +# (0, 4, 20, 24, 8, 4000), +# (0, 8, 96, 96, 8, 1000), +# (3, 8, 128, 128, 8, 1000), +# (3, 8, 128, 144, 8, 1), +# (0, 4, 72, 80, 4, 2256), +# (0, 4, 72, 80, 6, 3333), +# # Hang with stream count > 8 +# # (0, 9, 90, 8, 100), +# ], +# ) +# def test_moe_alltoall_prepare( +# ep_rank: int, +# ep_size: int, +# expert_count: int, +# slot_count: int, +# top_k: int, +# max_token_count_per_rank: int, +# ): +# torch.cuda.set_device(0) + +# cpu_expert_ids_all_ranks_lists = [] +# cpu_token_count_lists = [] +# cpu_scales_all_ranks_lists = [] +# for _ in range(ep_size): +# token_count = torch.randint( +# max_token_count_per_rank // 2, +# max_token_count_per_rank + 1, +# (1,), +# dtype=torch.int32, +# device=torch.device("cpu"), +# ) +# token_count = 1 if token_count == 0 else token_count + +# token_count = max_token_count_per_rank + +# cpu_expert_ids_all_ranks_lists.append( +# torch.randint( +# 0, +# slot_count, +# (token_count, top_k), +# dtype=torch.int32, +# device=torch.device("cpu"), +# ) +# ) + +# cpu_scales_all_ranks_lists.append( +# torch.zeros( +# token_count, top_k, dtype=torch.float32, device=torch.device("cpu") +# ) +# + 0.5 +# ) + +# cpu_token_count_lists.append(token_count) + +# def compute_target_rank(expert_id): +# ep_per_rank = slot_count // ep_size +# return expert_id // ep_per_rank + +# def generate_references(): +# ref_prepared_local_expert_ids = [] +# ref_prepared_local_scales = [] +# ref_local_send_rank_count_cumsum = [0] * ep_size +# ref_local_recv_rank_count_cumsum = [0] * ep_size +# ref_local_recv_rank_indices = [] + +# local_token_count = cpu_token_count_lists[ep_rank] +# send_token_count_to_ranks = [0] * ep_size + +# # send part +# for token_id in range(local_token_count): +# target_set = set() +# for pos in range(top_k): +# expert_id = int(cpu_expert_ids_all_ranks_lists[ep_rank][token_id][pos]) +# target_rank_id = compute_target_rank(expert_id) +# target_set.add(target_rank_id) + +# for target_rank_id in target_set: +# send_token_count_to_ranks[target_rank_id] += 1 + +# total_send_token_count = 0 +# for rank in range(ep_size): +# # print(f'rank: {rank}, send_token_count_to_ranks[rank]: {send_token_count_to_ranks[rank]}') +# base = ref_local_send_rank_count_cumsum[rank - 1] if rank > 0 else 0 +# ref_local_send_rank_count_cumsum[rank] = ( +# send_token_count_to_ranks[rank] + base +# ) +# total_send_token_count += send_token_count_to_ranks[rank] + +# ref_local_backward_send_rank_indices = [0] * (total_send_token_count) +# ref_local_send_rank_indices = [0] * (total_send_token_count) + +# current_send_token_ids = [0] * ep_size +# for token_id in range(local_token_count): +# target_set = set() +# for pos in range(top_k): +# expert_id = int(cpu_expert_ids_all_ranks_lists[ep_rank][token_id][pos]) +# target_rank_id = compute_target_rank(expert_id) +# if target_rank_id not in target_set: +# cumsum_before = ( +# 0 +# if target_rank_id == 0 +# else ref_local_send_rank_count_cumsum[target_rank_id - 1] +# ) +# send_index = cumsum_before + current_send_token_ids[target_rank_id] +# ref_local_send_rank_indices[send_index] = token_id +# ref_local_backward_send_rank_indices[send_index] = ( +# token_id * top_k + pos +# ) +# current_send_token_ids[target_rank_id] += 1 +# target_set.add(target_rank_id) + +# # receive part +# total_recv_token_count = 0 +# for rank in range(ep_size): +# token_count = cpu_token_count_lists[rank] +# current_recv_token_count = 0 +# for token_id in range(token_count): +# token_is_received = False +# for pos in range(top_k): +# expert_id = int(cpu_expert_ids_all_ranks_lists[rank][token_id][pos]) +# sf = cpu_scales_all_ranks_lists[rank][token_id][pos] +# target_rank_id = compute_target_rank(expert_id) +# if target_rank_id == ep_rank: +# if not token_is_received: +# token_is_received = True +# ref_prepared_local_expert_ids.append([slot_count] * top_k) +# ref_prepared_local_scales.append([0.0] * top_k) +# ref_prepared_local_expert_ids[-1][pos] = expert_id +# ref_prepared_local_scales[-1][pos] = sf +# if token_is_received: +# ref_local_recv_rank_indices.append(total_recv_token_count) +# total_recv_token_count += 1 +# current_recv_token_count += 1 +# ref_local_recv_rank_count_cumsum[rank] = ( +# current_recv_token_count +# if rank == 0 +# else ref_local_recv_rank_count_cumsum[rank - 1] +# + current_recv_token_count +# ) + +# return ( +# ref_prepared_local_expert_ids, +# ref_prepared_local_scales, +# ref_local_send_rank_count_cumsum, +# ref_local_send_rank_indices, +# ref_local_recv_rank_count_cumsum, +# ref_local_recv_rank_indices, +# ref_local_backward_send_rank_indices, +# total_recv_token_count, +# ) + +# ( +# ref_prepared_local_expert_ids, +# ref_prepared_local_scales, +# ref_local_send_rank_count_cumsum, +# ref_local_send_rank_indices, +# ref_local_recv_rank_count_cumsum, +# ref_local_recv_rank_indices, +# ref_local_backward_send_rank_indices, +# total_recv_token_count, +# ) = generate_references() + +# cpu_experter_count_lists = [] +# for rank in range(ep_size): +# local_expert_count = [] +# for i in range(expert_count): +# local_expert_count.append(rank * expert_count + i) +# cpu_experter_count_lists.append(torch.IntTensor(local_expert_count)) + +# # expert_ids_all_ranks = torch.tensor(cpu_expert_ids_all_ranks_lists).cuda() +# expert_ids_all_ranks = [ +# cpu_expert_ids_all_ranks_lists[i].cuda() for i in range(ep_size) +# ] +# # scales_all_ranks = torch.FloatTensor(cpu_scales_all_ranks_lists).cuda() +# scales_all_ranks = [cpu_scales_all_ranks_lists[i].cuda() for i in range(ep_size)] + +# experter_count_lists = [cpu_experter_count_lists[i].cuda() for i in range(ep_size)] + +# cuda_streams_all_ranks = [torch.cuda.Stream() for _ in range(ep_size)] + +# workspace_size = tllm_alltoall.get_moe_prepare_workspace_size_per_rank(ep_size) + +# all_workspaces = torch.zeros( +# ep_size, workspace_size, dtype=torch.uint64, device=torch.device("cuda") +# ) + +# stream = torch.cuda.Stream() +# cur_stream = torch.cuda.current_stream() +# stream.wait_stream(cur_stream) +# with torch.cuda.stream(stream): +# tllm_alltoall.moe_prepare( +# expert_ids_all_ranks[0], +# scales_all_ranks[0], +# experter_count_lists[0], +# all_workspaces, +# max_token_count_per_rank, +# 0, +# 1, +# expert_count, +# slot_count, +# top_k, +# ) +# cur_stream.wait_stream(stream) + +# # Make torch alloc tensor to avoid cuda sync +# prepared_local_experts = [] +# prepared_local_scales = [] +# local_send_rank_count_cumsum = [] +# local_send_rank_indices = [] +# local_recv_rank_count_cumsum = [] +# local_recv_rank_indices = [] +# backward_local_recv_rank_indices = [] +# for _ in range(ep_size): +# prepared_local_experts.append( +# torch.empty( +# max_token_count_per_rank * ep_size, +# top_k, +# dtype=torch.int32, +# device=torch.device("cuda"), +# ) +# ) +# prepared_local_scales.append( +# torch.empty( +# max_token_count_per_rank * ep_size, +# top_k, +# dtype=torch.float32, +# device=torch.device("cuda"), +# ) +# ) +# local_send_rank_count_cumsum.append( +# torch.empty(ep_size, dtype=torch.int32, device=torch.device("cuda")) +# ) +# local_send_rank_indices.append( +# torch.empty( +# max_token_count_per_rank * ep_size, +# dtype=torch.int32, +# device=torch.device("cuda"), +# ) +# ) +# local_recv_rank_count_cumsum.append( +# torch.empty(0, dtype=torch.int32, device=torch.device("cuda")) +# ) +# local_recv_rank_indices.append( +# torch.empty(0, dtype=torch.int32, device=torch.device("cuda")) +# ) +# backward_local_recv_rank_indices.append( +# torch.empty(0, dtype=torch.int32, device=torch.device("cuda")) +# ) + +# prepared_local_experts = None +# prepared_local_scales = None +# local_send_rank_count_cumsum = None +# local_send_rank_indices = None +# local_recv_rank_count_cumsum = None +# local_recv_rank_indices = None +# backward_local_recv_rank_indices = None + +# # reset the workspace +# all_workspaces = torch.zeros( +# ep_size, workspace_size, dtype=torch.uint64, device=torch.device("cuda") +# ) + +# # do prepare in parallel +# cur_stream = torch.cuda.current_stream() +# for rank in range(ep_size): +# s = cuda_streams_all_ranks[rank] +# s.wait_stream(cur_stream) +# with torch.cuda.stream(s): +# if rank == ep_rank: +# ( +# prepared_local_experts, +# prepared_local_scales, +# local_send_rank_count_cumsum, +# local_send_rank_indices, +# local_recv_rank_count_cumsum, +# local_recv_rank_indices, +# backward_local_recv_rank_indices, +# gathered_expert_statics, +# ) = tllm_alltoall.moe_prepare( +# expert_ids_all_ranks[rank], +# scales_all_ranks[rank], +# experter_count_lists[rank], +# all_workspaces, +# max_token_count_per_rank, +# rank, +# ep_size, +# expert_count, +# slot_count, +# top_k, +# ) +# else: +# tllm_alltoall.moe_prepare( +# expert_ids_all_ranks[rank], +# scales_all_ranks[rank], +# experter_count_lists[rank], +# all_workspaces, +# max_token_count_per_rank, +# rank, +# ep_size, +# expert_count, +# slot_count, +# top_k, +# ) +# for rank in range(ep_size): +# cuda_streams_all_ranks[rank].synchronize() + +# prepared_local_experts_cpu = prepared_local_experts[:total_recv_token_count].cpu() +# prepared_local_scales_cpu = prepared_local_scales[:total_recv_token_count].cpu() +# for i in range(total_recv_token_count): +# for j in range(top_k): +# expert_id = int(prepared_local_experts_cpu[i][j]) +# assert expert_id >= 0 and expert_id <= slot_count +# if expert_id < slot_count: +# assert compute_target_rank(expert_id) == ep_rank +# scale = float(prepared_local_scales_cpu[i][j]) +# assert scale > 1e-6 + +# gathered_expert_statics_cpu = gathered_expert_statics.cpu() +# for rank in range(ep_size): +# for i in range(expert_count): +# assert int(gathered_expert_statics_cpu[rank][i]) == rank * expert_count + i + +# ref_local_send_rank_count_cumsum = torch.IntTensor(ref_local_send_rank_count_cumsum) +# assert torch.equal( +# local_send_rank_count_cumsum.cpu(), ref_local_send_rank_count_cumsum +# ) + +# local_send_rank_indices = local_send_rank_indices.cpu() +# backward_local_recv_rank_indices = backward_local_recv_rank_indices.cpu() +# for i in range(ep_size): +# base = 0 if i == 0 else ref_local_send_rank_count_cumsum[i - 1] +# for j in range(base, ref_local_send_rank_count_cumsum[i]): +# token_id = local_send_rank_indices[j] +# lane_id = backward_local_recv_rank_indices[j] - token_id * top_k +# expert_id = int(cpu_expert_ids_all_ranks_lists[ep_rank][token_id][lane_id]) +# assert compute_target_rank(expert_id) == i + +# ref_local_recv_rank_count_cumsum = torch.IntTensor(ref_local_recv_rank_count_cumsum) +# assert torch.equal( +# local_recv_rank_count_cumsum[: ref_local_recv_rank_count_cumsum.size(0)].cpu(), +# ref_local_recv_rank_count_cumsum, +# ) + +# ref_local_recv_rank_indices = torch.IntTensor(ref_local_recv_rank_indices) +# assert torch.equal( +# local_recv_rank_indices[: ref_local_recv_rank_indices.size(0)].cpu(), +# ref_local_recv_rank_indices, +# ) + + +# if __name__ == "__main__": +# pytest.main([__file__, "-v"]) From 6ac511e68db2c99266fea3952fff7a0e986705ff Mon Sep 17 00:00:00 2001 From: djns99 <40156487+djns99@users.noreply.github.com> Date: Thu, 27 Nov 2025 19:15:06 +1300 Subject: [PATCH 16/25] Enhance tests and add convenience APIs for more general usage --- .../moeAlltoAllKernels.cu | 1 + csrc/trtllm_moe_a2a.cu | 10 + flashinfer/comm/__init__.py | 13 +- flashinfer/comm/trtllm_moe_alltoall.py | 118 +- tests/comm/test_trtllm_moe_alltoall.py | 1126 ++++------------- 5 files changed, 382 insertions(+), 886 deletions(-) diff --git a/csrc/nv_internal/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.cu b/csrc/nv_internal/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.cu index 2751a129c6..d18ac0ea2f 100644 --- a/csrc/nv_internal/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.cu +++ b/csrc/nv_internal/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.cu @@ -29,6 +29,7 @@ namespace tensorrt_llm::kernels::mnnvl_throughput { #define ENABLE_DEBUG_PRINT 0 #define DISABLE_SYNC_FOR_PROFILING 0 + #ifndef DISABLE_TIMEOUT #define DISABLE_TIMEOUT 0 #endif diff --git a/csrc/trtllm_moe_a2a.cu b/csrc/trtllm_moe_a2a.cu index 0407eb201a..7c0099c370 100644 --- a/csrc/trtllm_moe_a2a.cu +++ b/csrc/trtllm_moe_a2a.cu @@ -87,6 +87,15 @@ fi_throughput::MoeA2ADataOffsets calculateOffsets(int epSize, int maxNumTokens) return offsets; } +int64_t getMoeA2AWorkspaceSizePerRank(int64_t epSize, int64_t maxNumTokens, + int64_t maxPayloadSizePerElement) { + int64_t metadata_size = + calculateOffsets(static_cast(epSize), + static_cast(maxNumTokens))[fi_throughput::PAYLOAD_DATA_OFFSET_INDEX]; + int64_t payload_size = maxNumTokens * maxPayloadSizePerElement; + return alignOffset(metadata_size + payload_size); +} + Tensor moeA2AInitializeOp(TensorView workspace, int64_t epRank, int64_t epSize, int64_t maxNumTokens) { CHECK_INPUT_TYPE(workspace, dl_uint8); @@ -392,6 +401,7 @@ Tuple, Array> getMoeA2AMetaInfoIndexPairs() { } // namespace +TVM_FFI_DLL_EXPORT_TYPED_FUNC(moe_a2a_get_workspace_size_per_rank, getMoeA2AWorkspaceSizePerRank); TVM_FFI_DLL_EXPORT_TYPED_FUNC(moe_a2a_initialize, moeA2AInitializeOp); TVM_FFI_DLL_EXPORT_TYPED_FUNC(moe_a2a_dispatch, moeA2ADispatchOp); TVM_FFI_DLL_EXPORT_TYPED_FUNC(moe_a2a_combine, moeA2ACombineOp); diff --git a/flashinfer/comm/__init__.py b/flashinfer/comm/__init__.py index 7aa1ff1cfe..496050cd00 100644 --- a/flashinfer/comm/__init__.py +++ b/flashinfer/comm/__init__.py @@ -40,12 +40,15 @@ from .vllm_ar import register_graph_buffers as vllm_register_graph_buffers # MNNVL A2A (Throughput Backend) -from .trtllm_moe_a2a import MoeAlltoAll as MoeAlltoAll -from .trtllm_moe_a2a import moe_a2a_combine as moe_a2a_combine -from .trtllm_moe_a2a import moe_a2a_dispatch as moe_a2a_dispatch -from .trtllm_moe_a2a import moe_a2a_initialize as moe_a2a_initialize -from .trtllm_moe_a2a import ( +from .trtllm_moe_alltoall import MoeAlltoAll as MoeAlltoAll +from .trtllm_moe_alltoall import moe_a2a_combine as moe_a2a_combine +from .trtllm_moe_alltoall import moe_a2a_dispatch as moe_a2a_dispatch +from .trtllm_moe_alltoall import moe_a2a_initialize as moe_a2a_initialize +from .trtllm_moe_alltoall import ( moe_a2a_sanitize_expert_ids as moe_a2a_sanitize_expert_ids, ) +from .trtllm_moe_alltoall import ( + moe_a2a_get_workspace_size_per_rank as moe_a2a_get_workspace_size_per_rank, +) # from .mnnvl import MnnvlMemory, MnnvlMoe, MoEAlltoallInfo diff --git a/flashinfer/comm/trtllm_moe_alltoall.py b/flashinfer/comm/trtllm_moe_alltoall.py index df47a8f0fd..3bedc57e01 100644 --- a/flashinfer/comm/trtllm_moe_alltoall.py +++ b/flashinfer/comm/trtllm_moe_alltoall.py @@ -168,12 +168,26 @@ def moe_a2a_get_metainfo_index_pairs(): """ return module.moe_a2a_get_metainfo_index_pairs() + @register_custom_op( + "flashinfer::moe_a2a_get_workspace_size_per_rank", + mutates_args=[], + ) + def moe_a2a_get_workspace_size_per_rank( + ep_size: int, + max_num_tokens: int, + payload_size_per_element: int, + ): + return module.moe_a2a_get_workspace_size_per_rank( + ep_size, max_num_tokens, payload_size_per_element + ) + return SimpleNamespace( moe_a2a_initialize=moe_a2a_initialize, moe_a2a_dispatch=moe_a2a_dispatch, moe_a2a_combine=moe_a2a_combine, moe_a2a_sanitize_expert_ids=moe_a2a_sanitize_expert_ids, moe_a2a_get_metainfo_index_pairs=moe_a2a_get_metainfo_index_pairs, + moe_a2a_get_workspace_size_per_rank=moe_a2a_get_workspace_size_per_rank, ) @@ -300,6 +314,16 @@ def moe_a2a_sanitize_expert_ids( ) +def moe_a2a_get_workspace_size_per_rank( + ep_size: int, + max_num_tokens: int, + payload_size_per_element: int, +): + return get_mnnvl_a2a_module().moe_a2a_get_workspace_size_per_rank( + ep_size, max_num_tokens, payload_size_per_element + ) + + class MoeAlltoAll: """ Manages MoE All-to-All operations with proper workspace allocation and synchronization. @@ -314,7 +338,40 @@ class MoeAlltoAll: """ # Single shared workspace across the process - _WORKSPACE: Optional[dict] = None + # _WORKSPACE: Optional[dict] = None + _WORKSPACE_CACHE: dict[tuple[int, int, int, int], dict] = {} + + @classmethod + def get_workspace( + cls, + workspace_size_per_rank: int, + ep_rank: int, + ep_size: int, + max_num_tokens: int, + mapping: Mapping, + ) -> dict: + key = (workspace_size_per_rank, ep_rank, ep_size, max_num_tokens) + if key in cls._WORKSPACE_CACHE: + return cls._WORKSPACE_CACHE[key] + else: + mnnvl_mem = MnnvlMemory(mapping, workspace_size_per_rank) + workspace = mnnvl_mem.as_torch_strided_tensor(torch.uint8) + metainfo = moe_a2a_initialize( + workspace, + ep_rank, + ep_size, + max_num_tokens, + ) + cls._WORKSPACE_CACHE[key] = { + "workspace_size_per_rank": workspace_size_per_rank, + "max_num_tokens": max_num_tokens, + "ep_rank": ep_rank, + "ep_size": ep_size, + "mnnvl_mem": mnnvl_mem, + "workspace": workspace, + "metainfo": metainfo, + } + return cls._WORKSPACE_CACHE[key] # Metainfo index constants (loaded dynamically from C++) # These offsets allow accessing internal workspace data for testing/debugging @@ -379,40 +436,42 @@ def __init__( raise ValueError("num_experts must be a positive int") # Allocate or reuse workspace - if self._WORKSPACE is None: - mnnvl_mem = MnnvlMemory(mapping, workspace_size_per_rank) - workspace = mnnvl_mem.as_torch_strided_tensor(torch.uint8) - metainfo = moe_a2a_initialize( - workspace, - self.ep_rank, - self.ep_size, - self.max_num_tokens, - ) - MoeAlltoAll._WORKSPACE = { - "workspace_size_per_rank": workspace_size_per_rank, - "max_num_tokens": self.max_num_tokens, - "ep_rank": self.ep_rank, - "ep_size": self.ep_size, - "mnnvl_mem": mnnvl_mem, - "workspace": workspace, - "metainfo": metainfo, - } - else: - # Validate workspace compatibility - assert ( - self._WORKSPACE["workspace_size_per_rank"] == workspace_size_per_rank - ), "Workspace size mismatch" - assert self._WORKSPACE["max_num_tokens"] == self.max_num_tokens, ( - "Max tokens mismatch" - ) - assert self._WORKSPACE["ep_rank"] == self.ep_rank, "EP rank mismatch" - assert self._WORKSPACE["ep_size"] == self.ep_size, "EP size mismatch" + self._WORKSPACE = self.get_workspace( + workspace_size_per_rank, + self.ep_rank, + self.ep_size, + self.max_num_tokens, + mapping, + ) + # Validate workspace compatibility + assert self._WORKSPACE["workspace_size_per_rank"] == workspace_size_per_rank, ( + "Workspace size mismatch" + ) + assert self._WORKSPACE["max_num_tokens"] == self.max_num_tokens, ( + "Max tokens mismatch" + ) + assert self._WORKSPACE["ep_rank"] == self.ep_rank, "EP rank mismatch" + assert self._WORKSPACE["ep_size"] == self.ep_size, "EP size mismatch" self.mnnvl_mem = self._WORKSPACE["mnnvl_mem"] self.workspace = self._WORKSPACE["workspace"] self.metainfo = self._WORKSPACE["metainfo"] self._state = _A2AState() + def _reset_workspace(self): + """Reset the workspace to free up its state. This is mainly used for testing. Use this with caution. This object is no longer usable after this.""" + torch.cuda.synchronize() + del self._WORKSPACE + del self._WORKSPACE_CACHE[ + ( + self.workspace_size_per_rank, + self.ep_rank, + self.ep_size, + self.max_num_tokens, + ) + ] + self._state.phase = "deleted" + def dispatch( self, token_selected_experts: torch.Tensor, @@ -556,4 +615,5 @@ def get_combine_payload_tensor_in_workspace( "moe_a2a_dispatch", "moe_a2a_combine", "moe_a2a_sanitize_expert_ids", + "moe_a2a_get_workspace_size_per_rank", ] diff --git a/tests/comm/test_trtllm_moe_alltoall.py b/tests/comm/test_trtllm_moe_alltoall.py index d03c8e1ad6..87b341a788 100644 --- a/tests/comm/test_trtllm_moe_alltoall.py +++ b/tests/comm/test_trtllm_moe_alltoall.py @@ -1,852 +1,274 @@ -# """ -# Copyright (c) 2024 by FlashInfer team. - -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at - -# http://www.apache.org/licenses/LICENSE-2.0 - -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# """ - -# import pytest -# import torch - -# from flashinfer.comm.mapping import Mapping - -# import flashinfer.comm.trtllm_moe_alltoall as tllm_moe_alltoall - - -# @pytest.fixture(autouse=True, scope="session") -# def setup_test_environment(): -# """Set up test environment and warm up JIT compilation.""" -# torch.manual_seed(0x1234) -# yield - - -# # Single GPU test parameters -# SINGLE_GPU_PARAMS = [ -# (902, 32768, 256, 8, torch.float16), # Large data, float16 -# (101, 288, 128, 4, torch.float16), # Medium data, float16 -# (902, 7168, 256, 8, torch.bfloat16), # Large data, bfloat16 -# (101, 288, 128, 4, torch.bfloat16), # Medium data, bfloat16 -# (10, 8, 8, 2, torch.bfloat16), # Small data, bfloat16 -# ] - -# MULTI_RANK_PARAMS = [ -# (2, 5, 8, torch.float16), # Small input, 2 ranks -# (4, 901, 32768, torch.bfloat16), # Large input, 4 ranks -# (8, 16384, 128, torch.float16), # Many small vectors, 8 ranks -# ] - -# PREPARE_INDICES_PARAMS = [ -# (0, 8, 256, 4, 3, False), # Rank 0, small config -# (1, 8, 256, 4, 3, True), # Rank 1, small config with real cumsum -# (7, 8, 256, 8, 1025, False), # High rank, medium config -# (7, 64, 1024, 32, 1029, True), # High rank, large config with real cumsum -# ] - -# LOCAL_GATHER_PARAMS = [ -# (0, 8, 256, 4, 3), # Rank 0, small config -# (7, 8, 256, 8, 32), # High rank, medium config -# (7, 64, 1024, 32, 1029), # High rank, large config -# ] - - -# # Real cross-GPU communication test parameters -# CROSS_GPU_PARAMS = [ -# (2, 100, 256, torch.float16), # 2 GPUs, 2 ranks -# (2, 300, 512, torch.bfloat16), # 2 GPUs, 2 ranks, larger data -# (4, 150, 256, torch.float16), # 4 GPUs, 4 ranks (if available) -# (4, 400, 512, torch.float16), # 4 GPUs, 4 ranks, larger data -# ] - - -# def get_available_gpu_count(): -# """Get the number of available GPUs.""" -# if not torch.cuda.is_available(): -# return 0 -# return torch.cuda.device_count() - - -# def requires_gpus(min_gpus): -# """Decorator to skip test if insufficient GPUs are available.""" - -# def decorator(func): -# return pytest.mark.skipif( -# get_available_gpu_count() < min_gpus, -# reason=f"Requires at least {min_gpus} GPUs, but only {get_available_gpu_count()} available", -# )(func) - -# return decorator - - -# @pytest.mark.parametrize( -# "num_tokens,vector_dim,num_experts,top_k,dtype", -# SINGLE_GPU_PARAMS, -# ) -# def test_moe_alltoall_single_gpu(num_tokens, vector_dim, num_experts, top_k, dtype): -# """Test MOE alltoall communication on single GPU.""" -# torch.cuda.set_device(0) -# # Create a random input tensor -# input_tensor1 = torch.randn( -# num_tokens, vector_dim, dtype=dtype, device=torch.device("cuda") -# ) -# input_tensor2 = torch.randn( -# num_tokens, vector_dim, dtype=dtype, device=torch.device("cuda") -# ) - -# token_selected_experts = torch.randperm( -# num_tokens, dtype=torch.int32, device=torch.device("cuda") -# )[:top_k] - -# # num_tokens * topk items * 2 buffers, with vector_dim * 2 bytes each. -# # num_tokens * topk token selected experts, 4 bytes each. -# # Add 1KiB for aux data. -# workspace_size = ( -# (num_tokens * top_k * vector_dim * 2) * 2 + (num_tokens * top_k * 4) + 1024 -# ) - -# mapping = Mapping(rank=0, world_size=1) -# moe_a2a = trtllm_moe_alltoall.MoeAlltoAll( -# mapping, -# num_tokens, -# top_k, -# num_experts, -# workspace_size, -# ) - -# output_tensor1, output_tensor2, token_selected_experts_output = moe_a2a.dispatch( -# token_selected_experts, -# [input_tensor1, input_tensor2, token_selected_experts], -# num_tokens, -# invalid_token_expert_id=0, # Tokens assigned to expert 0 are invalid -# expert_id_payload_index=2, -# ) - -# print( -# output_tensor1.shape, output_tensor2.shape, token_selected_experts_output.shape -# ) - - -# @pytest.mark.parametrize( -# "world_size,input_entry_per_rank,vector_dim,dtype", MULTI_RANK_PARAMS -# ) -# def test_moe_alltoall_multi_rank_single_gpu( -# world_size, input_entry_per_rank, vector_dim, dtype -# ): -# """Test MOE alltoall communication with multiple ranks on single GPU.""" -# torch.cuda.set_device(0) -# max_world_size = 8 -# assert world_size <= max_world_size, ( -# f"should run with world_size at most {max_world_size}" -# ) - -# # SM count is now set up globally in the fixture - -# # Create a random input tensor -# input_tensor = torch.randn( -# input_entry_per_rank * world_size, -# vector_dim, -# dtype=dtype, -# device=torch.device("cuda"), -# ) -# output_tensor = torch.zeros( -# input_entry_per_rank * world_size, -# vector_dim, -# dtype=dtype, -# device=torch.device("cuda"), -# ) -# ref_output_tensor = torch.zeros( -# input_entry_per_rank * world_size, -# vector_dim, -# dtype=dtype, -# device=torch.device("cuda"), -# ) -# target_rank_ids = torch.randint( -# 0, -# world_size, -# (input_entry_per_rank * world_size,), -# dtype=torch.int32, -# device=torch.device("cuda"), -# ) - -# input_tensors_all_ranks = list(torch.split(input_tensor, input_entry_per_rank)) -# target_rank_ids_all_ranks = list(torch.split(target_rank_ids, input_entry_per_rank)) - -# send_ids_all_ranks = [] -# send_counts_all_ranks = [] -# send_cumsum_all_ranks = [] -# send_start_end_all_ranks = [] - -# # each rank do its own local compute to get how to send data to other ranks. -# for rank in range(world_size): -# send_start_end = [] -# local_target_rank_ids = target_rank_ids_all_ranks[rank] -# sorted_local_target_rank_ids, local_send_id = torch.sort(local_target_rank_ids) -# local_send_id = local_send_id.to(torch.int32) -# padded_sorted_local_target_rank_ids = torch.cat( -# ( -# sorted_local_target_rank_ids, -# torch.arange( -# world_size, dtype=torch.int32, device=torch.device("cuda") -# ), -# ) -# ) -# unique_target_rank_ids, local_send_counts = torch.unique( -# padded_sorted_local_target_rank_ids, return_counts=True -# ) -# local_send_counts = local_send_counts.to(torch.int32) -# assert unique_target_rank_ids.numel() == world_size, ( -# "unique_target_rank_ids must be equal to world_size" -# ) -# local_send_counts -= 1 # remove padding -# local_send_cumsum = torch.cumsum(local_send_counts, dim=0).to(torch.int32) -# send_ids_all_ranks.append(local_send_id) -# send_counts_all_ranks.append(local_send_counts) -# send_cumsum_all_ranks.append(local_send_cumsum) -# local_send_cumsum_cpu = local_send_cumsum.cpu().tolist() -# for i in range(len(local_send_cumsum_cpu)): -# send_start_end.append( -# ( -# local_send_cumsum_cpu[i - 1] if i > 0 else 0, -# local_send_cumsum_cpu[i], -# ) -# ) -# send_start_end_all_ranks.append(send_start_end) - -# recv_ids_all_ranks = [] -# recv_cumsum_all_ranks = [] - -# output_tensors_all_ranks = [] - -# total_recv_all_ranks_cpu = [] -# output_indice_offset = 0 - -# output_start_current_rank = 0 -# # each rank do compute based on other ranks' send counts to get how to receive data from other ranks. -# for rank in range(world_size): -# local_recv_counts = torch.zeros( -# world_size, dtype=torch.int32, device=torch.device("cuda") -# ) -# for other_rank in range(world_size): -# local_recv_counts[other_rank] = send_counts_all_ranks[other_rank][rank] -# local_recv_count_pair = local_recv_counts[other_rank].cpu().item() -# send_rank_start_end = send_start_end_all_ranks[other_rank][rank] -# ref_output_tensor[ -# output_indice_offset : output_indice_offset + local_recv_count_pair -# ] = input_tensors_all_ranks[other_rank][ -# send_ids_all_ranks[other_rank][ -# send_rank_start_end[0] : send_rank_start_end[1] -# ] -# ] -# output_indice_offset += local_recv_count_pair -# local_recv_cumsum = torch.cumsum(local_recv_counts, dim=0).to(torch.int32) -# recv_cumsum_all_ranks.append(local_recv_cumsum) -# total_recv_count = local_recv_cumsum[-1].cpu() -# total_recv_all_ranks_cpu.append(total_recv_count) -# output_tensors_all_ranks.append( -# output_tensor[ -# output_start_current_rank : output_start_current_rank + total_recv_count -# ] -# ) -# output_start_current_rank += total_recv_count -# local_recv_ids = torch.arange( -# total_recv_count, dtype=torch.int32, device=torch.device("cuda") -# ) -# recv_ids_all_ranks.append(local_recv_ids) - -# cuda_streams_all_ranks = [torch.cuda.Stream() for _ in range(world_size)] - -# workspace_size = tllm_alltoall.get_moe_commworkspace_size_per_rank(world_size) -# all_workspaces = torch.zeros( -# world_size, workspace_size, dtype=torch.uint64, device=torch.device("cuda") -# ) - -# # Synchronize before starting parallel communication -# torch.cuda.synchronize() - -# # do alltoall in parallel -# for rank in range(world_size): -# with torch.cuda.stream(cuda_streams_all_ranks[rank]): -# tllm_alltoall.moe_comm( -# input_tensors_all_ranks[rank], -# send_cumsum_all_ranks[rank], -# send_ids_all_ranks[rank], -# output_tensors_all_ranks[rank], -# recv_cumsum_all_ranks[rank], -# recv_ids_all_ranks[rank], -# all_workspaces, -# rank, -# world_size, -# ) -# for rank in range(world_size): -# cuda_streams_all_ranks[rank].synchronize() - -# torch.testing.assert_close(output_tensor, ref_output_tensor, atol=1e-5, rtol=1e-5) - - -# @pytest.mark.parametrize( -# "ep_rank,ep_size,expert_count,top_k,max_token_count_per_rank,use_real_rank_token_count_cumsum", -# PREPARE_INDICES_PARAMS, -# ) -# def test_moe_alltoall_prepare_indices( -# ep_rank, -# ep_size, -# expert_count, -# top_k, -# max_token_count_per_rank, -# use_real_rank_token_count_cumsum, -# ): -# """Test MOE alltoall prepare indices functionality.""" -# torch.cuda.set_device(0) - -# def generate_references(): -# rank_token_count = max_token_count_per_rank -# if use_real_rank_token_count_cumsum: -# # Make sure we have at least 1 token in each rank except last rank -# rank_token_counts = [ -# max(1, torch.randint(1, max_token_count_per_rank + 1, (1,)).item()) -# for _ in range(ep_size - 1) -# ] -# rank_token_counts.append( -# max_token_count_per_rank -# ) # last rank has max tokens -# real_rank_token_count_cumsum = ( -# torch.tensor( -# rank_token_counts, dtype=torch.int32, device=torch.device("cuda") -# ) -# .cumsum(dim=0) -# .to(torch.int32) -# ) -# rank_token_count = rank_token_counts[ep_rank] -# else: -# real_rank_token_count_cumsum = None - -# # Generate target rank ids for this rank -# target_rank_ids = torch.randint( -# 0, -# ep_size, -# (rank_token_count, top_k), -# dtype=torch.int32, -# device=torch.device("cuda"), -# ) - -# if not use_real_rank_token_count_cumsum: -# gathered_target_rank_ids = torch.zeros( -# ep_size * max_token_count_per_rank, -# top_k, -# dtype=torch.int32, -# device=torch.device("cuda"), -# ) -# gathered_target_rank_ids[ -# ep_rank * max_token_count_per_rank : ep_rank * max_token_count_per_rank -# + rank_token_count -# ] = target_rank_ids -# else: -# total_tokens = real_rank_token_count_cumsum[-1].item() -# gathered_target_rank_ids = torch.zeros( -# total_tokens, top_k, dtype=torch.int32, device=torch.device("cuda") -# ) -# start_pos = ( -# 0 if ep_rank == 0 else real_rank_token_count_cumsum[ep_rank - 1].item() -# ) -# gathered_target_rank_ids[start_pos : start_pos + rank_token_count] = ( -# target_rank_ids -# ) - -# return gathered_target_rank_ids, real_rank_token_count_cumsum, target_rank_ids - -# gathered_target_rank_ids, real_rank_token_count_cumsum, target_rank_ids = ( -# generate_references() -# ) - -# ( -# local_gather_indices, -# send_rank_count_cumsum, -# send_rank_local_indices, -# recv_rank_count_cumsum, -# recv_rank_local_indices, -# backward_recv_rank_local_indices, -# ) = tllm_alltoall.moe_comm_prepare_indices( -# gathered_target_rank_ids, -# real_rank_token_count_cumsum, -# max_token_count_per_rank, -# expert_count, -# top_k, -# ep_rank, -# ep_size, -# ) - -# # Validate shapes -# assert local_gather_indices.shape[0] <= max_token_count_per_rank * ep_size -# assert send_rank_count_cumsum.shape[0] == ep_size -# assert recv_rank_count_cumsum.shape[0] == ep_size -# assert send_rank_local_indices.shape[0] <= max_token_count_per_rank * max( -# ep_size, top_k -# ) -# assert recv_rank_local_indices.shape[0] <= max_token_count_per_rank * ep_size -# assert backward_recv_rank_local_indices.shape[0] <= max_token_count_per_rank * max( -# ep_size, top_k -# ) - -# # Basic validation - cumulative sums should be non-decreasing -# assert torch.all(send_rank_count_cumsum[1:] >= send_rank_count_cumsum[:-1]) -# assert torch.all(recv_rank_count_cumsum[1:] >= recv_rank_count_cumsum[:-1]) - - -# @pytest.mark.parametrize( -# "ep_rank,ep_size,expert_count,top_k,max_token_count_per_rank", LOCAL_GATHER_PARAMS -# ) -# def test_moe_local_gather( -# ep_rank, -# ep_size, -# expert_count, -# top_k, -# max_token_count_per_rank, -# ): -# """Test MOE local gather functionality.""" -# torch.cuda.set_device(0) - -# # Generate test data using the original method -# rank_token_count_cumsum = torch.randint( -# 0, -# max_token_count_per_rank + 1, -# (ep_size,), -# dtype=torch.int32, -# device=torch.device("cuda"), -# ) -# rank_token_count_cumsum = torch.cumsum(rank_token_count_cumsum, dim=0).to( -# torch.int32 -# ) -# local_token_count = rank_token_count_cumsum[ep_size - 1].cpu().item() -# local_max_token_count = max_token_count_per_rank * ep_size -# local_gather_indices = torch.randint( -# 0, -# max_token_count_per_rank * ep_size, -# (local_max_token_count,), -# dtype=torch.int32, -# device=torch.device("cuda"), -# ) - -# gathered_expert_ids = torch.randint( -# 0, -# expert_count, -# (max_token_count_per_rank * ep_size, top_k), -# dtype=torch.int32, -# device=torch.device("cuda"), -# ) -# gathered_scales = torch.rand( -# (max_token_count_per_rank * ep_size, top_k), -# dtype=torch.float32, -# device=torch.device("cuda"), -# ) - -# ref_local_expert_ids = torch.zeros( -# local_max_token_count, top_k, dtype=torch.int32, device=torch.device("cuda") -# ) -# ref_local_scales = torch.zeros( -# local_max_token_count, -# top_k, -# dtype=torch.float32, -# device=torch.device("cuda"), -# ) - -# # compute reference -# ref_local_expert_ids += expert_count -# valid_local_gather_indices = local_gather_indices[:local_token_count] -# ref_local_expert_ids[:local_token_count] = gathered_expert_ids[ -# valid_local_gather_indices -# ] -# ref_local_scales[:local_token_count] = gathered_scales[valid_local_gather_indices] - -# local_expert_ids = torch.empty( -# local_max_token_count, top_k, dtype=torch.int32, device=torch.device("cuda") -# ) -# local_scales = torch.empty( -# local_max_token_count, -# top_k, -# dtype=torch.float32, -# device=torch.device("cuda"), -# ) - -# tllm_alltoall.moe_local_gather( -# rank_token_count_cumsum, -# local_gather_indices, -# gathered_expert_ids, -# gathered_scales, -# local_expert_ids, -# local_scales, -# max_token_count_per_rank, -# expert_count, -# top_k, -# ep_rank, -# ep_size, -# ) - -# assert torch.equal(local_expert_ids, ref_local_expert_ids) -# assert torch.equal(local_scales, ref_local_scales) - - -# @pytest.mark.parametrize( -# "ep_rank, ep_size, expert_count, slot_count, top_k, max_token_count_per_rank", -# [ -# (0, 2, 16, 20, 8, 512), -# (0, 2, 16, 16, 3, 300), -# (0, 4, 20, 24, 8, 4000), -# (0, 8, 96, 96, 8, 1000), -# (3, 8, 128, 128, 8, 1000), -# (3, 8, 128, 144, 8, 1), -# (0, 4, 72, 80, 4, 2256), -# (0, 4, 72, 80, 6, 3333), -# # Hang with stream count > 8 -# # (0, 9, 90, 8, 100), -# ], -# ) -# def test_moe_alltoall_prepare( -# ep_rank: int, -# ep_size: int, -# expert_count: int, -# slot_count: int, -# top_k: int, -# max_token_count_per_rank: int, -# ): -# torch.cuda.set_device(0) - -# cpu_expert_ids_all_ranks_lists = [] -# cpu_token_count_lists = [] -# cpu_scales_all_ranks_lists = [] -# for _ in range(ep_size): -# token_count = torch.randint( -# max_token_count_per_rank // 2, -# max_token_count_per_rank + 1, -# (1,), -# dtype=torch.int32, -# device=torch.device("cpu"), -# ) -# token_count = 1 if token_count == 0 else token_count - -# token_count = max_token_count_per_rank - -# cpu_expert_ids_all_ranks_lists.append( -# torch.randint( -# 0, -# slot_count, -# (token_count, top_k), -# dtype=torch.int32, -# device=torch.device("cpu"), -# ) -# ) - -# cpu_scales_all_ranks_lists.append( -# torch.zeros( -# token_count, top_k, dtype=torch.float32, device=torch.device("cpu") -# ) -# + 0.5 -# ) - -# cpu_token_count_lists.append(token_count) - -# def compute_target_rank(expert_id): -# ep_per_rank = slot_count // ep_size -# return expert_id // ep_per_rank - -# def generate_references(): -# ref_prepared_local_expert_ids = [] -# ref_prepared_local_scales = [] -# ref_local_send_rank_count_cumsum = [0] * ep_size -# ref_local_recv_rank_count_cumsum = [0] * ep_size -# ref_local_recv_rank_indices = [] - -# local_token_count = cpu_token_count_lists[ep_rank] -# send_token_count_to_ranks = [0] * ep_size - -# # send part -# for token_id in range(local_token_count): -# target_set = set() -# for pos in range(top_k): -# expert_id = int(cpu_expert_ids_all_ranks_lists[ep_rank][token_id][pos]) -# target_rank_id = compute_target_rank(expert_id) -# target_set.add(target_rank_id) - -# for target_rank_id in target_set: -# send_token_count_to_ranks[target_rank_id] += 1 - -# total_send_token_count = 0 -# for rank in range(ep_size): -# # print(f'rank: {rank}, send_token_count_to_ranks[rank]: {send_token_count_to_ranks[rank]}') -# base = ref_local_send_rank_count_cumsum[rank - 1] if rank > 0 else 0 -# ref_local_send_rank_count_cumsum[rank] = ( -# send_token_count_to_ranks[rank] + base -# ) -# total_send_token_count += send_token_count_to_ranks[rank] - -# ref_local_backward_send_rank_indices = [0] * (total_send_token_count) -# ref_local_send_rank_indices = [0] * (total_send_token_count) - -# current_send_token_ids = [0] * ep_size -# for token_id in range(local_token_count): -# target_set = set() -# for pos in range(top_k): -# expert_id = int(cpu_expert_ids_all_ranks_lists[ep_rank][token_id][pos]) -# target_rank_id = compute_target_rank(expert_id) -# if target_rank_id not in target_set: -# cumsum_before = ( -# 0 -# if target_rank_id == 0 -# else ref_local_send_rank_count_cumsum[target_rank_id - 1] -# ) -# send_index = cumsum_before + current_send_token_ids[target_rank_id] -# ref_local_send_rank_indices[send_index] = token_id -# ref_local_backward_send_rank_indices[send_index] = ( -# token_id * top_k + pos -# ) -# current_send_token_ids[target_rank_id] += 1 -# target_set.add(target_rank_id) - -# # receive part -# total_recv_token_count = 0 -# for rank in range(ep_size): -# token_count = cpu_token_count_lists[rank] -# current_recv_token_count = 0 -# for token_id in range(token_count): -# token_is_received = False -# for pos in range(top_k): -# expert_id = int(cpu_expert_ids_all_ranks_lists[rank][token_id][pos]) -# sf = cpu_scales_all_ranks_lists[rank][token_id][pos] -# target_rank_id = compute_target_rank(expert_id) -# if target_rank_id == ep_rank: -# if not token_is_received: -# token_is_received = True -# ref_prepared_local_expert_ids.append([slot_count] * top_k) -# ref_prepared_local_scales.append([0.0] * top_k) -# ref_prepared_local_expert_ids[-1][pos] = expert_id -# ref_prepared_local_scales[-1][pos] = sf -# if token_is_received: -# ref_local_recv_rank_indices.append(total_recv_token_count) -# total_recv_token_count += 1 -# current_recv_token_count += 1 -# ref_local_recv_rank_count_cumsum[rank] = ( -# current_recv_token_count -# if rank == 0 -# else ref_local_recv_rank_count_cumsum[rank - 1] -# + current_recv_token_count -# ) - -# return ( -# ref_prepared_local_expert_ids, -# ref_prepared_local_scales, -# ref_local_send_rank_count_cumsum, -# ref_local_send_rank_indices, -# ref_local_recv_rank_count_cumsum, -# ref_local_recv_rank_indices, -# ref_local_backward_send_rank_indices, -# total_recv_token_count, -# ) - -# ( -# ref_prepared_local_expert_ids, -# ref_prepared_local_scales, -# ref_local_send_rank_count_cumsum, -# ref_local_send_rank_indices, -# ref_local_recv_rank_count_cumsum, -# ref_local_recv_rank_indices, -# ref_local_backward_send_rank_indices, -# total_recv_token_count, -# ) = generate_references() - -# cpu_experter_count_lists = [] -# for rank in range(ep_size): -# local_expert_count = [] -# for i in range(expert_count): -# local_expert_count.append(rank * expert_count + i) -# cpu_experter_count_lists.append(torch.IntTensor(local_expert_count)) - -# # expert_ids_all_ranks = torch.tensor(cpu_expert_ids_all_ranks_lists).cuda() -# expert_ids_all_ranks = [ -# cpu_expert_ids_all_ranks_lists[i].cuda() for i in range(ep_size) -# ] -# # scales_all_ranks = torch.FloatTensor(cpu_scales_all_ranks_lists).cuda() -# scales_all_ranks = [cpu_scales_all_ranks_lists[i].cuda() for i in range(ep_size)] - -# experter_count_lists = [cpu_experter_count_lists[i].cuda() for i in range(ep_size)] - -# cuda_streams_all_ranks = [torch.cuda.Stream() for _ in range(ep_size)] - -# workspace_size = tllm_alltoall.get_moe_prepare_workspace_size_per_rank(ep_size) - -# all_workspaces = torch.zeros( -# ep_size, workspace_size, dtype=torch.uint64, device=torch.device("cuda") -# ) - -# stream = torch.cuda.Stream() -# cur_stream = torch.cuda.current_stream() -# stream.wait_stream(cur_stream) -# with torch.cuda.stream(stream): -# tllm_alltoall.moe_prepare( -# expert_ids_all_ranks[0], -# scales_all_ranks[0], -# experter_count_lists[0], -# all_workspaces, -# max_token_count_per_rank, -# 0, -# 1, -# expert_count, -# slot_count, -# top_k, -# ) -# cur_stream.wait_stream(stream) - -# # Make torch alloc tensor to avoid cuda sync -# prepared_local_experts = [] -# prepared_local_scales = [] -# local_send_rank_count_cumsum = [] -# local_send_rank_indices = [] -# local_recv_rank_count_cumsum = [] -# local_recv_rank_indices = [] -# backward_local_recv_rank_indices = [] -# for _ in range(ep_size): -# prepared_local_experts.append( -# torch.empty( -# max_token_count_per_rank * ep_size, -# top_k, -# dtype=torch.int32, -# device=torch.device("cuda"), -# ) -# ) -# prepared_local_scales.append( -# torch.empty( -# max_token_count_per_rank * ep_size, -# top_k, -# dtype=torch.float32, -# device=torch.device("cuda"), -# ) -# ) -# local_send_rank_count_cumsum.append( -# torch.empty(ep_size, dtype=torch.int32, device=torch.device("cuda")) -# ) -# local_send_rank_indices.append( -# torch.empty( -# max_token_count_per_rank * ep_size, -# dtype=torch.int32, -# device=torch.device("cuda"), -# ) -# ) -# local_recv_rank_count_cumsum.append( -# torch.empty(0, dtype=torch.int32, device=torch.device("cuda")) -# ) -# local_recv_rank_indices.append( -# torch.empty(0, dtype=torch.int32, device=torch.device("cuda")) -# ) -# backward_local_recv_rank_indices.append( -# torch.empty(0, dtype=torch.int32, device=torch.device("cuda")) -# ) - -# prepared_local_experts = None -# prepared_local_scales = None -# local_send_rank_count_cumsum = None -# local_send_rank_indices = None -# local_recv_rank_count_cumsum = None -# local_recv_rank_indices = None -# backward_local_recv_rank_indices = None - -# # reset the workspace -# all_workspaces = torch.zeros( -# ep_size, workspace_size, dtype=torch.uint64, device=torch.device("cuda") -# ) - -# # do prepare in parallel -# cur_stream = torch.cuda.current_stream() -# for rank in range(ep_size): -# s = cuda_streams_all_ranks[rank] -# s.wait_stream(cur_stream) -# with torch.cuda.stream(s): -# if rank == ep_rank: -# ( -# prepared_local_experts, -# prepared_local_scales, -# local_send_rank_count_cumsum, -# local_send_rank_indices, -# local_recv_rank_count_cumsum, -# local_recv_rank_indices, -# backward_local_recv_rank_indices, -# gathered_expert_statics, -# ) = tllm_alltoall.moe_prepare( -# expert_ids_all_ranks[rank], -# scales_all_ranks[rank], -# experter_count_lists[rank], -# all_workspaces, -# max_token_count_per_rank, -# rank, -# ep_size, -# expert_count, -# slot_count, -# top_k, -# ) -# else: -# tllm_alltoall.moe_prepare( -# expert_ids_all_ranks[rank], -# scales_all_ranks[rank], -# experter_count_lists[rank], -# all_workspaces, -# max_token_count_per_rank, -# rank, -# ep_size, -# expert_count, -# slot_count, -# top_k, -# ) -# for rank in range(ep_size): -# cuda_streams_all_ranks[rank].synchronize() - -# prepared_local_experts_cpu = prepared_local_experts[:total_recv_token_count].cpu() -# prepared_local_scales_cpu = prepared_local_scales[:total_recv_token_count].cpu() -# for i in range(total_recv_token_count): -# for j in range(top_k): -# expert_id = int(prepared_local_experts_cpu[i][j]) -# assert expert_id >= 0 and expert_id <= slot_count -# if expert_id < slot_count: -# assert compute_target_rank(expert_id) == ep_rank -# scale = float(prepared_local_scales_cpu[i][j]) -# assert scale > 1e-6 - -# gathered_expert_statics_cpu = gathered_expert_statics.cpu() -# for rank in range(ep_size): -# for i in range(expert_count): -# assert int(gathered_expert_statics_cpu[rank][i]) == rank * expert_count + i - -# ref_local_send_rank_count_cumsum = torch.IntTensor(ref_local_send_rank_count_cumsum) -# assert torch.equal( -# local_send_rank_count_cumsum.cpu(), ref_local_send_rank_count_cumsum -# ) - -# local_send_rank_indices = local_send_rank_indices.cpu() -# backward_local_recv_rank_indices = backward_local_recv_rank_indices.cpu() -# for i in range(ep_size): -# base = 0 if i == 0 else ref_local_send_rank_count_cumsum[i - 1] -# for j in range(base, ref_local_send_rank_count_cumsum[i]): -# token_id = local_send_rank_indices[j] -# lane_id = backward_local_recv_rank_indices[j] - token_id * top_k -# expert_id = int(cpu_expert_ids_all_ranks_lists[ep_rank][token_id][lane_id]) -# assert compute_target_rank(expert_id) == i - -# ref_local_recv_rank_count_cumsum = torch.IntTensor(ref_local_recv_rank_count_cumsum) -# assert torch.equal( -# local_recv_rank_count_cumsum[: ref_local_recv_rank_count_cumsum.size(0)].cpu(), -# ref_local_recv_rank_count_cumsum, -# ) - -# ref_local_recv_rank_indices = torch.IntTensor(ref_local_recv_rank_indices) -# assert torch.equal( -# local_recv_rank_indices[: ref_local_recv_rank_indices.size(0)].cpu(), -# ref_local_recv_rank_indices, -# ) - - -# if __name__ == "__main__": -# pytest.main([__file__, "-v"]) +""" +Copyright (c) 2024 by FlashInfer team. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import pytest +import torch + +from flashinfer.comm.mapping import Mapping + +import flashinfer.comm.trtllm_moe_alltoall as trtllm_moe_alltoall + + +@pytest.fixture(autouse=True, scope="session") +def setup_test_environment(): + """Set up test environment and warm up JIT compilation.""" + torch.manual_seed(0xD5) + yield + + +# Single GPU test parameters +SINGLE_GPU_PARAMS = [ + (902, 32768, 256, 8, torch.float16), # Large data, float16 + (101, 288, 128, 4, torch.float16), # Medium data, float16 + (902, 7168, 256, 8, torch.bfloat16), # Large data, bfloat16 + (101, 288, 128, 4, torch.bfloat16), # Medium data, bfloat16 + (10, 8, 8, 2, torch.bfloat16), # Small data, bfloat16 +] + +MULTI_RANK_PARAMS = [ + (2, 5, 8, torch.float16), # Small input, 2 ranks + (4, 901, 32768, torch.bfloat16), # Large input, 4 ranks + (8, 16384, 128, torch.float16), # Many small vectors, 8 ranks +] + +PREPARE_INDICES_PARAMS = [ + (0, 8, 256, 4, 3, False), # Rank 0, small config + (1, 8, 256, 4, 3, True), # Rank 1, small config with real cumsum + (7, 8, 256, 8, 1025, False), # High rank, medium config + (7, 64, 1024, 32, 1029, True), # High rank, large config with real cumsum +] + +LOCAL_GATHER_PARAMS = [ + (0, 8, 256, 4, 3), # Rank 0, small config + (7, 8, 256, 8, 32), # High rank, medium config + (7, 64, 1024, 32, 1029), # High rank, large config +] + + +# Real cross-GPU communication test parameters +CROSS_GPU_PARAMS = [ + (2, 100, 256, torch.float16), # 2 GPUs, 2 ranks + (2, 300, 512, torch.bfloat16), # 2 GPUs, 2 ranks, larger data + (4, 150, 256, torch.float16), # 4 GPUs, 4 ranks (if available) + (4, 400, 512, torch.float16), # 4 GPUs, 4 ranks, larger data +] + + +def get_available_gpu_count(): + """Get the number of available GPUs.""" + if not torch.cuda.is_available(): + return 0 + return torch.cuda.device_count() + + +def requires_gpus(min_gpus): + """Decorator to skip test if insufficient GPUs are available.""" + + def decorator(func): + return pytest.mark.skipif( + get_available_gpu_count() < min_gpus, + reason=f"Requires at least {min_gpus} GPUs, but only {get_available_gpu_count()} available", + )(func) + + return decorator + + +@pytest.mark.parametrize( + "num_tokens,vector_dim,num_experts,top_k,dtype", + SINGLE_GPU_PARAMS, +) +def test_moe_alltoall_single_gpu(num_tokens, vector_dim, num_experts, top_k, dtype): + """Test MOE alltoall communication on single GPU.""" + torch.cuda.set_device(0) + # Create a random input tensor + input_tensor1 = torch.randn( + num_tokens, vector_dim, dtype=dtype, device=torch.device("cuda") + ) + input_tensor2 = torch.randn( + num_tokens, vector_dim * 2, dtype=dtype, device=torch.device("cuda") + ) + + token_selected_experts = torch.empty( + num_tokens, top_k, dtype=torch.int32, device=torch.device("cuda") + ) + for i in range(num_tokens): + # Include one extra expert to represent invalid expert IDs + token_selected_experts[i] = torch.randperm( + num_experts, dtype=torch.int32, device=torch.device("cuda") + )[:top_k] + token_selected_experts = token_selected_experts.contiguous() + + workspace_size = trtllm_moe_alltoall.moe_a2a_get_workspace_size_per_rank( + 1, + num_tokens, + input_tensor1.numel() * dtype.itemsize + + input_tensor2.numel() * dtype.itemsize + + token_selected_experts.numel() * torch.int32.itemsize, + ) + mapping = Mapping(rank=0, world_size=1) + moe_a2a = trtllm_moe_alltoall.MoeAlltoAll( + mapping, + num_tokens, + top_k, + num_experts, + workspace_size_per_rank=workspace_size, + ) + + output_tensor1, output_tensor2, token_selected_experts_output = moe_a2a.dispatch( + token_selected_experts, + [input_tensor1, input_tensor2, token_selected_experts], + num_tokens, + invalid_token_expert_id=-3, # Tokens assigned to invalid expert are set to -3 + expert_id_payload_index=2, + ) + + # Sort to undo the shuffling that happens in the dispatch kernel. + input_tensor1, _ = torch.sort(input_tensor1, dim=0) + input_tensor2, _ = torch.sort(input_tensor2, dim=0) + token_selected_experts, _ = torch.sort(token_selected_experts, dim=0) + output_tensor1, _ = torch.sort(output_tensor1[0], dim=0) + output_tensor2, _ = torch.sort(output_tensor2[0], dim=0) + token_selected_experts_output, _ = torch.sort( + token_selected_experts_output[0], dim=0 + ) + + torch.testing.assert_close(output_tensor1, input_tensor1, atol=0, rtol=0) + torch.testing.assert_close(output_tensor2, input_tensor2, atol=0, rtol=0) + torch.testing.assert_close( + token_selected_experts_output, token_selected_experts, atol=0, rtol=0 + ) + + moe_a2a._reset_workspace() + + +@pytest.mark.parametrize("world_size,num_tokens,vector_dim,dtype", MULTI_RANK_PARAMS) +def test_moe_alltoall_multi_rank_single_gpu(world_size, num_tokens, vector_dim, dtype): + """Test MOE alltoall communication with multiple ranks on single GPU.""" + torch.cuda.set_device(0) + max_world_size = 8 + assert world_size <= max_world_size, ( + f"should run with world_size at most {max_world_size}" + ) + + # SM count is now set up globally in the fixture + + # Create a random input tensor + input_tensors = [ + torch.randn( + num_tokens * world_size, + vector_dim * (i + 1), + dtype=dtype, + device=torch.device("cuda"), + ) + for i in range(2) + ] + + token_selected_experts = torch.randint( + 0, + world_size, + (num_tokens * world_size, 1), + dtype=torch.int32, + device=torch.device("cuda"), + ) + + payloads = input_tensors + [token_selected_experts] + total_payload_size_per_element = [x[0].numel() * x.itemsize for x in payloads] + total_payload_size_per_element = sum(total_payload_size_per_element) + + workspace_size = trtllm_moe_alltoall.moe_a2a_get_workspace_size_per_rank( + world_size, num_tokens * world_size, total_payload_size_per_element + ) + + all_workspaces = torch.zeros( + world_size, workspace_size, dtype=torch.uint8, device=torch.device("cuda") + ) + + # Must be done before the synchronization so the state is cleared + metainfo = [] + for rank in range(world_size): + metainfo.append( + trtllm_moe_alltoall.moe_a2a_initialize( + all_workspaces, + rank, + world_size, + num_tokens * world_size, + ) + ) + + # Synchronize before starting parallel communication + torch.cuda.synchronize() + + output_tensors = [] + # do alltoall in parallel + cuda_streams_all_ranks = [torch.cuda.Stream() for _ in range(world_size)] + for rank in range(world_size): + with torch.cuda.stream(cuda_streams_all_ranks[rank]): + rank_payloads = [ + x[rank * num_tokens : (rank + 1) * num_tokens] for x in payloads + ] + output_tensors.append( + trtllm_moe_alltoall.moe_a2a_dispatch( + rank_payloads[2], + rank_payloads, + all_workspaces, + metainfo[rank], + num_tokens, + ep_rank=rank, + ep_size=world_size, + top_k=1, + num_experts=world_size, + )[0] + ) + + for rank in range(world_size): + cuda_streams_all_ranks[rank].synchronize() + + torch.cuda.synchronize() + + torch.set_printoptions(threshold=float("inf")) + print( + f"all_workspaces: {all_workspaces.shape} {all_workspaces.flatten().view(torch.uint8)[1152:1632].view(torch.bfloat16)}" + ) + + for rank in range(world_size): + print(f"output_tensors[{rank}]: {output_tensors[rank]}") + + for rank in range(world_size): + # Get the indices where token_selected_experts == rank + print( + f"token_selected_experts: {token_selected_experts.shape} {token_selected_experts}" + ) + token_selected_experts_indices = ( + token_selected_experts.flatten() == rank + ).nonzero(as_tuple=False) + + for actual, ref in zip(output_tensors[rank], payloads, strict=True): + print(f"token_selected_experts_indices: {token_selected_experts_indices}") + print(f"actual raw: {actual.shape} {actual}") + actual = actual[rank][: len(token_selected_experts_indices)] + print(f"actual filtered: {actual.shape} {actual}") + ref = ref[token_selected_experts_indices].squeeze() + actual, _ = torch.sort(actual, dim=0) + ref, _ = torch.sort(ref, dim=0) + print(f"actual: {actual}") + print(f"ref: {ref}") + torch.testing.assert_close(actual, ref, atol=0, rtol=0) + + +# TODO Add a combine test + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) From 49e882f3aaf55c6ea3e9df191e9b2efcadfd3b14 Mon Sep 17 00:00:00 2001 From: djns99 <40156487+djns99@users.noreply.github.com> Date: Fri, 28 Nov 2025 11:03:33 +1300 Subject: [PATCH 17/25] Fix existing dispatch tests --- flashinfer/comm/trtllm_moe_alltoall.py | 8 +- tests/comm/test_trtllm_moe_alltoall.py | 139 +++++++++---------------- 2 files changed, 55 insertions(+), 92 deletions(-) diff --git a/flashinfer/comm/trtllm_moe_alltoall.py b/flashinfer/comm/trtllm_moe_alltoall.py index 3bedc57e01..9a76cf495b 100644 --- a/flashinfer/comm/trtllm_moe_alltoall.py +++ b/flashinfer/comm/trtllm_moe_alltoall.py @@ -175,10 +175,10 @@ def moe_a2a_get_metainfo_index_pairs(): def moe_a2a_get_workspace_size_per_rank( ep_size: int, max_num_tokens: int, - payload_size_per_element: int, + payload_size_per_token: int, ): return module.moe_a2a_get_workspace_size_per_rank( - ep_size, max_num_tokens, payload_size_per_element + ep_size, max_num_tokens, payload_size_per_token ) return SimpleNamespace( @@ -317,10 +317,10 @@ def moe_a2a_sanitize_expert_ids( def moe_a2a_get_workspace_size_per_rank( ep_size: int, max_num_tokens: int, - payload_size_per_element: int, + payload_size_per_token: int, ): return get_mnnvl_a2a_module().moe_a2a_get_workspace_size_per_rank( - ep_size, max_num_tokens, payload_size_per_element + ep_size, max_num_tokens, payload_size_per_token ) diff --git a/tests/comm/test_trtllm_moe_alltoall.py b/tests/comm/test_trtllm_moe_alltoall.py index 87b341a788..b5e65d95f1 100644 --- a/tests/comm/test_trtllm_moe_alltoall.py +++ b/tests/comm/test_trtllm_moe_alltoall.py @@ -31,39 +31,15 @@ def setup_test_environment(): # Single GPU test parameters SINGLE_GPU_PARAMS = [ - (902, 32768, 256, 8, torch.float16), # Large data, float16 - (101, 288, 128, 4, torch.float16), # Medium data, float16 - (902, 7168, 256, 8, torch.bfloat16), # Large data, bfloat16 - (101, 288, 128, 4, torch.bfloat16), # Medium data, bfloat16 - (10, 8, 8, 2, torch.bfloat16), # Small data, bfloat16 + (902, 7168, 256, 8), # Large data + (10, 288, 128, 4), # Medium data + (10, 8, 8, 2), # Small data ] MULTI_RANK_PARAMS = [ - (2, 5, 8, torch.float16), # Small input, 2 ranks - (4, 901, 32768, torch.bfloat16), # Large input, 4 ranks - (8, 16384, 128, torch.float16), # Many small vectors, 8 ranks -] - -PREPARE_INDICES_PARAMS = [ - (0, 8, 256, 4, 3, False), # Rank 0, small config - (1, 8, 256, 4, 3, True), # Rank 1, small config with real cumsum - (7, 8, 256, 8, 1025, False), # High rank, medium config - (7, 64, 1024, 32, 1029, True), # High rank, large config with real cumsum -] - -LOCAL_GATHER_PARAMS = [ - (0, 8, 256, 4, 3), # Rank 0, small config - (7, 8, 256, 8, 32), # High rank, medium config - (7, 64, 1024, 32, 1029), # High rank, large config -] - - -# Real cross-GPU communication test parameters -CROSS_GPU_PARAMS = [ - (2, 100, 256, torch.float16), # 2 GPUs, 2 ranks - (2, 300, 512, torch.bfloat16), # 2 GPUs, 2 ranks, larger data - (4, 150, 256, torch.float16), # 4 GPUs, 4 ranks (if available) - (4, 400, 512, torch.float16), # 4 GPUs, 4 ranks, larger data + (2, 5, 8), # Small input, 2 ranks + (4, 901, 32768), # Large input, 4 ranks + (8, 16384, 128), # Many small vectors, 8 ranks ] @@ -86,20 +62,34 @@ def decorator(func): return decorator +def make_payload(num_tokens, vector_dim, dtype): + if dtype == torch.uint8 or dtype == torch.int32: + return torch.randint( + torch.iinfo(dtype).min, + torch.iinfo(dtype).max, + (num_tokens, vector_dim), + dtype=dtype, + device=torch.device("cuda"), + ) + else: + return torch.randn( + num_tokens, vector_dim, dtype=dtype, device=torch.device("cuda") + ) + + @pytest.mark.parametrize( - "num_tokens,vector_dim,num_experts,top_k,dtype", + "num_tokens,vector_dim,num_experts,top_k", SINGLE_GPU_PARAMS, ) -def test_moe_alltoall_single_gpu(num_tokens, vector_dim, num_experts, top_k, dtype): +def test_moe_alltoall_single_gpu(num_tokens, vector_dim, num_experts, top_k): """Test MOE alltoall communication on single GPU.""" torch.cuda.set_device(0) # Create a random input tensor - input_tensor1 = torch.randn( - num_tokens, vector_dim, dtype=dtype, device=torch.device("cuda") - ) - input_tensor2 = torch.randn( - num_tokens, vector_dim * 2, dtype=dtype, device=torch.device("cuda") - ) + dtypes = [torch.float16, torch.bfloat16, torch.int32, torch.uint8] + input_tensors = [ + make_payload(num_tokens, vector_dim * (i + 1), dtype) + for i, dtype in enumerate(dtypes) + ] token_selected_experts = torch.empty( num_tokens, top_k, dtype=torch.int32, device=torch.device("cuda") @@ -111,12 +101,12 @@ def test_moe_alltoall_single_gpu(num_tokens, vector_dim, num_experts, top_k, dty )[:top_k] token_selected_experts = token_selected_experts.contiguous() + payload_size_per_token = sum([x[0].numel() * x.itemsize for x in input_tensors]) + workspace_size = trtllm_moe_alltoall.moe_a2a_get_workspace_size_per_rank( 1, num_tokens, - input_tensor1.numel() * dtype.itemsize - + input_tensor2.numel() * dtype.itemsize - + token_selected_experts.numel() * torch.int32.itemsize, + payload_size_per_token, ) mapping = Mapping(rank=0, world_size=1) moe_a2a = trtllm_moe_alltoall.MoeAlltoAll( @@ -127,35 +117,25 @@ def test_moe_alltoall_single_gpu(num_tokens, vector_dim, num_experts, top_k, dty workspace_size_per_rank=workspace_size, ) - output_tensor1, output_tensor2, token_selected_experts_output = moe_a2a.dispatch( + output_tensors = moe_a2a.dispatch( token_selected_experts, - [input_tensor1, input_tensor2, token_selected_experts], + input_tensors, num_tokens, invalid_token_expert_id=-3, # Tokens assigned to invalid expert are set to -3 expert_id_payload_index=2, ) # Sort to undo the shuffling that happens in the dispatch kernel. - input_tensor1, _ = torch.sort(input_tensor1, dim=0) - input_tensor2, _ = torch.sort(input_tensor2, dim=0) - token_selected_experts, _ = torch.sort(token_selected_experts, dim=0) - output_tensor1, _ = torch.sort(output_tensor1[0], dim=0) - output_tensor2, _ = torch.sort(output_tensor2[0], dim=0) - token_selected_experts_output, _ = torch.sort( - token_selected_experts_output[0], dim=0 - ) - - torch.testing.assert_close(output_tensor1, input_tensor1, atol=0, rtol=0) - torch.testing.assert_close(output_tensor2, input_tensor2, atol=0, rtol=0) - torch.testing.assert_close( - token_selected_experts_output, token_selected_experts, atol=0, rtol=0 - ) + for input_tensor, output_tensor in zip(input_tensors, output_tensors, strict=True): + input_tensor, _ = torch.sort(input_tensor, dim=0) + output_tensor, _ = torch.sort(output_tensor.flatten(end_dim=1), dim=0) + torch.testing.assert_close(output_tensor, input_tensor, atol=0, rtol=0) moe_a2a._reset_workspace() -@pytest.mark.parametrize("world_size,num_tokens,vector_dim,dtype", MULTI_RANK_PARAMS) -def test_moe_alltoall_multi_rank_single_gpu(world_size, num_tokens, vector_dim, dtype): +@pytest.mark.parametrize("world_size,num_tokens,vector_dim", MULTI_RANK_PARAMS) +def test_moe_alltoall_multi_rank_single_gpu(world_size, num_tokens, vector_dim): """Test MOE alltoall communication with multiple ranks on single GPU.""" torch.cuda.set_device(0) max_world_size = 8 @@ -163,17 +143,11 @@ def test_moe_alltoall_multi_rank_single_gpu(world_size, num_tokens, vector_dim, f"should run with world_size at most {max_world_size}" ) - # SM count is now set up globally in the fixture - + dtypes = [torch.float16, torch.bfloat16, torch.int32, torch.uint8] # Create a random input tensor input_tensors = [ - torch.randn( - num_tokens * world_size, - vector_dim * (i + 1), - dtype=dtype, - device=torch.device("cuda"), - ) - for i in range(2) + make_payload(num_tokens * world_size, vector_dim * (i + 1), dtype) + for i, dtype in enumerate(dtypes) ] token_selected_experts = torch.randint( @@ -184,7 +158,7 @@ def test_moe_alltoall_multi_rank_single_gpu(world_size, num_tokens, vector_dim, device=torch.device("cuda"), ) - payloads = input_tensors + [token_selected_experts] + payloads = input_tensors total_payload_size_per_element = [x[0].numel() * x.itemsize for x in payloads] total_payload_size_per_element = sum(total_payload_size_per_element) @@ -219,9 +193,12 @@ def test_moe_alltoall_multi_rank_single_gpu(world_size, num_tokens, vector_dim, rank_payloads = [ x[rank * num_tokens : (rank + 1) * num_tokens] for x in payloads ] + rank_token_selected_experts = token_selected_experts[ + rank * num_tokens : (rank + 1) * num_tokens + ] output_tensors.append( trtllm_moe_alltoall.moe_a2a_dispatch( - rank_payloads[2], + rank_token_selected_experts, rank_payloads, all_workspaces, metainfo[rank], @@ -238,33 +215,19 @@ def test_moe_alltoall_multi_rank_single_gpu(world_size, num_tokens, vector_dim, torch.cuda.synchronize() - torch.set_printoptions(threshold=float("inf")) - print( - f"all_workspaces: {all_workspaces.shape} {all_workspaces.flatten().view(torch.uint8)[1152:1632].view(torch.bfloat16)}" - ) - - for rank in range(world_size): - print(f"output_tensors[{rank}]: {output_tensors[rank]}") - for rank in range(world_size): # Get the indices where token_selected_experts == rank - print( - f"token_selected_experts: {token_selected_experts.shape} {token_selected_experts}" - ) token_selected_experts_indices = ( token_selected_experts.flatten() == rank ).nonzero(as_tuple=False) - for actual, ref in zip(output_tensors[rank], payloads, strict=True): - print(f"token_selected_experts_indices: {token_selected_experts_indices}") - print(f"actual raw: {actual.shape} {actual}") - actual = actual[rank][: len(token_selected_experts_indices)] - print(f"actual filtered: {actual.shape} {actual}") + for actual, ref in zip(output_tensors[rank][:-1], payloads[:-1], strict=True): + # Select the tensors that arent all zeros + actual = actual.flatten(end_dim=1) + actual = actual[actual.any(dim=1)] ref = ref[token_selected_experts_indices].squeeze() actual, _ = torch.sort(actual, dim=0) ref, _ = torch.sort(ref, dim=0) - print(f"actual: {actual}") - print(f"ref: {ref}") torch.testing.assert_close(actual, ref, atol=0, rtol=0) From f98530c88b4d1e84199aaa463027e1f1583a935b Mon Sep 17 00:00:00 2001 From: djns99 <40156487+djns99@users.noreply.github.com> Date: Fri, 28 Nov 2025 13:34:31 +1300 Subject: [PATCH 18/25] Tests for sanitize and combine --- csrc/trtllm_moe_a2a.cu | 15 +- flashinfer/comm/trtllm_moe_alltoall.py | 25 +- tests/comm/test_trtllm_moe_alltoall.py | 380 ++++++++++++++++++++++--- 3 files changed, 376 insertions(+), 44 deletions(-) diff --git a/csrc/trtllm_moe_a2a.cu b/csrc/trtllm_moe_a2a.cu index 7c0099c370..a87875baf0 100644 --- a/csrc/trtllm_moe_a2a.cu +++ b/csrc/trtllm_moe_a2a.cu @@ -88,12 +88,14 @@ fi_throughput::MoeA2ADataOffsets calculateOffsets(int epSize, int maxNumTokens) } int64_t getMoeA2AWorkspaceSizePerRank(int64_t epSize, int64_t maxNumTokens, - int64_t maxPayloadSizePerElement) { + int64_t totalDispatchSizePerToken, + int64_t combineSizePerToken) { int64_t metadata_size = calculateOffsets(static_cast(epSize), static_cast(maxNumTokens))[fi_throughput::PAYLOAD_DATA_OFFSET_INDEX]; - int64_t payload_size = maxNumTokens * maxPayloadSizePerElement; - return alignOffset(metadata_size + payload_size); + int64_t payload_size = maxNumTokens * totalDispatchSizePerToken; + int64_t combine_size = maxNumTokens * combineSizePerToken; + return alignOffset(metadata_size + payload_size) + alignOffset(combine_size); } Tensor moeA2AInitializeOp(TensorView workspace, int64_t epRank, int64_t epSize, @@ -184,6 +186,10 @@ Tuple, Array, int64_t> moeA2ADispatchOp( static_cast(epSize) * runtimeMaxTokensPerRank * elementsPerToken * elementSize; payloadByteSizes[i] = bytesPerPayload; totalBytesNeeded += bytesPerPayload; + + TVM_FFI_ICHECK(totalBytesNeeded % elementSize == 0) + << "Misaligned payload buffer " << i << " with element size " << elementSize + << ". Consider putting ordering payloads by minimum element size"; } auto* workspaceBase = static_cast(workspace.data_ptr()); @@ -310,7 +316,8 @@ Tensor moeA2ACombineOp(TensorView payload, int64_t localNumTokens, TensorView wo if (payloadInWorkspace) { auto* expectedPtr = rankWorkspacePtr + combinePayloadOffset; TVM_FFI_ICHECK(payload.data_ptr() == expectedPtr) - << "payload_in_workspace is True but tensor pointer mismatch"; + << "payload_in_workspace is True but tensor pointer mismatch: " << payload.data_ptr() + << " != " << expectedPtr; } Tensor output = diff --git a/flashinfer/comm/trtllm_moe_alltoall.py b/flashinfer/comm/trtllm_moe_alltoall.py index 9a76cf495b..169c4de622 100644 --- a/flashinfer/comm/trtllm_moe_alltoall.py +++ b/flashinfer/comm/trtllm_moe_alltoall.py @@ -175,10 +175,14 @@ def moe_a2a_get_metainfo_index_pairs(): def moe_a2a_get_workspace_size_per_rank( ep_size: int, max_num_tokens: int, - payload_size_per_token: int, + total_dispatch_payload_size_per_token: int, + combine_payload_size_per_token: int, ): return module.moe_a2a_get_workspace_size_per_rank( - ep_size, max_num_tokens, payload_size_per_token + ep_size, + max_num_tokens, + total_dispatch_payload_size_per_token, + combine_payload_size_per_token, ) return SimpleNamespace( @@ -224,11 +228,12 @@ def moe_a2a_wrap_payload_tensor_in_workspace( Returns: tensor: [ep_size * max_tokens, hidden_size] workspace-backed tensor """ - workspace_base = workspace.flatten().view(dtype=torch.uint8) + workspace_base = workspace.view(-1).view(dtype=torch.uint8) + assert slice_end <= workspace.numel(), ( + f"slice_end {slice_end} exceeds workspace size {workspace.numel()}" + ) result = ( - workspace_base[slice_start:slice_end] - .view(leading_shape + [-1]) - .view(dtype=dtype) + workspace_base[slice_start:slice_end].view(dtype=dtype).view(*leading_shape, -1) ) return result @@ -317,10 +322,14 @@ def moe_a2a_sanitize_expert_ids( def moe_a2a_get_workspace_size_per_rank( ep_size: int, max_num_tokens: int, - payload_size_per_token: int, + total_dispatch_payload_size_per_token: int, + combine_payload_size_per_token: int, ): return get_mnnvl_a2a_module().moe_a2a_get_workspace_size_per_rank( - ep_size, max_num_tokens, payload_size_per_token + ep_size, + max_num_tokens, + total_dispatch_payload_size_per_token, + combine_payload_size_per_token, ) diff --git a/tests/comm/test_trtllm_moe_alltoall.py b/tests/comm/test_trtllm_moe_alltoall.py index b5e65d95f1..52e709f590 100644 --- a/tests/comm/test_trtllm_moe_alltoall.py +++ b/tests/comm/test_trtllm_moe_alltoall.py @@ -42,6 +42,20 @@ def setup_test_environment(): (8, 16384, 128), # Many small vectors, 8 ranks ] +SANITIZE_PARAMS = [ + (2, 5), # Few tokens, 2 ranks + (4, 901), # Many tokens, 4 ranks +] + +COMBINE_PARAMS = [ + (2, 5, 8, 2, torch.bfloat16), # Small input, 2 ranks + (4, 901, 32768, 4, torch.bfloat16), # Large input, 4 ranks + (8, 16384, 128, 8, torch.bfloat16), # Many small vectors, 8 ranks + (2, 5, 8, 2, torch.float16), # Small input, 2 ranks + (4, 901, 32768, 4, torch.float16), # Large input, 4 ranks + (8, 16384, 128, 8, torch.float16), # Many small vectors, 8 ranks +] + def get_available_gpu_count(): """Get the number of available GPUs.""" @@ -85,7 +99,8 @@ def test_moe_alltoall_single_gpu(num_tokens, vector_dim, num_experts, top_k): """Test MOE alltoall communication on single GPU.""" torch.cuda.set_device(0) # Create a random input tensor - dtypes = [torch.float16, torch.bfloat16, torch.int32, torch.uint8] + dtypes = [torch.bfloat16, torch.float16, torch.int32, torch.uint8] + hidden_state_index = 0 input_tensors = [ make_payload(num_tokens, vector_dim * (i + 1), dtype) for i, dtype in enumerate(dtypes) @@ -107,6 +122,7 @@ def test_moe_alltoall_single_gpu(num_tokens, vector_dim, num_experts, top_k): 1, num_tokens, payload_size_per_token, + input_tensors[0].shape[-1] * input_tensors[0].itemsize, ) mapping = Mapping(rank=0, world_size=1) moe_a2a = trtllm_moe_alltoall.MoeAlltoAll( @@ -131,41 +147,58 @@ def test_moe_alltoall_single_gpu(num_tokens, vector_dim, num_experts, top_k): output_tensor, _ = torch.sort(output_tensor.flatten(end_dim=1), dim=0) torch.testing.assert_close(output_tensor, input_tensor, atol=0, rtol=0) - moe_a2a._reset_workspace() + inplace_combine_tensor = moe_a2a.get_combine_payload_tensor_in_workspace( + num_tokens, + input_tensors[hidden_state_index].shape[-1], + input_tensors[hidden_state_index].dtype, + ) + # Copy first output tensor into inplace_combine_tensor + inplace_combine_tensor.copy_(output_tensors[hidden_state_index]) -@pytest.mark.parametrize("world_size,num_tokens,vector_dim", MULTI_RANK_PARAMS) -def test_moe_alltoall_multi_rank_single_gpu(world_size, num_tokens, vector_dim): - """Test MOE alltoall communication with multiple ranks on single GPU.""" - torch.cuda.set_device(0) - max_world_size = 8 - assert world_size <= max_world_size, ( - f"should run with world_size at most {max_world_size}" + output = moe_a2a.combine( + inplace_combine_tensor, num_tokens, payload_in_workspace=True ) - dtypes = [torch.float16, torch.bfloat16, torch.int32, torch.uint8] - # Create a random input tensor - input_tensors = [ - make_payload(num_tokens * world_size, vector_dim * (i + 1), dtype) - for i, dtype in enumerate(dtypes) - ] - - token_selected_experts = torch.randint( - 0, - world_size, - (num_tokens * world_size, 1), - dtype=torch.int32, - device=torch.device("cuda"), + # Should just be a direct copy for 1 GPU + torch.testing.assert_close( + output, input_tensors[hidden_state_index], atol=0, rtol=0 ) + +def dispatch_from_single_rank( + input_tensors, + token_selected_experts, + world_size, + num_experts, + num_tokens, + hidden_state_index=None, +): payloads = input_tensors total_payload_size_per_element = [x[0].numel() * x.itemsize for x in payloads] total_payload_size_per_element = sum(total_payload_size_per_element) + combine_size = 0 + if hidden_state_index is not None: + combine_size = ( + input_tensors[hidden_state_index].shape[-1] + * input_tensors[hidden_state_index].itemsize + ) + workspace_size = trtllm_moe_alltoall.moe_a2a_get_workspace_size_per_rank( - world_size, num_tokens * world_size, total_payload_size_per_element + world_size, + num_tokens * world_size, + total_payload_size_per_element, + combine_size, ) + print(f"world_size: {world_size}") + print(f"num_tokens: {num_tokens}") + print(f"world_size * num_tokens: {world_size * num_tokens}") + print(f"total_payload_size_per_element: {total_payload_size_per_element}") + print(f"combine_size: {combine_size}") + print(f"workspace_size: {workspace_size}") + all_workspaces = torch.zeros( world_size, workspace_size, dtype=torch.uint8, device=torch.device("cuda") ) @@ -186,6 +219,7 @@ def test_moe_alltoall_multi_rank_single_gpu(world_size, num_tokens, vector_dim): torch.cuda.synchronize() output_tensors = [] + combine_payload_offsets = [] # do alltoall in parallel cuda_streams_all_ranks = [torch.cuda.Stream() for _ in range(world_size)] for rank in range(world_size): @@ -196,18 +230,77 @@ def test_moe_alltoall_multi_rank_single_gpu(world_size, num_tokens, vector_dim): rank_token_selected_experts = token_selected_experts[ rank * num_tokens : (rank + 1) * num_tokens ] - output_tensors.append( - trtllm_moe_alltoall.moe_a2a_dispatch( - rank_token_selected_experts, - rank_payloads, + output, offset = trtllm_moe_alltoall.moe_a2a_dispatch( + rank_token_selected_experts, + rank_payloads, + all_workspaces, + metainfo[rank], + num_tokens, + ep_rank=rank, + ep_size=world_size, + top_k=rank_token_selected_experts.shape[-1], + num_experts=num_experts, + ) + output_tensors.append(output) + combine_payload_offsets.append(offset) + + for rank in range(world_size): + cuda_streams_all_ranks[rank].synchronize() + + torch.cuda.synchronize() + + return output_tensors, all_workspaces, metainfo, combine_payload_offsets + + +def sanitize_expert_ids_from_single_rank( + output_tensors, + expert_ids_index, + all_workspaces, + metainfo, + world_size, + invalid_expert_id, +): + for rank in range(world_size): + trtllm_moe_alltoall.moe_a2a_sanitize_expert_ids( + output_tensors[rank][expert_ids_index], + all_workspaces, + metainfo[rank], + rank, + invalid_expert_id, + ) + return output_tensors + + +def combine_from_single_rank( + combine_payload, + num_tokens, + top_k, + all_workspaces, + metainfo, + world_size, + combine_payload_offsets, + payload_in_workspace, +): + combine_results = [] + + torch.cuda.synchronize() + + cuda_streams_all_ranks = [torch.cuda.Stream() for _ in range(world_size)] + for rank in range(world_size): + with torch.cuda.stream(cuda_streams_all_ranks[rank]): + combine_results.append( + trtllm_moe_alltoall.moe_a2a_combine( + combine_payload[rank], + num_tokens, all_workspaces, metainfo[rank], num_tokens, ep_rank=rank, ep_size=world_size, - top_k=1, - num_experts=world_size, - )[0] + top_k=top_k, + combine_payload_offset=combine_payload_offsets[rank], + payload_in_workspace=payload_in_workspace, + ) ) for rank in range(world_size): @@ -215,13 +308,44 @@ def test_moe_alltoall_multi_rank_single_gpu(world_size, num_tokens, vector_dim): torch.cuda.synchronize() + return combine_results + + +@pytest.mark.parametrize("world_size,num_tokens,vector_dim", MULTI_RANK_PARAMS) +def test_moe_alltoall_multi_rank_single_gpu(world_size, num_tokens, vector_dim): + """Test MOE alltoall communication with multiple ranks on single GPU.""" + torch.cuda.set_device(0) + max_world_size = 8 + assert world_size <= max_world_size, ( + f"should run with world_size at most {max_world_size}" + ) + + dtypes = [torch.float16, torch.bfloat16, torch.int32, torch.uint8] + # Create a random input tensor + input_tensors = [ + make_payload(num_tokens * world_size, vector_dim * (i + 1), dtype) + for i, dtype in enumerate(dtypes) + ] + + token_selected_experts = torch.randint( + 0, + world_size, + (num_tokens * world_size, 1), + dtype=torch.int32, + device=torch.device("cuda"), + ) + + output_tensors, _, _, _ = dispatch_from_single_rank( + input_tensors, token_selected_experts, world_size, world_size, num_tokens + ) + for rank in range(world_size): # Get the indices where token_selected_experts == rank token_selected_experts_indices = ( token_selected_experts.flatten() == rank ).nonzero(as_tuple=False) - for actual, ref in zip(output_tensors[rank][:-1], payloads[:-1], strict=True): + for actual, ref in zip(output_tensors[rank], input_tensors, strict=True): # Select the tensors that arent all zeros actual = actual.flatten(end_dim=1) actual = actual[actual.any(dim=1)] @@ -231,7 +355,199 @@ def test_moe_alltoall_multi_rank_single_gpu(world_size, num_tokens, vector_dim): torch.testing.assert_close(actual, ref, atol=0, rtol=0) -# TODO Add a combine test +@pytest.mark.parametrize("world_size,num_tokens", SANITIZE_PARAMS) +def test_sanitize_expert_ids(world_size, num_tokens): + torch.cuda.set_device(0) + max_world_size = 8 + assert world_size <= max_world_size, ( + f"should run with world_size at most {max_world_size}" + ) + + flags = torch.ones( + num_tokens * world_size, 1, dtype=torch.bool, device=torch.device("cuda") + ) + token_selected_experts = torch.randint( + 0, + world_size, + (num_tokens * world_size, 1), + dtype=torch.int32, + device=torch.device("cuda"), + ) + + output_tensors, all_workspaces, metainfo, _ = dispatch_from_single_rank( + [token_selected_experts, flags], + token_selected_experts, + world_size, + world_size, + num_tokens, + ) + + # Clone since the tensors are modified in place + expected_output_tensors = [(x[0].clone(), x[1].clone()) for x in output_tensors] + output_tensors = sanitize_expert_ids_from_single_rank( + output_tensors, 0, all_workspaces, metainfo, world_size, -3 + ) + + for rank, (sanitized, raw) in enumerate( + zip(output_tensors, expected_output_tensors, strict=True) + ): + raw_tensor, flag_tensor = raw + valid_mask = (raw_tensor == rank) & flag_tensor + raw_tensor[~valid_mask] = -3 + torch.testing.assert_close(sanitized[0], raw_tensor, atol=0, rtol=0) + + +def fake_moe( + hidden_states, + token_selected_experts, + num_experts, + is_ep=False, + ep_rank=None, + num_experts_per_rank=None, +): + target_shape = hidden_states.shape + hidden_states = hidden_states.flatten(end_dim=-2) + token_selected_experts = token_selected_experts.flatten(end_dim=-2) + num_tokens, _ = hidden_states.shape + _, top_k = token_selected_experts.shape + + if is_ep: + assert ep_rank is not None and num_experts_per_rank is not None + + # Initialize output + processed_states = torch.zeros_like(hidden_states) + + # Process each token + for token_idx in range(num_tokens): + results = [] + # For each expert selected for this token/ + for k in range(top_k): + expert_id = token_selected_experts[token_idx, k].item() + if is_ep and not ( + expert_id >= ep_rank * num_experts_per_rank + and expert_id < (ep_rank + 1) * num_experts_per_rank + ): + continue + + scale = (expert_id + 1.0) / num_experts + 0.5 + results.append(hidden_states[token_idx] * scale) + + # Summing the results after is closer to the actual implementation as we do a tree reduction. + if results: + processed_states[token_idx] = torch.sum( + torch.stack(results, dim=0), dim=0, dtype=torch.float32 + ).to(processed_states.dtype) + + print(f"processed_states shape: {processed_states.shape}") + print(f"target_shape: {target_shape}") + return processed_states.view(target_shape) + + +@pytest.mark.parametrize("world_size,num_tokens,vector_dim,top_k,dtype", COMBINE_PARAMS) +def test_moe_combine_multi_rank_single_gpu( + world_size, num_tokens, vector_dim, top_k, dtype +): + torch.cuda.set_device(0) + max_world_size = 8 + assert world_size <= max_world_size, ( + f"should run with world_size at most {max_world_size}" + ) + + num_experts = world_size * top_k + + token_selected_experts_index = 0 + hidden_state_index = 1 + + token_selected_experts = torch.empty( + num_tokens * world_size, top_k, dtype=torch.int32, device=torch.device("cuda") + ) + + for i in range(num_tokens * world_size): + # Include one extra expert to represent invalid expert IDs + token_selected_experts[i] = torch.randperm( + num_experts, dtype=torch.int32, device=torch.device("cuda") + )[:top_k] + token_selected_experts = token_selected_experts.contiguous() + + # Create a random input tensor + reference_tensor = make_payload(num_tokens * world_size, vector_dim, dtype) + input_tensors = [ + token_selected_experts, + reference_tensor, + make_payload( + num_tokens * world_size, 1, torch.uint8 + ), # Some extra payload to test combine alignment logic + ] + + output_tensors, all_workspaces, metainfo, combine_payload_offsets = ( + dispatch_from_single_rank( + input_tensors, + token_selected_experts, + world_size, + num_experts, + num_tokens, + hidden_state_index, + ) + ) + + # Sanitize expert ids for fake_moe + output_tensors = sanitize_expert_ids_from_single_rank( + output_tensors, + token_selected_experts_index, + all_workspaces, + metainfo, + world_size, + -1, + ) + + inplace_combine_tensors = [] + for rank in range(world_size): + inplace_combine_tensors.append( + trtllm_moe_alltoall.moe_a2a_wrap_payload_tensor_in_workspace( + all_workspaces, + [world_size, num_tokens], + combine_payload_offsets[rank], + combine_payload_offsets[rank] + + world_size * num_tokens * vector_dim * dtype.itemsize, + dtype, + ) + ) + + for rank in range(world_size): + inplace_combine_tensors[rank].copy_( + fake_moe( + output_tensors[rank][hidden_state_index], + output_tensors[rank][token_selected_experts_index], + num_experts, + is_ep=True, + ep_rank=rank, + num_experts_per_rank=num_experts // world_size, + ) + ) + + combine_results = combine_from_single_rank( + inplace_combine_tensors, + num_tokens, + top_k, + all_workspaces, + metainfo, + world_size, + combine_payload_offsets, + payload_in_workspace=True, + ) + + reference_result = fake_moe( + input_tensors[hidden_state_index], token_selected_experts, num_experts + ) + + for rank in range(world_size): + torch.testing.assert_close( + combine_results[rank], + reference_result[rank * num_tokens : (rank + 1) * num_tokens], + atol=1e-2, + rtol=1e-2, + ) + if __name__ == "__main__": pytest.main([__file__, "-v"]) From c4ee3c29eaac61f8da3c0b2d9dcd3c528b799398 Mon Sep 17 00:00:00 2001 From: djns99 <40156487+djns99@users.noreply.github.com> Date: Fri, 28 Nov 2025 13:57:45 +1300 Subject: [PATCH 19/25] Fix logic for inplace combine workspace setup --- csrc/trtllm_moe_a2a.cu | 4 ++-- tests/comm/test_trtllm_moe_alltoall.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/csrc/trtllm_moe_a2a.cu b/csrc/trtllm_moe_a2a.cu index a87875baf0..457e98deac 100644 --- a/csrc/trtllm_moe_a2a.cu +++ b/csrc/trtllm_moe_a2a.cu @@ -316,8 +316,8 @@ Tensor moeA2ACombineOp(TensorView payload, int64_t localNumTokens, TensorView wo if (payloadInWorkspace) { auto* expectedPtr = rankWorkspacePtr + combinePayloadOffset; TVM_FFI_ICHECK(payload.data_ptr() == expectedPtr) - << "payload_in_workspace is True but tensor pointer mismatch: " << payload.data_ptr() - << " != " << expectedPtr; + << "payload_in_workspace is True but tensor pointer mismatch: " << (void*)payload.data_ptr() + << " != " << (void*)expectedPtr; } Tensor output = diff --git a/tests/comm/test_trtllm_moe_alltoall.py b/tests/comm/test_trtllm_moe_alltoall.py index 52e709f590..d54719fa5b 100644 --- a/tests/comm/test_trtllm_moe_alltoall.py +++ b/tests/comm/test_trtllm_moe_alltoall.py @@ -504,7 +504,7 @@ def test_moe_combine_multi_rank_single_gpu( for rank in range(world_size): inplace_combine_tensors.append( trtllm_moe_alltoall.moe_a2a_wrap_payload_tensor_in_workspace( - all_workspaces, + all_workspaces[rank, :], [world_size, num_tokens], combine_payload_offsets[rank], combine_payload_offsets[rank] From febd13280bbf6c7d9378f84b510c24cff0c682b0 Mon Sep 17 00:00:00 2001 From: djns99 <40156487+djns99@users.noreply.github.com> Date: Fri, 28 Nov 2025 14:50:19 +1300 Subject: [PATCH 20/25] Limit num tokens to allow combine to successfully run on 1 GPU --- tests/comm/test_trtllm_moe_alltoall.py | 63 ++++++++++---------------- 1 file changed, 25 insertions(+), 38 deletions(-) diff --git a/tests/comm/test_trtllm_moe_alltoall.py b/tests/comm/test_trtllm_moe_alltoall.py index d54719fa5b..53ef955b20 100644 --- a/tests/comm/test_trtllm_moe_alltoall.py +++ b/tests/comm/test_trtllm_moe_alltoall.py @@ -38,42 +38,35 @@ def setup_test_environment(): MULTI_RANK_PARAMS = [ (2, 5, 8), # Small input, 2 ranks - (4, 901, 32768), # Large input, 4 ranks - (8, 16384, 128), # Many small vectors, 8 ranks + (4, 32, 32768), # Large input, 4 ranks + (8, 16, 2048), # Medium input, 8 ranks ] SANITIZE_PARAMS = [ - (2, 5), # Few tokens, 2 ranks - (4, 901), # Many tokens, 4 ranks + (2, 64), # 2 ranks + (4, 32), # 4 ranks + (8, 16), # 8 ranks ] COMBINE_PARAMS = [ - (2, 5, 8, 2, torch.bfloat16), # Small input, 2 ranks - (4, 901, 32768, 4, torch.bfloat16), # Large input, 4 ranks - (8, 16384, 128, 8, torch.bfloat16), # Many small vectors, 8 ranks - (2, 5, 8, 2, torch.float16), # Small input, 2 ranks - (4, 901, 32768, 4, torch.float16), # Large input, 4 ranks - (8, 16384, 128, 8, torch.float16), # Many small vectors, 8 ranks + (2, 64, 8, 2, torch.bfloat16), # Small input, 2 ranks + (4, 32, 32768, 4, torch.bfloat16), # Large input, 4 ranks + (8, 16, 2048, 8, torch.bfloat16), # Medium input, 8 ranks + (2, 64, 8, 2, torch.float16), # Small input, 2 ranks + (4, 32, 32768, 4, torch.float16), # Large input, 4 ranks + (8, 16, 2048, 8, torch.float16), # Medium input, 8 ranks ] -def get_available_gpu_count(): - """Get the number of available GPUs.""" - if not torch.cuda.is_available(): - return 0 - return torch.cuda.device_count() - - -def requires_gpus(min_gpus): - """Decorator to skip test if insufficient GPUs are available.""" - - def decorator(func): - return pytest.mark.skipif( - get_available_gpu_count() < min_gpus, - reason=f"Requires at least {min_gpus} GPUs, but only {get_available_gpu_count()} available", - )(func) - - return decorator +# This is a hack to ensure we get forward progress when running multiple kernels on a single GPU +def check_sufficient_sm_count(num_tokens, world_size): + if ( + num_tokens * world_size + > torch.cuda.get_device_properties(0).multi_processor_count + ): + pytest.skip( + f"Requires at least {num_tokens * world_size} SMs, but only {torch.cuda.get_device_properties(0).multi_processor_count} available" + ) def make_payload(num_tokens, vector_dim, dtype): @@ -192,13 +185,6 @@ def dispatch_from_single_rank( combine_size, ) - print(f"world_size: {world_size}") - print(f"num_tokens: {num_tokens}") - print(f"world_size * num_tokens: {world_size * num_tokens}") - print(f"total_payload_size_per_element: {total_payload_size_per_element}") - print(f"combine_size: {combine_size}") - print(f"workspace_size: {workspace_size}") - all_workspaces = torch.zeros( world_size, workspace_size, dtype=torch.uint8, device=torch.device("cuda") ) @@ -315,6 +301,7 @@ def combine_from_single_rank( def test_moe_alltoall_multi_rank_single_gpu(world_size, num_tokens, vector_dim): """Test MOE alltoall communication with multiple ranks on single GPU.""" torch.cuda.set_device(0) + check_sufficient_sm_count(num_tokens, world_size) max_world_size = 8 assert world_size <= max_world_size, ( f"should run with world_size at most {max_world_size}" @@ -358,6 +345,7 @@ def test_moe_alltoall_multi_rank_single_gpu(world_size, num_tokens, vector_dim): @pytest.mark.parametrize("world_size,num_tokens", SANITIZE_PARAMS) def test_sanitize_expert_ids(world_size, num_tokens): torch.cuda.set_device(0) + check_sufficient_sm_count(num_tokens, world_size) max_world_size = 8 assert world_size <= max_world_size, ( f"should run with world_size at most {max_world_size}" @@ -438,8 +426,6 @@ def fake_moe( torch.stack(results, dim=0), dim=0, dtype=torch.float32 ).to(processed_states.dtype) - print(f"processed_states shape: {processed_states.shape}") - print(f"target_shape: {target_shape}") return processed_states.view(target_shape) @@ -448,6 +434,7 @@ def test_moe_combine_multi_rank_single_gpu( world_size, num_tokens, vector_dim, top_k, dtype ): torch.cuda.set_device(0) + check_sufficient_sm_count(num_tokens, world_size) max_world_size = 8 assert world_size <= max_world_size, ( f"should run with world_size at most {max_world_size}" @@ -544,8 +531,8 @@ def test_moe_combine_multi_rank_single_gpu( torch.testing.assert_close( combine_results[rank], reference_result[rank * num_tokens : (rank + 1) * num_tokens], - atol=1e-2, - rtol=1e-2, + atol=1.5e-2, + rtol=1.5e-2, ) From 94df8453ce616e5f2b8f2bb3e27a2d60b447d11e Mon Sep 17 00:00:00 2001 From: djns99 <40156487+djns99@users.noreply.github.com> Date: Fri, 28 Nov 2025 14:54:56 +1300 Subject: [PATCH 21/25] Unify naming --- ...rtllm_moe_a2a.cu => trtllm_moe_alltoall.cu} | 0 flashinfer/aot.py | 4 ++-- flashinfer/comm/trtllm_moe_alltoall.py | 18 +++++++++--------- flashinfer/jit/__init__.py | 2 +- flashinfer/jit/comm.py | 6 +++--- 5 files changed, 15 insertions(+), 15 deletions(-) rename csrc/{trtllm_moe_a2a.cu => trtllm_moe_alltoall.cu} (100%) diff --git a/csrc/trtllm_moe_a2a.cu b/csrc/trtllm_moe_alltoall.cu similarity index 100% rename from csrc/trtllm_moe_a2a.cu rename to csrc/trtllm_moe_alltoall.cu diff --git a/flashinfer/aot.py b/flashinfer/aot.py index 84bc0ca199..271080f2bd 100644 --- a/flashinfer/aot.py +++ b/flashinfer/aot.py @@ -512,14 +512,14 @@ def gen_all_modules( from .jit.comm import gen_nvshmem_module from .jit.comm import gen_comm_alltoall_module from .jit.comm import gen_trtllm_mnnvl_comm_module - from .jit.comm import gen_mnnvl_a2a_module + from .jit.comm import gen_mnnvl_moe_alltoall_module jit_specs.append(gen_nvshmem_module()) jit_specs.append(gen_comm_alltoall_module()) if has_sm100: jit_specs.append(gen_trtllm_comm_module()) jit_specs.append(gen_trtllm_mnnvl_comm_module()) - jit_specs.append(gen_mnnvl_a2a_module()) + jit_specs.append(gen_mnnvl_moe_alltoall_module()) jit_specs.append(gen_vllm_comm_module()) if add_misc: diff --git a/flashinfer/comm/trtllm_moe_alltoall.py b/flashinfer/comm/trtllm_moe_alltoall.py index 169c4de622..1fe69e5c16 100644 --- a/flashinfer/comm/trtllm_moe_alltoall.py +++ b/flashinfer/comm/trtllm_moe_alltoall.py @@ -16,7 +16,7 @@ from .mnnvl import MnnvlMemory, MnnvlConfig from .mapping import Mapping -from ..jit.comm import gen_mnnvl_a2a_module +from ..jit.comm import gen_mnnvl_moe_alltoall_module from ..utils import register_custom_op @@ -30,9 +30,9 @@ class _A2AState: @functools.cache -def get_mnnvl_a2a_module(): +def get_mnnvl_moe_alltoall_module(): """Get or build the MNNVL A2A JIT module.""" - module = gen_mnnvl_a2a_module().build_and_load() + module = gen_mnnvl_moe_alltoall_module().build_and_load() @register_custom_op( "flashinfer::moe_a2a_initialize", @@ -201,7 +201,7 @@ def moe_a2a_initialize( ep_size: int, max_num_tokens: int, ): - return get_mnnvl_a2a_module().moe_a2a_initialize( + return get_mnnvl_moe_alltoall_module().moe_a2a_initialize( workspace, ep_rank, ep_size, max_num_tokens ) @@ -250,7 +250,7 @@ def moe_a2a_dispatch( num_experts: int, ): recv_offsets, recv_sizes, combine_payload_offset = ( - get_mnnvl_a2a_module().moe_a2a_dispatch( + get_mnnvl_moe_alltoall_module().moe_a2a_dispatch( token_selected_experts, input_payloads, workspace, @@ -293,7 +293,7 @@ def moe_a2a_combine( combine_payload_offset: int, payload_in_workspace: bool = False, ) -> torch.Tensor: - return get_mnnvl_a2a_module().moe_a2a_combine( + return get_mnnvl_moe_alltoall_module().moe_a2a_combine( payload, local_num_tokens, workspace, @@ -314,7 +314,7 @@ def moe_a2a_sanitize_expert_ids( ep_rank: int, invalid_expert_id: int, ): - return get_mnnvl_a2a_module().moe_a2a_sanitize_expert_ids( + return get_mnnvl_moe_alltoall_module().moe_a2a_sanitize_expert_ids( expert_ids, workspace, metainfo, ep_rank, invalid_expert_id ) @@ -325,7 +325,7 @@ def moe_a2a_get_workspace_size_per_rank( total_dispatch_payload_size_per_token: int, combine_payload_size_per_token: int, ): - return get_mnnvl_a2a_module().moe_a2a_get_workspace_size_per_rank( + return get_mnnvl_moe_alltoall_module().moe_a2a_get_workspace_size_per_rank( ep_size, max_num_tokens, total_dispatch_payload_size_per_token, @@ -390,7 +390,7 @@ def get_workspace( def _init_constants(cls): """Initialize constants from C++ if not already done.""" if cls._METAINFO_INDEX is None: - module = get_mnnvl_a2a_module() + module = get_mnnvl_moe_alltoall_module() names, values = module.moe_a2a_get_metainfo_index_pairs() # Convert TVM arrays to Python and build dictionary diff --git a/flashinfer/jit/__init__.py b/flashinfer/jit/__init__.py index bd1934ff62..bf6ea042c4 100644 --- a/flashinfer/jit/__init__.py +++ b/flashinfer/jit/__init__.py @@ -77,7 +77,7 @@ from .comm import gen_trtllm_comm_module as gen_trtllm_comm_module from .comm import gen_vllm_comm_module as gen_vllm_comm_module from .comm import gen_nvshmem_module as gen_nvshmem_module -from .comm import gen_mnnvl_a2a_module as gen_mnnvl_a2a_module +from .comm import gen_mnnvl_moe_alltoall_module as gen_mnnvl_moe_alltoall_module from .dsv3_optimizations import ( gen_dsv3_router_gemm_module as gen_dsv3_router_gemm_module, ) diff --git a/flashinfer/jit/comm.py b/flashinfer/jit/comm.py index 4c350ddf22..e4fdf88f86 100644 --- a/flashinfer/jit/comm.py +++ b/flashinfer/jit/comm.py @@ -80,11 +80,11 @@ def gen_vllm_comm_module() -> JitSpec: ) -def gen_mnnvl_a2a_module() -> JitSpec: +def gen_mnnvl_moe_alltoall_module() -> JitSpec: return gen_jit_spec( - "mnnvl_a2a", + "mnnvl_moe_alltoall", [ - jit_env.FLASHINFER_CSRC_DIR / "trtllm_moe_a2a.cu", + jit_env.FLASHINFER_CSRC_DIR / "trtllm_moe_alltoall.cu", jit_env.FLASHINFER_CSRC_DIR / "nv_internal" / "tensorrt_llm" From 00d38ccff074b9951d9e89fc084c44b7793ae9f4 Mon Sep 17 00:00:00 2001 From: djns99 <40156487+djns99@users.noreply.github.com> Date: Fri, 28 Nov 2025 15:24:41 +1300 Subject: [PATCH 22/25] Add test for payload not in the workspace and fix coderabbit comments --- csrc/trtllm_moe_alltoall.cu | 5 ++- flashinfer/comm/trtllm_moe_alltoall.py | 42 +++++++++++++++++---- tests/comm/test_trtllm_moe_alltoall.py | 51 +++++++++++++++++--------- 3 files changed, 72 insertions(+), 26 deletions(-) diff --git a/csrc/trtllm_moe_alltoall.cu b/csrc/trtllm_moe_alltoall.cu index 457e98deac..2852a6ccd4 100644 --- a/csrc/trtllm_moe_alltoall.cu +++ b/csrc/trtllm_moe_alltoall.cu @@ -127,7 +127,6 @@ Tuple, Array, int64_t> moeA2ADispatchOp( TensorView metainfo, int64_t runtimeMaxTokensPerRank, int64_t epRank, int64_t epSize, int64_t topK, int64_t numExperts) { using tl_throughput::PayloadDescriptor; - fflush(stdout); CHECK_INPUT(tokenSelectedExperts); CHECK_INPUT_TYPE(tokenSelectedExperts, dl_int32); @@ -388,6 +387,10 @@ void moeA2ASanitizeExpertIdsOp(TensorView expertIds, TensorView workspace, Tenso static_cast(expertIds.data_ptr()), recvCounters, static_cast(invalidExpertId), static_cast(epSize), static_cast(runtimeMaxTokensPerRank), static_cast(topK), get_current_stream()); + + auto err = cudaGetLastError(); + TVM_FFI_ICHECK(err == cudaSuccess) + << "moe_a2a_sanitize_expert_ids launch failed: " << cudaGetErrorString(err); } // Expose metainfo index constants for Python access diff --git a/flashinfer/comm/trtllm_moe_alltoall.py b/flashinfer/comm/trtllm_moe_alltoall.py index 1fe69e5c16..72cdfb028c 100644 --- a/flashinfer/comm/trtllm_moe_alltoall.py +++ b/flashinfer/comm/trtllm_moe_alltoall.py @@ -178,6 +178,18 @@ def moe_a2a_get_workspace_size_per_rank( total_dispatch_payload_size_per_token: int, combine_payload_size_per_token: int, ): + """ + Get the workspace size per rank for the MoeAlltoAll operation. + + Args: + ep_size: Total expert parallel size + max_num_tokens: Maximum number of tokens across all ranks + total_dispatch_payload_size_per_token: The size of the payload per token in the dispatch phase. This should be the sum of all payloads tensors. + combine_payload_size_per_token: The size of the payload per token in the combine phase. + + Returns: + workspace_size_per_rank: Size of the workspace per rank in bytes + """ return module.moe_a2a_get_workspace_size_per_rank( ep_size, max_num_tokens, @@ -218,15 +230,13 @@ def moe_a2a_wrap_payload_tensor_in_workspace( Args: workspace: [ep_size, size_per_rank] workspace tensor - ep_rank: Current expert parallel rank - ep_size: Total expert parallel size - runtime_max_tokens_per_rank: Max tokens per rank in this batch - total_size: Total size of the payload - offset: Offset from dispatch - dtype: Data type for the tensor + leading_shape: The leading shape to wrap the tensor with + slice_start: The start of the slice in the workspace + slice_end: The end of the slice in the workspace + dtype: Data type for the output tensor Returns: - tensor: [ep_size * max_tokens, hidden_size] workspace-backed tensor + tensor: [leading_shape, *] workspace-backed tensor """ workspace_base = workspace.view(-1).view(dtype=torch.uint8) assert slice_end <= workspace.numel(), ( @@ -249,6 +259,24 @@ def moe_a2a_dispatch( top_k: int, num_experts: int, ): + """ + Dispatch tokens and payloads to expert ranks. + + Args: + token_selected_experts: [local_num_tokens, top_k] int32 tensor + input_payloads: List of [local_num_tokens, *] tensors to dispatch + workspace: [ep_size, size_per_rank] workspace tensor + metainfo: Metadata tensor from initialize + runtime_max_tokens_per_rank: Max tokens per rank in this batch + ep_rank: Current expert parallel rank + ep_size: Total expert parallel size + top_k: Number of experts per token + num_experts: Total number of experts + + Returns: + output_payloads: List of payloads for this rank, backed by data in the workspace + combine_payload_offset: The offset to place the combine payload in the workspace + """ recv_offsets, recv_sizes, combine_payload_offset = ( get_mnnvl_moe_alltoall_module().moe_a2a_dispatch( token_selected_experts, diff --git a/tests/comm/test_trtllm_moe_alltoall.py b/tests/comm/test_trtllm_moe_alltoall.py index 53ef955b20..b7f4173dea 100644 --- a/tests/comm/test_trtllm_moe_alltoall.py +++ b/tests/comm/test_trtllm_moe_alltoall.py @@ -49,12 +49,14 @@ def setup_test_environment(): ] COMBINE_PARAMS = [ - (2, 64, 8, 2, torch.bfloat16), # Small input, 2 ranks - (4, 32, 32768, 4, torch.bfloat16), # Large input, 4 ranks - (8, 16, 2048, 8, torch.bfloat16), # Medium input, 8 ranks - (2, 64, 8, 2, torch.float16), # Small input, 2 ranks - (4, 32, 32768, 4, torch.float16), # Large input, 4 ranks - (8, 16, 2048, 8, torch.float16), # Medium input, 8 ranks + (2, 64, 8, 2, torch.bfloat16, True), # Small input, 2 ranks + (4, 32, 32768, 4, torch.bfloat16, True), # Large input, 4 ranks + (8, 16, 2048, 8, torch.bfloat16, True), # Medium input, 8 ranks + (8, 16, 2048, 8, torch.bfloat16, False), # Medium input, 8 ranks + (2, 64, 8, 2, torch.float16, True), # Small input, 2 ranks + (4, 32, 32768, 4, torch.float16, True), # Large input, 4 ranks + (8, 16, 2048, 8, torch.float16, True), # Medium input, 8 ranks + (8, 16, 2048, 8, torch.float16, False), # Medium input, 8 ranks ] @@ -429,9 +431,11 @@ def fake_moe( return processed_states.view(target_shape) -@pytest.mark.parametrize("world_size,num_tokens,vector_dim,top_k,dtype", COMBINE_PARAMS) +@pytest.mark.parametrize( + "world_size,num_tokens,vector_dim,top_k,dtype,payload_in_workspace", COMBINE_PARAMS +) def test_moe_combine_multi_rank_single_gpu( - world_size, num_tokens, vector_dim, top_k, dtype + world_size, num_tokens, vector_dim, top_k, dtype, payload_in_workspace ): torch.cuda.set_device(0) check_sufficient_sm_count(num_tokens, world_size) @@ -489,16 +493,27 @@ def test_moe_combine_multi_rank_single_gpu( inplace_combine_tensors = [] for rank in range(world_size): - inplace_combine_tensors.append( - trtllm_moe_alltoall.moe_a2a_wrap_payload_tensor_in_workspace( - all_workspaces[rank, :], - [world_size, num_tokens], - combine_payload_offsets[rank], - combine_payload_offsets[rank] - + world_size * num_tokens * vector_dim * dtype.itemsize, - dtype, + if payload_in_workspace: + inplace_combine_tensors.append( + trtllm_moe_alltoall.moe_a2a_wrap_payload_tensor_in_workspace( + all_workspaces[rank, :], + [world_size, num_tokens], + combine_payload_offsets[rank], + combine_payload_offsets[rank] + + world_size * num_tokens * vector_dim * dtype.itemsize, + dtype, + ) + ) + else: + inplace_combine_tensors.append( + torch.empty( + world_size, + num_tokens, + vector_dim, + dtype=dtype, + device=torch.device("cuda"), + ) ) - ) for rank in range(world_size): inplace_combine_tensors[rank].copy_( @@ -520,7 +535,7 @@ def test_moe_combine_multi_rank_single_gpu( metainfo, world_size, combine_payload_offsets, - payload_in_workspace=True, + payload_in_workspace=payload_in_workspace, ) reference_result = fake_moe( From 3e04d84e6f170a9d9cef7252cc7402fe566612ac Mon Sep 17 00:00:00 2001 From: djns99 <40156487+djns99@users.noreply.github.com> Date: Fri, 28 Nov 2025 15:27:27 +1300 Subject: [PATCH 23/25] Update comm.rst --- docs/api/comm.rst | 3 +++ 1 file changed, 3 insertions(+) diff --git a/docs/api/comm.rst b/docs/api/comm.rst index f852073ae4..210261b0a4 100644 --- a/docs/api/comm.rst +++ b/docs/api/comm.rst @@ -142,3 +142,6 @@ MNNVL A2A (Throughput Backend) moe_a2a_dispatch moe_a2a_combine moe_a2a_sanitize_expert_ids + moe_a2a_get_metainfo_index_pairs + moe_a2a_get_workspace_size_per_rank + moe_a2a_wrap_payload_tensor_in_workspace From eaa5eb873d709a125b4b556c83b87d6d379fb3d0 Mon Sep 17 00:00:00 2001 From: djns99 <40156487+djns99@users.noreply.github.com> Date: Fri, 28 Nov 2025 15:38:52 +1300 Subject: [PATCH 24/25] Fix coderabbit nits --- csrc/trtllm_moe_alltoall.cu | 2 +- flashinfer/comm/trtllm_moe_alltoall.py | 8 +++----- tests/comm/test_mnnvl_moe_alltoall.py | 11 ++++++----- tests/comm/test_trtllm_moe_alltoall.py | 4 ++-- 4 files changed, 12 insertions(+), 13 deletions(-) diff --git a/csrc/trtllm_moe_alltoall.cu b/csrc/trtllm_moe_alltoall.cu index 2852a6ccd4..2dd081534f 100644 --- a/csrc/trtllm_moe_alltoall.cu +++ b/csrc/trtllm_moe_alltoall.cu @@ -188,7 +188,7 @@ Tuple, Array, int64_t> moeA2ADispatchOp( TVM_FFI_ICHECK(totalBytesNeeded % elementSize == 0) << "Misaligned payload buffer " << i << " with element size " << elementSize - << ". Consider putting ordering payloads by minimum element size"; + << ". Consider reordering payloads by largest to smallest element size"; } auto* workspaceBase = static_cast(workspace.data_ptr()); diff --git a/flashinfer/comm/trtllm_moe_alltoall.py b/flashinfer/comm/trtllm_moe_alltoall.py index 72cdfb028c..9ff616443d 100644 --- a/flashinfer/comm/trtllm_moe_alltoall.py +++ b/flashinfer/comm/trtllm_moe_alltoall.py @@ -5,8 +5,6 @@ supporting multiple payloads per collective operation. """ -# TODO Review - from dataclasses import dataclass from types import SimpleNamespace from typing import Optional @@ -648,9 +646,9 @@ def get_combine_payload_tensor_in_workspace( __all__ = [ "MoeAlltoAll", - "moe_a2a_initialize", - "moe_a2a_dispatch", "moe_a2a_combine", - "moe_a2a_sanitize_expert_ids", + "moe_a2a_dispatch", "moe_a2a_get_workspace_size_per_rank", + "moe_a2a_initialize", + "moe_a2a_sanitize_expert_ids", ] diff --git a/tests/comm/test_mnnvl_moe_alltoall.py b/tests/comm/test_mnnvl_moe_alltoall.py index a006844bee..0f2a000a82 100644 --- a/tests/comm/test_mnnvl_moe_alltoall.py +++ b/tests/comm/test_mnnvl_moe_alltoall.py @@ -38,17 +38,18 @@ def safe_run(func, *args, **kwargs): comm = MPI.COMM_WORLD try: func(*args, **kwargs) - except MPIExit as e: - raise e - except Exception as e: + except MPIExit: + raise + except Exception: traceback.print_exc() comm.allgather(True) - raise e + raise @pytest.fixture(autouse=True) def setup_test(): torch.manual_seed(0x1234) + yield def compute_target_rank_id(expert_id, num_experts_per_rank): @@ -154,7 +155,7 @@ def fake_moe( # Process each token for token_idx in range(num_tokens): results = [] - # For each expert selected for this token/ + # For each expert selected for this token for k in range(top_k): expert_id = token_selected_experts[token_idx, k].item() if is_ep: diff --git a/tests/comm/test_trtllm_moe_alltoall.py b/tests/comm/test_trtllm_moe_alltoall.py index b7f4173dea..c089657878 100644 --- a/tests/comm/test_trtllm_moe_alltoall.py +++ b/tests/comm/test_trtllm_moe_alltoall.py @@ -24,7 +24,7 @@ @pytest.fixture(autouse=True, scope="session") def setup_test_environment(): - """Set up test environment and warm up JIT compilation.""" + """Set up torch seed for deterministic tests.""" torch.manual_seed(0xD5) yield @@ -410,7 +410,7 @@ def fake_moe( # Process each token for token_idx in range(num_tokens): results = [] - # For each expert selected for this token/ + # For each expert selected for this token for k in range(top_k): expert_id = token_selected_experts[token_idx, k].item() if is_ep and not ( From a51b1ea8e406e96b50ea92dd88d028248652af2e Mon Sep 17 00:00:00 2001 From: djns99 <40156487+djns99@users.noreply.github.com> Date: Fri, 28 Nov 2025 15:41:49 +1300 Subject: [PATCH 25/25] Properly export all functions --- flashinfer/comm/__init__.py | 5 ++++- flashinfer/comm/trtllm_moe_alltoall.py | 1 + 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/flashinfer/comm/__init__.py b/flashinfer/comm/__init__.py index 496050cd00..98d77550a8 100644 --- a/flashinfer/comm/__init__.py +++ b/flashinfer/comm/__init__.py @@ -44,11 +44,14 @@ from .trtllm_moe_alltoall import moe_a2a_combine as moe_a2a_combine from .trtllm_moe_alltoall import moe_a2a_dispatch as moe_a2a_dispatch from .trtllm_moe_alltoall import moe_a2a_initialize as moe_a2a_initialize +from .trtllm_moe_alltoall import ( + moe_a2a_get_workspace_size_per_rank as moe_a2a_get_workspace_size_per_rank, +) from .trtllm_moe_alltoall import ( moe_a2a_sanitize_expert_ids as moe_a2a_sanitize_expert_ids, ) from .trtllm_moe_alltoall import ( - moe_a2a_get_workspace_size_per_rank as moe_a2a_get_workspace_size_per_rank, + moe_a2a_wrap_payload_tensor_in_workspace as moe_a2a_wrap_payload_tensor_in_workspace, ) # from .mnnvl import MnnvlMemory, MnnvlMoe, MoEAlltoallInfo diff --git a/flashinfer/comm/trtllm_moe_alltoall.py b/flashinfer/comm/trtllm_moe_alltoall.py index 9ff616443d..323741642e 100644 --- a/flashinfer/comm/trtllm_moe_alltoall.py +++ b/flashinfer/comm/trtllm_moe_alltoall.py @@ -651,4 +651,5 @@ def get_combine_payload_tensor_in_workspace( "moe_a2a_get_workspace_size_per_rank", "moe_a2a_initialize", "moe_a2a_sanitize_expert_ids", + "moe_a2a_wrap_payload_tensor_in_workspace", ]