diff --git a/csrc/deep_ep.cpp b/csrc/deep_ep.cpp index 0789cd58..fdaca8d4 100644 --- a/csrc/deep_ep.cpp +++ b/csrc/deep_ep.cpp @@ -1092,8 +1092,10 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i const std::optional& cumulative_local_expert_recv_stats, const std::optional& dispatch_wait_recv_cost_stats, int num_max_dispatch_tokens_per_rank, int num_experts, - bool use_fp8, bool round_scale, bool use_ue8m0, - bool async, bool return_recv_hook) { + bool use_fp8, bool round_scale, bool use_ue8m0, + bool async, bool return_recv_hook, + bool use_per_tensor_quantization, + const std::optional& static_scale) { #ifndef DISABLE_NVSHMEM EP_HOST_ASSERT(low_latency_mode); @@ -1148,7 +1150,7 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i void* packed_recv_x_scales_ptr = nullptr; EP_HOST_ASSERT((num_ranks * num_max_dispatch_tokens_per_rank) % 4 == 0 and "TMA requires the number of tokens to be multiple of 4"); - if (use_fp8) { + if (use_fp8 and not use_per_tensor_quantization) { // TODO: support unaligned cases EP_HOST_ASSERT(hidden % 512 == 0); if (not use_ue8m0) { @@ -1163,6 +1165,19 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i packed_recv_x_scales_ptr = packed_recv_x_scales->data_ptr(); } + // 检查静态量化参数 + if (static_scale.has_value()) { + EP_HOST_ASSERT(use_fp8 && "Static scale requires FP8 quantization"); + auto scale_tensor = static_scale.value(); + EP_HOST_ASSERT(scale_tensor.is_contiguous()); + EP_HOST_ASSERT(scale_tensor.scalar_type() == torch::kFloat32); + if (use_per_tensor_quantization) { + EP_HOST_ASSERT(scale_tensor.numel() == 1); + } else { + EP_HOST_ASSERT(scale_tensor.numel() == hidden / 128); + } + } + // Kernel launch auto next_clean_meta = next_buffer.clean_meta(); auto launcher = [=](int phases) { @@ -1177,7 +1192,8 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i next_clean_meta.first, next_clean_meta.second, num_tokens, hidden, num_max_dispatch_tokens_per_rank, num_topk, num_experts, rank, num_ranks, - use_fp8, round_scale, use_ue8m0, + use_fp8, round_scale, use_ue8m0, use_per_tensor_quantization, + static_scale.has_value() ? static_scale->data_ptr() : nullptr, // 传递静态量化参数 workspace, num_device_sms, launch_stream, phases); }; diff --git a/csrc/deep_ep.hpp b/csrc/deep_ep.hpp index aa62ccb0..4cf93095 100644 --- a/csrc/deep_ep.hpp +++ b/csrc/deep_ep.hpp @@ -149,7 +149,9 @@ struct Buffer { const std::optional& dispatch_wait_recv_cost_stats, int num_max_dispatch_tokens_per_rank, int num_experts, bool use_fp8, bool round_scale, bool use_ue8m0, - bool async, bool return_recv_hook); + bool async, bool return_recv_hook, + bool use_per_tensor_quantization, + const std::optional& static_scale); std::tuple, std::optional>> low_latency_combine(const torch::Tensor& x, const torch::Tensor& topk_idx, const torch::Tensor& topk_weights, diff --git a/csrc/kernels/api.cuh b/csrc/kernels/api.cuh index d34775fd..a02fd420 100644 --- a/csrc/kernels/api.cuh +++ b/csrc/kernels/api.cuh @@ -153,6 +153,20 @@ void dispatch(void* packed_recv_x, void* packed_recv_x_scales, void* workspace, int num_device_sms, cudaStream_t stream, int phases); +void dispatch(void* packed_recv_x, void* packed_recv_x_scales, + int* packed_recv_src_info, int64_t* packed_recv_layout_range, + int* packed_recv_count, + int* cumulative_local_expert_recv_stats, + int64_t* dispatch_wait_recv_cost_stats, + void* rdma_recv_x, int* rdma_recv_count, void* rdma_x, + const void* x, const int64_t* topk_idx, + int* next_clean, int num_next_clean_int, + int num_tokens, int hidden, int num_max_dispatch_tokens_per_rank, + int num_topk, int num_experts, int rank, int num_ranks, + bool use_fp8, bool round_scale, bool use_ue8m0, bool use_per_tensor_quantization,const float* static_scale, + void* workspace, int num_device_sms, + cudaStream_t stream, int phases); + void combine(void* combined_x, void* rdma_recv_x, int* rdma_recv_flag, void* rdma_send_x, const void* x, const int64_t* topk_idx, const float* topk_weights, diff --git a/csrc/kernels/internode_ll.cu b/csrc/kernels/internode_ll.cu index 12beb7e3..f8feae37 100644 --- a/csrc/kernels/internode_ll.cu +++ b/csrc/kernels/internode_ll.cu @@ -36,7 +36,7 @@ void clean_low_latency_buffer(int* clean_0, int num_clean_int_0, clean_0, num_clean_int_0, clean_1, num_clean_int_1); } -template +template __global__ __launch_bounds__(1024, 1) void dispatch(void* packed_recv_x, void* packed_recv_x_scales, int* packed_recv_src_info, int64_t* packed_recv_layout_range, @@ -50,7 +50,7 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales, int num_tokens, int num_max_dispatch_tokens_per_rank, int num_topk, int num_experts, int rank, int num_ranks, int num_warp_groups, int num_warps_per_group, - bool round_scale, int phases) { + bool round_scale, const float* static_scale, int phases) { const auto sm_id = static_cast(blockIdx.x); const auto thread_id = static_cast(threadIdx.x); const auto warp_id = thread_id / 32, lane_id = get_lane_id(); @@ -68,14 +68,15 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales, // FP8 staffs constexpr int kNumPerChannels = 128; - const int num_scales = kHidden / kNumPerChannels; + const int num_scales = kUsePerTensorStaticQuantization ? 1 : kHidden / kNumPerChannels; const size_t hidden_bytes = kHidden * (kUseFP8 ? sizeof(__nv_fp8_storage_t) : sizeof(nv_bfloat16)); const size_t hidden_int4 = hidden_bytes / sizeof(int4); // Message package: hidden data, FP8 scales, index at source // NOTES: currently we have 3 reserved int fields for future use - using vec_t = std::conditional_t; - const size_t num_bytes_per_msg = sizeof(int4) + (kUseFP8 ? (kHidden + num_scales * sizeof(float)) : (kHidden * sizeof(nv_bfloat16))); + using vec_t = typename std::conditional::type; + const size_t base_bytes = sizeof(int4) + (kUseFP8 ? (kHidden + num_scales * sizeof(float)) : (kHidden * sizeof(nv_bfloat16))); + const size_t num_bytes_per_msg = (base_bytes + sizeof(int4) - 1) / sizeof(int4) * sizeof(int4); // 对齐到16字节边界 const size_t num_int4_per_msg = num_bytes_per_msg / sizeof(int4); EP_DEVICE_ASSERT(num_bytes_per_msg % sizeof(int4) == 0); @@ -108,31 +109,39 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales, thread_id == 0 ? (*rdma_x_src_idx = token_idx) : 0; // FP8 cast - EP_STATIC_ASSERT(hidden_bf16_int4 % 32 == 0, "Must use the full warp to reduce"); #pragma unroll for (int i = thread_id; i < hidden_bf16_int4; i += num_threads) { // Read auto int4_value = __ldg(x_int4 + i); if constexpr (kUseFP8) { - // Calculate local amax + float scale, scale_inv; auto bf16_values = reinterpret_cast(&int4_value); float fp32_values[kNumElemsPerRead]; - float amax = kFP8Margin, scale, scale_inv; - #pragma unroll - for (int j = 0; j < kNumElemsPerRead; ++ j) { - fp32_values[j] = static_cast(bf16_values[j]); - amax = fmaxf(amax, fabsf(fp32_values[j])); - } - - // Reduce amax and scale - EP_STATIC_ASSERT(kNumElemsPerRead * 32 / kNumPerChannels == 2, "Invalid vectorization"); - amax = warp_reduce_max<16>(amax); - calculate_fp8_scales(amax, scale, scale_inv, round_scale); - if (lane_id == 0 or lane_id == 16) - rdma_x_scales[i * kNumElemsPerRead / 128] = scale_inv; + if constexpr (kUsePerTensorStaticQuantization) { + EP_DEVICE_ASSERT(static_scale != nullptr); + scale_inv = static_scale[0]; + scale = 1.0f / scale_inv; + for (int j = 0; j < kNumElemsPerRead; ++ j) { + fp32_values[j] = static_cast(bf16_values[j]); + } - // Cast into send buffer + } else { + float amax = kFP8Margin; + #pragma unroll + for (int j = 0; j < kNumElemsPerRead; ++ j) { + fp32_values[j] = static_cast(bf16_values[j]); + amax = fmaxf(amax, fabsf(fp32_values[j])); + } + // Reduce amax and scale + EP_STATIC_ASSERT(kNumElemsPerRead * 32 / kNumPerChannels == 2, "Invalid vectorization"); + amax = warp_reduce_max<16>(amax); + calculate_fp8_scales(amax, scale, scale_inv, round_scale); + if (lane_id == 0 or lane_id == 16) + rdma_x_scales[i * kNumElemsPerRead / 128] = scale_inv; + } + + // Cast into send buffer using per-tensor quantization logic vec_t int2_value; auto fp8x2_values = reinterpret_cast<__nv_fp8x2_storage_t*>(&int2_value); #pragma unroll @@ -309,7 +318,7 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales, UNROLLED_WARP_COPY(7, lane_id, hidden_int4, dst_data, src_data, ld_nc_global, st_na_global); // Copy scales - if constexpr (kUseFP8) { + if constexpr (kUseFP8 and not kUsePerTensorStaticQuantization) { // Equivalent CuTe layout: // (num_tokens, (num_packed, num_elems_per_pack)):(num_elems_per_pack, (num_tokens * num_elems_per_pack, 1)) const auto src_scales = reinterpret_cast(reinterpret_cast(src_data) + hidden_bytes); @@ -344,7 +353,7 @@ void dispatch(void* packed_recv_x, void* packed_recv_x_scales, int* next_clean, int num_next_clean_int, int num_tokens, int hidden, int num_max_dispatch_tokens_per_rank, int num_topk, int num_experts, int rank, int num_ranks, - bool use_fp8, bool round_scale, bool use_ue8m0, + bool use_fp8, bool round_scale, bool use_ue8m0,bool use_per_tensor_quantization,const float* static_scale, void* workspace, int num_device_sms, cudaStream_t stream, int phases) { constexpr int kNumMaxTopK = 9; @@ -367,11 +376,15 @@ void dispatch(void* packed_recv_x, void* packed_recv_x_scales, EP_HOST_ASSERT(round_scale and "UE8M0 SF requires `round_scale=True`"); #define DISPATCH_LAUNCH_CASE(hidden) { \ -auto dispatch_func = dispatch