From b5fe7bbd9719a2d4af85591747bb54bb12e79da5 Mon Sep 17 00:00:00 2001 From: kling Date: Wed, 12 Nov 2025 17:22:25 +0800 Subject: [PATCH] Support per-expert-overlap SBO. --- csrc/deep_ep.cpp | 132 ++++++++++++------ csrc/deep_ep.hpp | 17 ++- csrc/event.hpp | 4 +- csrc/kernels/api.cuh | 10 +- csrc/kernels/internode_ll.cu | 264 ++++++++++++++++++++++++++--------- deep_ep/buffer.py | 9 +- tests/test_low_latency.py | 57 ++++++-- 7 files changed, 356 insertions(+), 137 deletions(-) diff --git a/csrc/deep_ep.cpp b/csrc/deep_ep.cpp index 954c9ffb..a3d7599c 100644 --- a/csrc/deep_ep.cpp +++ b/csrc/deep_ep.cpp @@ -1524,17 +1524,18 @@ void Buffer::clean_low_latency_buffer(int num_max_dispatch_tokens_per_rank, int #endif } -std::tuple, + std::optional, + std::optional, + std::optional, std::optional, - torch::Tensor, - torch::Tensor, - torch::Tensor, std::optional, std::optional>> Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_idx, const std::optional& cumulative_local_expert_recv_stats, const std::optional& dispatch_wait_recv_cost_stats, + bool use_expert_overlap, int num_rounds, int round_id, int send_num_sms, int recv_num_sms, bool hook_use_comm_stream, int num_max_dispatch_tokens_per_rank, int num_experts, bool use_fp8, @@ -1566,6 +1567,12 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, EP_HOST_ASSERT(dispatch_wait_recv_cost_stats->size(0) == num_ranks); } + if (use_expert_overlap) { + EP_HOST_ASSERT(num_rounds >= 1 && num_experts / num_ranks % num_rounds == 0 && return_recv_hook); + if (send_num_sms == -1) send_num_sms = num_device_sms; + if (recv_num_sms == -1) recv_num_sms = num_device_sms; + } + auto num_tokens = static_cast(x.size(0)), hidden = static_cast(x.size(1)); auto num_topk = static_cast(topk_idx.size(1)); auto num_local_experts = num_experts / num_ranks; @@ -1574,7 +1581,12 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, LowLatencyLayout layout(rdma_buffer_ptr, num_max_dispatch_tokens_per_rank, hidden, num_ranks, num_experts); EP_HOST_ASSERT(layout.total_bytes <= num_rdma_bytes); auto buffer = layout.buffers[low_latency_buffer_idx]; - auto next_buffer = layout.buffers[low_latency_buffer_idx ^= 1]; + auto next_buffer = layout.buffers[low_latency_buffer_idx ^ 1]; + + if (not use_expert_overlap || round_id == (num_rounds - 1)) { + // Buffer control + low_latency_buffer_idx ^= 1; + } // Wait previous tasks to be finished // NOTES: the hook mode will always use the default stream @@ -1585,42 +1597,55 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, stream_wait(launch_stream, compute_stream); // Allocate packed tensors - auto packed_recv_x = torch::empty({num_local_experts, num_ranks * num_max_dispatch_tokens_per_rank, hidden}, - x.options().dtype(use_fp8 ? torch::kFloat8_e4m3fn : torch::kBFloat16)); - auto packed_recv_src_info = - torch::empty({num_local_experts, num_ranks * num_max_dispatch_tokens_per_rank}, torch::dtype(torch::kInt32).device(torch::kCUDA)); - auto packed_recv_layout_range = torch::empty({num_local_experts, num_ranks}, torch::dtype(torch::kInt64).device(torch::kCUDA)); - auto packed_recv_count = torch::empty({num_local_experts}, torch::dtype(torch::kInt32).device(torch::kCUDA)); - - // Allocate column-majored scales + auto packed_recv_x = std::optional(); + auto packed_recv_src_info = std::optional(); + auto packed_recv_layout_range = std::optional(); + auto packed_recv_count = std::optional(); auto packed_recv_x_scales = std::optional(); - void* packed_recv_x_scales_ptr = nullptr; - EP_HOST_ASSERT((num_ranks * num_max_dispatch_tokens_per_rank) % 4 == 0 and "TMA requires the number of tokens to be multiple of 4"); - - if (use_fp8) { - // TODO: support unaligned cases - EP_HOST_ASSERT(hidden % 512 == 0); - if (not use_ue8m0) { - packed_recv_x_scales = torch::empty({num_local_experts, hidden / 128, num_ranks * num_max_dispatch_tokens_per_rank}, - torch::dtype(torch::kFloat32).device(torch::kCUDA)); - } else { - EP_HOST_ASSERT(round_scale); - packed_recv_x_scales = torch::empty({num_local_experts, hidden / 512, num_ranks * num_max_dispatch_tokens_per_rank}, - torch::dtype(torch::kInt).device(torch::kCUDA)); + + if (not use_expert_overlap or round_id == 0) { + packed_recv_x = torch::empty({num_local_experts, num_ranks * num_max_dispatch_tokens_per_rank, hidden}, + x.options().dtype(use_fp8 ? torch::kFloat8_e4m3fn : torch::kBFloat16)); + packed_recv_src_info = + torch::empty({num_local_experts, num_ranks * num_max_dispatch_tokens_per_rank}, torch::dtype(torch::kInt32).device(torch::kCUDA)); + packed_recv_layout_range = torch::empty({num_local_experts, num_ranks}, torch::dtype(torch::kInt64).device(torch::kCUDA)); + packed_recv_count = torch::empty({num_local_experts}, torch::dtype(torch::kInt32).device(torch::kCUDA)); + + // Allocate column-majored scales + void* packed_recv_x_scales_ptr = nullptr; + EP_HOST_ASSERT((num_ranks * num_max_dispatch_tokens_per_rank) % 4 == 0 and "TMA requires the number of tokens to be multiple of 4"); + + if (use_fp8) { + // TODO: support unaligned cases + EP_HOST_ASSERT(hidden % 512 == 0); + if (not use_ue8m0) { + packed_recv_x_scales = torch::empty({num_local_experts, hidden / 128, num_ranks * num_max_dispatch_tokens_per_rank}, + torch::dtype(torch::kFloat32).device(torch::kCUDA)); + } else { + EP_HOST_ASSERT(round_scale); + packed_recv_x_scales = torch::empty({num_local_experts, hidden / 512, num_ranks * num_max_dispatch_tokens_per_rank}, + torch::dtype(torch::kInt).device(torch::kCUDA)); + } + packed_recv_x_scales = torch::transpose(packed_recv_x_scales.value(), 1, 2); + packed_recv_x_scales_ptr = packed_recv_x_scales->data_ptr(); } - packed_recv_x_scales = torch::transpose(packed_recv_x_scales.value(), 1, 2); - packed_recv_x_scales_ptr = packed_recv_x_scales->data_ptr(); + + cached_packed_recv_x_ptr = packed_recv_x->data_ptr(); + cached_packed_recv_x_scales_ptr = packed_recv_x_scales_ptr; + cached_packed_recv_count_ptr = packed_recv_count->data_ptr(); + cached_packed_recv_src_info_ptr = packed_recv_src_info->data_ptr(); + cached_packed_recv_layout_range_ptr = packed_recv_layout_range->data_ptr(); } // Kernel launch auto next_clean_meta = next_buffer.clean_meta(); - auto launcher = [=](int phases) { + auto launcher = [=](int phases, int num_sms, cudaStream_t stream) { internode_ll::dispatch( - packed_recv_x.data_ptr(), - packed_recv_x_scales_ptr, - packed_recv_src_info.data_ptr(), - packed_recv_layout_range.data_ptr(), - packed_recv_count.data_ptr(), + cached_packed_recv_x_ptr, + cached_packed_recv_x_scales_ptr, + cached_packed_recv_src_info_ptr, + cached_packed_recv_layout_range_ptr, + cached_packed_recv_count_ptr, mask_buffer_ptr, cumulative_local_expert_recv_stats.has_value() ? cumulative_local_expert_recv_stats->data_ptr() : nullptr, dispatch_wait_recv_cost_stats.has_value() ? dispatch_wait_recv_cost_stats->data_ptr() : nullptr, @@ -1642,11 +1667,14 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, round_scale, use_ue8m0, workspace, - num_device_sms, - launch_stream, - phases); + num_sms, + stream, + phases, + use_expert_overlap, + num_rounds, + round_id); }; - launcher(return_recv_hook ? LOW_LATENCY_SEND_PHASE : (LOW_LATENCY_SEND_PHASE | LOW_LATENCY_RECV_PHASE)); + launcher(return_recv_hook ? LOW_LATENCY_SEND_PHASE : (LOW_LATENCY_SEND_PHASE | LOW_LATENCY_RECV_PHASE), use_expert_overlap ? send_num_sms : num_device_sms, launch_stream); // Wait streams std::optional event; @@ -1661,7 +1689,7 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, // Receiver callback std::optional> recv_hook = std::nullopt; if (return_recv_hook) - recv_hook = [=]() { launcher(LOW_LATENCY_RECV_PHASE); }; + recv_hook = [=]() { launcher(LOW_LATENCY_RECV_PHASE, use_expert_overlap ? recv_num_sms : num_device_sms, hook_use_comm_stream ? comm_stream : launch_stream); }; // Return values return {packed_recv_x, packed_recv_x_scales, packed_recv_count, packed_recv_src_info, packed_recv_layout_range, event, recv_hook}; @@ -1684,6 +1712,7 @@ std::tuple, std::optional& out) { #ifndef DISABLE_NVSHMEM EP_HOST_ASSERT(low_latency_mode); @@ -1711,6 +1740,12 @@ std::tuple, std::optionalsize(0) == num_ranks); } + if (use_expert_overlap) { + EP_HOST_ASSERT(num_rounds >= 1 && num_experts / num_ranks % num_rounds == 0 && return_recv_hook); + if (send_num_sms == -1) send_num_sms = num_device_sms; + if (recv_num_sms == -1) recv_num_sms = num_device_sms; + } + auto hidden = static_cast(x.size(2)); auto num_topk = static_cast(topk_weights.size(1)); auto num_combined_tokens = static_cast(topk_weights.size(0)); @@ -1719,7 +1754,11 @@ std::tuple, std::optional, std::optional, std::optional event; @@ -1785,7 +1825,7 @@ std::tuple, std::optional> recv_hook = std::nullopt; if (return_recv_hook) - recv_hook = [=]() { launcher(LOW_LATENCY_RECV_PHASE); }; + recv_hook = [=]() { launcher(LOW_LATENCY_RECV_PHASE, use_expert_overlap ? recv_num_sms : num_device_sms, hook_use_comm_stream ? comm_stream : launch_stream, 1, 0); }; // Return values return {combined_x, event, recv_hook}; diff --git a/csrc/deep_ep.hpp b/csrc/deep_ep.hpp index 5fb90bff..759f97d1 100644 --- a/csrc/deep_ep.hpp +++ b/csrc/deep_ep.hpp @@ -98,6 +98,13 @@ struct Buffer { // Workspace void* workspace = nullptr; + // cached per-expert overlap pointers + void* cached_packed_recv_x_ptr = nullptr; + void* cached_packed_recv_x_scales_ptr = nullptr; + int* cached_packed_recv_count_ptr = nullptr; + int* cached_packed_recv_src_info_ptr = nullptr; + int64_t* cached_packed_recv_layout_range_ptr = nullptr; + // Host-side MoE info volatile int* moe_recv_counter = nullptr; int* moe_recv_counter_mapped = nullptr; @@ -254,17 +261,18 @@ struct Buffer { void clean_low_latency_buffer(int num_max_dispatch_tokens_per_rank, int hidden, int num_experts); - std::tuple, + std::optional, + std::optional, + std::optional, std::optional, - torch::Tensor, - torch::Tensor, - torch::Tensor, std::optional, std::optional>> low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_idx, const std::optional& cumulative_local_expert_recv_stats, const std::optional& dispatch_wait_recv_cost_stats, + bool use_expert_overlap, int num_rounds, int round_id,int send_num_sms, int recv_num_sms, bool hook_use_comm_stream, int num_max_dispatch_tokens_per_rank, int num_experts, bool use_fp8, @@ -286,6 +294,7 @@ struct Buffer { bool zero_copy, bool async, bool return_recv_hook, + bool use_expert_overlap, int num_rounds, int send_round_id, int send_num_sms, int recv_num_sms, bool hook_use_comm_stream, const std::optional& out = std::nullopt); torch::Tensor get_next_low_latency_combine_buffer(int num_max_dispatch_tokens_per_rank, int hidden, int num_experts) const; diff --git a/csrc/event.hpp b/csrc/event.hpp index b0b4383b..564d7e2e 100644 --- a/csrc/event.hpp +++ b/csrc/event.hpp @@ -31,7 +31,9 @@ torch::Event create_event(const at::cuda::CUDAStream& s) { } void stream_wait(const at::cuda::CUDAStream& s_0, const at::cuda::CUDAStream& s_1) { - EP_HOST_ASSERT(s_0.id() != s_1.id()); + if (s_0.id() == s_1.id()) { + return; + } s_0.unwrap().wait(create_event(s_1)); } diff --git a/csrc/kernels/api.cuh b/csrc/kernels/api.cuh index 9bbe096a..fbf39052 100644 --- a/csrc/kernels/api.cuh +++ b/csrc/kernels/api.cuh @@ -310,7 +310,10 @@ void dispatch(void* packed_recv_x, void* workspace, int num_device_sms, cudaStream_t stream, - int phases); + int phases, + bool use_expert_overlap, + int num_rounds, + int round_id); void combine(void* combined_x, void* rdma_recv_x, @@ -337,7 +340,10 @@ void combine(void* combined_x, int num_device_sms, cudaStream_t stream, int phases, - bool zero_copy); + bool zero_copy, + bool use_expert_overlap, + int num_rounds, + int round_id); void query_mask_buffer(int* mask_buffer_ptr, int num_ranks, int* output_mask_tensor, cudaStream_t stream); diff --git a/csrc/kernels/internode_ll.cu b/csrc/kernels/internode_ll.cu index e9fd473b..bfbdb6b9 100644 --- a/csrc/kernels/internode_ll.cu +++ b/csrc/kernels/internode_ll.cu @@ -7,6 +7,34 @@ namespace deep_ep { namespace internode_ll { +__device__ __forceinline__ +std::pair convert_physical_to_expert_id_in_round( + int num_local_experts, int num_local_experts_per_round, int physical_expert_id) { + auto rank_id = physical_expert_id / num_local_experts; + auto local_expert_id = physical_expert_id % num_local_experts; + auto round_id = local_expert_id / num_local_experts_per_round; + auto expert_in_round = rank_id * num_local_experts_per_round + physical_expert_id % num_local_experts_per_round; + return std::make_pair(round_id, expert_in_round); +} + +__device__ __forceinline__ +int convert_expert_id_in_round_to_physical( + int num_local_experts, int num_local_experts_per_round, + int round_id, int expert_in_round) { + auto rank_id = expert_in_round / num_local_experts_per_round; + auto local_expert_id = round_id * num_local_experts_per_round + expert_in_round % num_local_experts_per_round; + auto physical_expert_id = num_local_experts * rank_id + local_expert_id; + return physical_expert_id; +} + +__device__ __forceinline__ +std::pair convert_expert_id_in_round_to_sm( + int num_experts_per_sm_per_round, int expert_in_round) { + auto actual_sm_id = expert_in_round / num_experts_per_sm_per_round; + auto expert_id_in_sm = expert_in_round % num_experts_per_sm_per_round; + return std::make_pair(actual_sm_id, expert_id_in_sm); +} + template __forceinline__ __device__ bool is_rank_masked(int* mask_buffer_ptr, int rank) { if (mask_buffer_ptr == nullptr) { @@ -153,7 +181,10 @@ __global__ __launch_bounds__(1024, 1) void dispatch(void* packed_recv_x, int num_warp_groups, int num_warps_per_group, bool round_scale, - int phases) { + int phases, + bool use_expert_overlap, + int num_rounds, + int round_id) { const auto sm_id = static_cast(blockIdx.x); const auto thread_id = static_cast(threadIdx.x); const auto warp_id = thread_id / 32, lane_id = get_lane_id(); @@ -162,7 +193,8 @@ __global__ __launch_bounds__(1024, 1) void dispatch(void* packed_recv_x, const auto num_local_experts = num_experts / num_ranks; const auto warp_group_id = warp_id / num_warps_per_group; const auto sub_warp_id = warp_id % num_warps_per_group; - const auto responsible_expert_idx = sm_id * num_warp_groups + warp_group_id; + auto responsible_expert_idx = sm_id * num_warp_groups + warp_group_id; + auto expert_id_in_round = sm_id * num_warp_groups + warp_group_id; // May extract UE8M0 from the scales using scale_t = std::conditional_t; @@ -186,6 +218,18 @@ __global__ __launch_bounds__(1024, 1) void dispatch(void* packed_recv_x, constexpr int kNumMaxWarpGroups = 32; __shared__ int shared_num_tokens_sent_per_expert[kNumMaxWarpGroups]; + if (use_expert_overlap) { + EP_DEVICE_ASSERT (num_rounds >= 1 && num_local_experts % num_rounds == 0 && round_id >= 0 && round_id < num_rounds); + } else { + round_id = 0; + num_rounds = 1; + } + + auto num_experts_per_round = num_experts / num_rounds; + auto num_experts_per_sm_per_round = (num_experts_per_round + num_sms - 1) / num_sms; + EP_DEVICE_ASSERT(num_experts_per_sm_per_round == num_warp_groups); + auto num_local_experts_per_round = num_local_experts / num_rounds; + // Sending phase if ((phases & LOW_LATENCY_SEND_PHASE) == 0) goto LOW_LATENCY_DISPATCH_RECV; @@ -211,52 +255,56 @@ __global__ __launch_bounds__(1024, 1) void dispatch(void* packed_recv_x, thread_id == 0 ? (*rdma_x_src_idx = token_idx) : 0; // FP8 cast - EP_STATIC_ASSERT(hidden_bf16_int4 % 32 == 0, "Must use the full warp to reduce"); - #pragma unroll - for (int i = thread_id; i < hidden_bf16_int4; i += num_threads) { - // Read - auto int4_value = __ldg(x_int4 + i); - - if constexpr (kUseFP8) { - // Calculate local amax - auto bf16_values = reinterpret_cast(&int4_value); - float fp32_values[kNumElemsPerRead]; - float amax = kFP8Margin, scale, scale_inv; - #pragma unroll - for (int j = 0; j < kNumElemsPerRead; ++j) { - fp32_values[j] = static_cast(bf16_values[j]); - amax = fmaxf(amax, fabsf(fp32_values[j])); - } - - // Reduce amax and scale - EP_STATIC_ASSERT(kNumElemsPerRead * 32 / kNumPerChannels == 2, "Invalid vectorization"); - amax = warp_reduce_max<16>(amax); - calculate_fp8_scales(amax, scale, scale_inv, round_scale); - if (lane_id == 0 or lane_id == 16) - rdma_x_scales[i * kNumElemsPerRead / 128] = scale_inv; + if (round_id == 0) { + EP_STATIC_ASSERT(hidden_bf16_int4 % 32 == 0, "Must use the full warp to reduce"); + #pragma unroll + for (int i = thread_id; i < hidden_bf16_int4; i += num_threads) { + // Read + auto int4_value = __ldg(x_int4 + i); + + if constexpr (kUseFP8) { + // Calculate local amax + auto bf16_values = reinterpret_cast(&int4_value); + float fp32_values[kNumElemsPerRead]; + float amax = kFP8Margin, scale, scale_inv; + #pragma unroll + for (int j = 0; j < kNumElemsPerRead; ++j) { + fp32_values[j] = static_cast(bf16_values[j]); + amax = fmaxf(amax, fabsf(fp32_values[j])); + } - // Cast into send buffer - vec_t int2_value; - auto fp8x2_values = reinterpret_cast<__nv_fp8x2_storage_t*>(&int2_value); - #pragma unroll - for (int j = 0; j < kNumElemsPerRead; j += 2) { - float2 fp32x2 = {fp32_values[j] * scale, fp32_values[j + 1] * scale}; - fp8x2_values[j / 2] = __nv_cvt_float2_to_fp8x2(fp32x2, __NV_SATFINITE, __NV_E4M3); + // Reduce amax and scale + EP_STATIC_ASSERT(kNumElemsPerRead * 32 / kNumPerChannels == 2, "Invalid vectorization"); + amax = warp_reduce_max<16>(amax); + calculate_fp8_scales(amax, scale, scale_inv, round_scale); + if (lane_id == 0 or lane_id == 16) + rdma_x_scales[i * kNumElemsPerRead / 128] = scale_inv; + + // Cast into send buffer + vec_t int2_value; + auto fp8x2_values = reinterpret_cast<__nv_fp8x2_storage_t*>(&int2_value); + #pragma unroll + for (int j = 0; j < kNumElemsPerRead; j += 2) { + float2 fp32x2 = {fp32_values[j] * scale, fp32_values[j + 1] * scale}; + fp8x2_values[j / 2] = __nv_cvt_float2_to_fp8x2(fp32x2, __NV_SATFINITE, __NV_E4M3); + } + rdma_x_vec[i] = int2_value; + } else { + // Reinterpret-cast is for C++14 compatibility + rdma_x_vec[i] = *reinterpret_cast(&int4_value); } - rdma_x_vec[i] = int2_value; - } else { - // Reinterpret-cast is for C++14 compatibility - rdma_x_vec[i] = *reinterpret_cast(&int4_value); } } asm volatile("bar.sync 1, %0;" ::"r"(num_threads)); // Issue IBGDA sends if (dst_expert_idx >= 0) { - int slot_idx = lane_id == 0 ? atomicAdd(atomic_counter_per_expert + dst_expert_idx, 1) : 0; - slot_idx = __shfl_sync(0xffffffff, slot_idx, 0); const auto dst_rank = dst_expert_idx / num_local_experts; const auto dst_expert_local_idx = dst_expert_idx % num_local_experts; + if (use_expert_overlap and round_id != dst_expert_local_idx / num_local_experts_per_round) + continue; + int slot_idx = lane_id == 0 ? atomicAdd(atomic_counter_per_expert + dst_expert_idx, 1) : 0; + slot_idx = __shfl_sync(0xffffffff, slot_idx, 0); const auto src_ptr = reinterpret_cast(rdma_x_src_idx); const auto dst_ptr = reinterpret_cast(rdma_recv_x) + dst_expert_local_idx * num_ranks * num_max_dispatch_tokens_per_rank * num_bytes_per_msg + @@ -285,15 +333,19 @@ __global__ __launch_bounds__(1024, 1) void dispatch(void* packed_recv_x, EP_DEVICE_ASSERT(ibgda_get_state()->num_rc_per_pe >= num_local_experts); // The first SM is also responsible for cleaning the next buffer - #pragma unroll - for (int i = lane_id; i < num_next_clean_int; i += 32) - next_clean[i] = 0; + if (round_id == (num_rounds - 1)) { + #pragma unroll + for (int i = lane_id; i < num_next_clean_int; i += 32) + next_clean[i] = 0; + } // Notify before executing `int_p` __syncwarp(); #pragma unroll - for (int i = lane_id; i < num_experts; i += 32) - atomic_add_release_global(atomic_finish_counter_per_expert + i, FINISHED_SUM_TAG); + for (int i = lane_id; i < num_experts_per_round; i += 32) { + auto expert_id = convert_expert_id_in_round_to_physical(num_local_experts, num_local_experts_per_round, round_id, i); + atomic_add_release_global(atomic_finish_counter_per_expert + expert_id, FINISHED_SUM_TAG); + } } // This SM should be responsible for some destination experts, read `topk_idx` for them @@ -305,27 +357,60 @@ __global__ __launch_bounds__(1024, 1) void dispatch(void* packed_recv_x, #pragma unroll 8 for (int i = lane_id; i < num_tokens * num_topk; i += 32) { auto idx = static_cast(__ldg(topk_idx + i)); - if (idx >= expert_begin_idx and idx < expert_end_idx) - expert_count[idx - expert_begin_idx]++; + if (use_expert_overlap) { + auto [send_round_id, expert_in_round] = convert_physical_to_expert_id_in_round(num_local_experts, num_local_experts_per_round, idx); + auto [actual_sm_id, expert_id_in_sm] = convert_expert_id_in_round_to_sm(num_experts_per_sm_per_round, expert_in_round); + if (actual_sm_id == sm_id && send_round_id == round_id) { + expert_count[expert_id_in_sm] ++; + } + } else { + if (idx >= expert_begin_idx and idx < expert_end_idx) + expert_count[idx - expert_begin_idx]++; + } } // Warp reduce - #pragma unroll - for (int i = expert_begin_idx; i < expert_end_idx; ++i) { - auto sum = warp_reduce_sum(expert_count[i - expert_begin_idx]); - if (lane_id == 0) { - shared_num_tokens_sent_per_expert[i - expert_begin_idx] = sum; - atomic_add_release_global(atomic_finish_counter_per_expert + i, FINISHED_SUM_TAG - sum); + if (use_expert_overlap) { + #pragma unroll + for (int j = 0; j < num_experts_per_sm_per_round; j++) { + auto expert_in_round = sm_id * num_experts_per_sm_per_round + j; + auto [actual_sm_id, expert_id_in_sm] = convert_expert_id_in_round_to_sm(num_experts_per_sm_per_round, expert_in_round); + if (actual_sm_id != sm_id) continue; + auto physical_expert_id = convert_expert_id_in_round_to_physical(num_local_experts, num_local_experts_per_round, round_id, expert_in_round); + auto sum = warp_reduce_sum(expert_count[expert_id_in_sm]); + if (lane_id == 0) { + shared_num_tokens_sent_per_expert[expert_id_in_sm] = sum; + atomic_add_release_global(atomic_finish_counter_per_expert + physical_expert_id, FINISHED_SUM_TAG - sum); + } + } + } else { + #pragma unroll + for (int i = expert_begin_idx; i < expert_end_idx; ++i) { + auto sum = warp_reduce_sum(expert_count[i - expert_begin_idx]); + if (lane_id == 0) { + shared_num_tokens_sent_per_expert[i - expert_begin_idx] = sum; + atomic_add_release_global(atomic_finish_counter_per_expert + i, FINISHED_SUM_TAG - sum); + } } } } __syncthreads(); // Issue count sends - if (responsible_expert_idx < num_experts and sub_warp_id == 0 and lane_id == 0) { + if (use_expert_overlap) { + responsible_expert_idx = convert_expert_id_in_round_to_physical(num_local_experts, num_local_experts_per_round, round_id, expert_id_in_round); + } + if (expert_id_in_round < num_experts_per_round and responsible_expert_idx < num_experts and sub_warp_id == 0 and lane_id == 0) { const auto dst_rank = responsible_expert_idx / num_local_experts; const auto dst_expert_local_idx = responsible_expert_idx % num_local_experts; - const auto num_tokens_sent = shared_num_tokens_sent_per_expert[responsible_expert_idx - sm_id * num_warp_groups]; + int num_tokens_sent; + if (use_expert_overlap) { + auto [actual_sm_id, expert_id_in_sm] = convert_expert_id_in_round_to_sm(num_experts_per_sm_per_round, expert_id_in_round); + if (actual_sm_id != sm_id) goto SKIP_WARP; + num_tokens_sent = shared_num_tokens_sent_per_expert[expert_id_in_sm]; + } else { + num_tokens_sent = shared_num_tokens_sent_per_expert[responsible_expert_idx - sm_id * num_warp_groups]; + } // Wait local sends issued and send expert counts while (ld_acquire_global(atomic_finish_counter_per_expert + responsible_expert_idx) != FINISHED_SUM_TAG * 2) @@ -348,6 +433,7 @@ __global__ __launch_bounds__(1024, 1) void dispatch(void* packed_recv_x, if (dst_rank == 0) packed_recv_count[dst_expert_local_idx] = 0; } + SKIP_WARP: __syncwarp(); // Receiving phase @@ -360,7 +446,10 @@ LOW_LATENCY_DISPATCH_RECV: cg::this_grid().sync(); // Receiving and packing - if (responsible_expert_idx < num_experts) { + if (use_expert_overlap) { + responsible_expert_idx = convert_expert_id_in_round_to_physical(num_local_experts, num_local_experts_per_round, round_id, expert_id_in_round); + } + if (expert_id_in_round < num_experts_per_round and responsible_expert_idx < num_experts) { const auto src_rank = responsible_expert_idx / num_local_experts; const auto local_expert_idx = responsible_expert_idx % num_local_experts; const auto rdma_recv_x_uint8 = static_cast(rdma_recv_x) + @@ -490,15 +579,21 @@ void dispatch(void* packed_recv_x, void* workspace, int num_device_sms, cudaStream_t stream, - int phases) { + int phases, + bool use_expert_overlap, + int num_rounds, + int round_id) { constexpr int kNumMaxTopK = 11; - const int num_warp_groups = ceil_div(num_experts, num_device_sms); + if (not use_expert_overlap) { + num_rounds = 1; + } + int num_warp_groups = ceil_div(num_experts / num_rounds, num_device_sms); const int num_warps_per_group = 32 / num_warp_groups; EP_HOST_ASSERT(num_warp_groups > 0 and num_warps_per_group > 0); EP_HOST_ASSERT(kNumMaxTopK + 1 <= num_warp_groups * num_warps_per_group); const auto num_warps = num_warp_groups * num_warps_per_group; - const auto num_sms = ceil_div(num_experts, num_warp_groups); + const auto num_sms = ceil_div(num_experts / num_rounds, num_warp_groups); EP_HOST_ASSERT(num_topk <= kNumMaxTopK); // Workspace checks @@ -545,7 +640,10 @@ void dispatch(void* packed_recv_x, num_warp_groups, \ num_warps_per_group, \ round_scale, \ - phases); \ + phases, \ + use_expert_overlap, \ + num_rounds, \ + round_id); \ } \ break @@ -737,7 +835,10 @@ __global__ __launch_bounds__(1024, 1) void combine(void* combined_x, int num_warp_groups, int num_warps_per_group, int phases, - bool zero_copy) { + bool zero_copy, + bool use_expert_overlap, + int num_rounds, + int send_round_id) { const auto sm_id = __shfl_sync(0xffffffff, static_cast(blockIdx.x), 0); const auto num_sms = __shfl_sync(0xffffffff, static_cast(gridDim.x), 0); const auto thread_id = static_cast(threadIdx.x); @@ -746,7 +847,8 @@ __global__ __launch_bounds__(1024, 1) void combine(void* combined_x, const auto num_local_experts = num_experts / num_ranks; const auto warp_group_id = warp_id / num_warps_per_group; const auto sub_warp_id = warp_id % num_warps_per_group; - const auto responsible_expert_idx = sm_id * num_warp_groups + warp_group_id; + auto responsible_expert_idx = sm_id * num_warp_groups + warp_group_id; + auto expert_id_in_round = sm_id * num_warp_groups + warp_group_id; extern __shared__ __align__(1024) uint8_t smem_buffer[]; @@ -770,24 +872,39 @@ __global__ __launch_bounds__(1024, 1) void combine(void* combined_x, constexpr size_t num_bytes_per_slot = kHidden * sizeof(nv_bfloat16) + kNumMetaBytes; EP_STATIC_ASSERT(num_bytes_per_slot % sizeof(int4) == 0, "Invalid vectorization"); + if (use_expert_overlap) { + EP_DEVICE_ASSERT(num_local_experts % num_rounds == 0 && num_rounds >= 1 && send_round_id >= 0 && send_round_id < num_rounds); + } else { + num_rounds = 1; + send_round_id = 0; + } + + auto num_experts_per_round = num_experts / num_rounds; + auto num_local_experts_per_round = num_local_experts / num_rounds; + // Sending phase if ((phases & LOW_LATENCY_SEND_PHASE) == 0) goto LOW_LATENCY_COMBINE_RECV; // Clean up next buffer if (sm_id == 0 and warp_group_id == 0 and sub_warp_id == 0) { - #pragma unroll - for (int i = lane_id; i < num_next_clean_int; i += 32) - next_clean[i] = 0; + if (send_round_id == 0) { + #pragma unroll + for (int i = lane_id; i < num_next_clean_int; i += 32) + next_clean[i] = 0; + } // Notify before executing `int_p` __syncwarp(); if (lane_id == 0) - atomic_add_release_global(atomic_clean_flag, num_experts); + atomic_add_release_global(atomic_clean_flag, num_experts_per_round); } // Issue IBGDA sends - if (responsible_expert_idx < num_experts) { + if (use_expert_overlap) { + responsible_expert_idx = convert_expert_id_in_round_to_physical(num_local_experts, num_local_experts_per_round, send_round_id, responsible_expert_idx); + } + if (expert_id_in_round < num_experts_per_round and responsible_expert_idx < num_experts) { const auto dst_rank = responsible_expert_idx / num_local_experts; const auto local_expert_idx = responsible_expert_idx % num_local_experts; const auto global_expert_idx = rank * num_local_experts + local_expert_idx; @@ -1163,16 +1280,22 @@ void combine(void* combined_x, int num_device_sms, cudaStream_t stream, int phases, - bool zero_copy) { + bool zero_copy, + bool use_expert_overlap, + int num_rounds, + int send_round_id) { constexpr int kNumMaxTopk = 11; - const int num_warp_groups = ceil_div(num_experts, num_device_sms); + if (not use_expert_overlap) { + num_rounds = 1; + } + int num_warp_groups = ceil_div(num_experts / num_rounds, num_device_sms); const int num_warps_per_group = 32 / num_warp_groups; const int num_recv_per_sm = ceil_div(num_combined_tokens, num_device_sms); EP_HOST_ASSERT(num_warp_groups > 0 and num_warps_per_group > 0 and num_recv_per_sm >= 0); const auto num_warps = num_warp_groups * num_warps_per_group; const auto num_sms = - max(ceil_div(num_experts, num_warp_groups), num_recv_per_sm == 0 ? 1 : ceil_div(num_combined_tokens, num_recv_per_sm)); + max(ceil_div(num_experts / num_rounds, num_warp_groups), num_recv_per_sm == 0 ? 1 : ceil_div(num_combined_tokens, num_recv_per_sm)); // Check workspace auto atomic_clean_flag = static_cast(workspace); @@ -1229,7 +1352,10 @@ void combine(void* combined_x, num_warp_groups, \ num_warps_per_group, \ phases, \ - zero_copy); \ + zero_copy, \ + use_expert_overlap, \ + num_rounds, \ + send_round_id); \ } \ break diff --git a/deep_ep/buffer.py b/deep_ep/buffer.py index 37512ee9..f02f8625 100644 --- a/deep_ep/buffer.py +++ b/deep_ep/buffer.py @@ -549,6 +549,7 @@ def low_latency_dispatch(self, x: torch.Tensor, topk_idx: torch.Tensor, num_max_dispatch_tokens_per_rank: int, num_experts: int, cumulative_local_expert_recv_stats: Optional[torch.Tensor] = None, dispatch_wait_recv_cost_stats: Optional[torch.Tensor] = None, + use_expert_overlap: bool = False, num_rounds: int = -1, round_id: int = -1, send_num_sms: int = -1, recv_num_sms: int = -1, hook_use_comm_stream: bool = False, use_fp8: bool = True, round_scale: bool = False, use_ue8m0: bool = False, async_finish: bool = False, return_recv_hook: bool = False) -> \ Tuple[Tuple[torch.Tensor, torch.Tensor], torch.Tensor, Tuple, EventOverlap, Callable]: @@ -604,6 +605,7 @@ def low_latency_dispatch(self, x: torch.Tensor, topk_idx: torch.Tensor, self.runtime.low_latency_dispatch(x, topk_idx, cumulative_local_expert_recv_stats, dispatch_wait_recv_cost_stats, + use_expert_overlap, num_rounds, round_id, send_num_sms, recv_num_sms, hook_use_comm_stream, num_max_dispatch_tokens_per_rank, num_experts, use_fp8, round_scale, use_ue8m0, async_finish, return_recv_hook) @@ -617,7 +619,8 @@ def low_latency_dispatch(self, x: torch.Tensor, topk_idx: torch.Tensor, def low_latency_combine(self, x: torch.Tensor, topk_idx: torch.Tensor, topk_weights: torch.Tensor, handle: tuple, use_logfmt: bool = False, zero_copy: bool = False, async_finish: bool = False, return_recv_hook: bool = False, out: Optional[torch.Tensor] = None, - combine_wait_recv_cost_stats: Optional[torch.Tensor] = None) -> \ + combine_wait_recv_cost_stats: Optional[torch.Tensor] = None, + use_expert_overlap: bool = False, num_rounds: int = -1, round_id: int = -1, send_num_sms: int = -1, recv_num_sms: int = -1, hook_use_comm_stream: bool = False) -> \ Tuple[torch.Tensor, EventOverlap, Callable]: """ A low-latency implementation for combining tokens (reduce **with weights**) with IBGDA. @@ -656,7 +659,9 @@ def low_latency_combine(self, x: torch.Tensor, topk_idx: torch.Tensor, topk_weig assert self.nvshmem_qp_depth >= (num_max_dispatch_tokens_per_rank + 1) * 2 combined_x, event, hook = self.runtime.low_latency_combine(x, topk_idx, topk_weights, src_info, layout_range, combine_wait_recv_cost_stats, num_max_dispatch_tokens_per_rank, - num_experts, use_logfmt, zero_copy, async_finish, return_recv_hook, out) + num_experts, use_logfmt, zero_copy, async_finish, return_recv_hook, + use_expert_overlap, num_rounds, round_id, send_num_sms, recv_num_sms, hook_use_comm_stream, + out) tensors_to_record = (x, topk_idx, topk_weights, src_info, layout_range, combined_x) return combined_x, EventOverlap(event, tensors_to_record if async_finish else None), hook diff --git a/tests/test_low_latency.py b/tests/test_low_latency.py index 456dcf27..2ba7f016 100644 --- a/tests/test_low_latency.py +++ b/tests/test_low_latency.py @@ -4,6 +4,7 @@ import torch.distributed as dist from functools import partial from typing import Literal, Set +import itertools import deep_ep from utils import init_dist, bench, bench_kineto, calc_diff, hash_tensor, per_token_cast_back @@ -44,7 +45,8 @@ def test_main(num_tokens: int, buffer: deep_ep.Buffer, use_logfmt: bool = False, shrink_test: bool = False, - seed: int = 0): + seed: int = 0, + use_expert_overlap: bool = False, send_num_sms = -1, recv_num_sms = -1, num_rounds = 1): torch.manual_seed(seed + rank) random.seed(seed + rank) @@ -94,12 +96,17 @@ def test_main(num_tokens: int, num_times += 1 for _ in range((num_times % 2) + 1): cumulative_local_expert_recv_stats = torch.zeros((num_local_experts, ), dtype=torch.int, device='cuda') - packed_recv_x, packed_recv_count, handle, event, hook = \ - buffer.low_latency_dispatch(current_x, topk_idx, num_tokens, num_experts, - use_fp8=dispatch_use_fp8, round_scale=round_scale, use_ue8m0=use_ue8m0, - cumulative_local_expert_recv_stats=cumulative_local_expert_recv_stats, - async_finish=not return_recv_hook, return_recv_hook=return_recv_hook) - hook() if return_recv_hook else event.current_stream_wait() + num_rounds = num_rounds if use_expert_overlap else 1 + for round_id in range(num_rounds): + tmp_packed_recv_x, tmp_packed_recv_count, tmp_handle, event, hook = \ + buffer.low_latency_dispatch(current_x, topk_idx, num_tokens, num_experts, + use_fp8=dispatch_use_fp8, round_scale=round_scale, use_ue8m0=use_ue8m0, + use_expert_overlap=use_expert_overlap, num_rounds=num_rounds, round_id=round_id, send_num_sms=send_num_sms, recv_num_sms=recv_num_sms, + cumulative_local_expert_recv_stats=cumulative_local_expert_recv_stats, + async_finish=not use_expert_overlap and not return_recv_hook, return_recv_hook=use_expert_overlap or return_recv_hook) + hook() if return_recv_hook or use_expert_overlap else event.current_stream_wait() + if round_id == 0: + packed_recv_x, packed_recv_count, handle = tmp_packed_recv_x, tmp_packed_recv_count, tmp_handle if shrink_test: query_mask_buffer_and_check("dispatch", buffer, mask_status, expected_masked_ranks) packed_recv_x = (packed_recv_x[0], packed_recv_x[1].contiguous()) if dispatch_use_fp8 else packed_recv_x @@ -153,16 +160,18 @@ def test_main(num_tokens: int, if zero_copy: buffer.get_next_low_latency_combine_buffer(handle)[:, :, :] = simulated_gemm_x out = torch.empty((num_tokens, hidden), dtype=torch.bfloat16, device='cuda') - combined_x, event, hook = buffer.low_latency_combine(simulated_gemm_x, + num_rounds = num_rounds if use_expert_overlap else 1 + for round_id in range(num_rounds): + combined_x, event, hook = buffer.low_latency_combine(simulated_gemm_x, topk_idx, topk_weights, handle, - use_logfmt=use_logfmt, - async_finish=not return_recv_hook, - zero_copy=zero_copy, - return_recv_hook=return_recv_hook, + async_finish=not use_expert_overlap and not return_recv_hook, + zero_copy=zero_copy and not use_expert_overlap, + return_recv_hook=use_expert_overlap or return_recv_hook, + use_expert_overlap=use_expert_overlap, num_rounds=num_rounds, round_id=round_id, send_num_sms=send_num_sms, recv_num_sms=recv_num_sms, out=out) - hook() if return_recv_hook else event.current_stream_wait() + hook() if return_recv_hook or use_expert_overlap else event.current_stream_wait() if shrink_test: query_mask_buffer_and_check("combine", buffer, mask_status, expected_masked_ranks) if do_check: @@ -277,6 +286,27 @@ def test_loop(local_rank: int, num_local_ranks: int, args: argparse.Namespace): use_logfmt=args.use_logfmt, shrink_test=args.shrink_test, seed=1) + + do_sbo_test = args.sbo_test + if do_sbo_test: + num_sms_options = (64, 32, 16) + num_rounds_options = (4,) + for (num_sms, num_rounds) in itertools.product(num_sms_options, num_rounds_options): + if rank == 0: + print(f"{num_sms=}, {num_rounds=}") + torch.distributed.barrier(group) + test_main(num_tokens, + hidden, + num_experts, + num_topk, + rank, + num_ranks, + group, + buffer, + use_logfmt=args.use_logfmt, + shrink_test=args.shrink_test, + seed=1, + use_expert_overlap=True, send_num_sms=num_sms, recv_num_sms=num_sms, num_rounds=num_rounds) do_pressure_test = args.pressure_test for seed in range(int(1e9) if do_pressure_test else 0): @@ -324,6 +354,7 @@ def test_loop(local_rank: int, num_local_ranks: int, args: argparse.Namespace): parser.add_argument('--use-logfmt', action='store_true', help='Whether to test LogFMT combine') parser.add_argument("--pressure-test", action='store_true', help='Whether to do pressure test') parser.add_argument("--shrink-test", action='store_true', help='Whether to simulate failure and test shrink mode') + parser.add_argument("--sbo-test", action='store_true', help='Whether to do SBO test') args = parser.parse_args() num_processes = args.num_processes