From 05627796f67a1d9a535c48028e549e514c7b0f18 Mon Sep 17 00:00:00 2001 From: bryan Date: Tue, 9 Dec 2025 16:20:25 +0800 Subject: [PATCH 1/4] emb_segment_reduce backward kernel opt --- csrc/ops/embedding_segment_reduce_kernel.cu | 189 ++++++++++++++++++-- 1 file changed, 176 insertions(+), 13 deletions(-) diff --git a/csrc/ops/embedding_segment_reduce_kernel.cu b/csrc/ops/embedding_segment_reduce_kernel.cu index 2c38a57..e836fd9 100644 --- a/csrc/ops/embedding_segment_reduce_kernel.cu +++ b/csrc/ops/embedding_segment_reduce_kernel.cu @@ -113,6 +113,48 @@ __global__ void segment_reduce_forward_kernel( } } +//============================================================================= +// Helper kernel: Compute segment IDs for each input element +//============================================================================= +template +__global__ void compute_segment_ids_kernel( + const offset_t* __restrict__ offsets, int32_t* __restrict__ segment_ids, + int64_t B, int64_t S) { + // Each block handles one segment and fills segment IDs + int64_t seg = blockIdx.x; + if (seg >= S - 1) return; + + offset_t start = offsets[seg]; + offset_t end = offsets[seg + 1]; + + for (offset_t i = start + threadIdx.x; i < end; i += blockDim.x) { + segment_ids[i] = seg; + } +} + +//============================================================================= +// Helper kernel: Compute weight sums for each segment (for MEAN mode) +//============================================================================= +template +__global__ void compute_weight_sums_kernel( + const scalar_t* __restrict__ weight, const offset_t* __restrict__ offsets, + scalar_t* __restrict__ weight_sums, int64_t S) { + int64_t seg = blockIdx.x * blockDim.x + threadIdx.x; + if (seg >= S - 1) return; + + offset_t start = offsets[seg]; + offset_t end = offsets[seg + 1]; + + scalar_t sum = 0; + for (offset_t i = start; i < end; ++i) { + sum += weight[i]; + } + weight_sums[seg] = sum; +} + +//============================================================================= +// Original: Segment-parallel backward kernel (kept for reference/fallback) +//============================================================================= template __global__ void segment_reduce_backward_kernel( @@ -196,6 +238,73 @@ __global__ void segment_reduce_backward_kernel( } } +//============================================================================= +// Optimized: High-occupancy element-parallel backward kernel +// Key optimizations: +// 1. Element-parallel instead of segment-parallel for better load balancing +// 2. Precomputed segment IDs avoid per-element binary search +// 3. Larger grid size for better latency hiding on modern GPUs +// 4. Vectorized loads (float4) for better memory throughput +//============================================================================= +template +__global__ void high_occupancy_backward_kernel( + const scalar_t* __restrict__ grad_output, + const scalar_t* __restrict__ weight, + const int64_t* __restrict__ reverse_indices, + const int32_t* __restrict__ segment_ids, + const scalar_t* __restrict__ weight_sums, + const offset_t* __restrict__ offsets, scalar_t* grad_unique_emb, int64_t B, + int64_t S, int64_t D) { + // Total work items: B * D / 4 (with PACK_SIZE=4) + const int64_t total_work = (B * D) >> 2; // Divide by 4 + const int64_t tid = blockIdx.x * blockDim.x + threadIdx.x; + const int64_t stride = blockDim.x * gridDim.x; + + // Process multiple elements per thread for better ILP + #pragma unroll 1 // Don't unroll outer loop to save registers + for (int64_t work_idx = tid; work_idx < total_work; work_idx += stride) { + // Decode flat index to (input_idx, dp) + const int64_t flat_idx = work_idx << 2; // Multiply by 4 + const int64_t input_idx = flat_idx / D; + const int64_t dp = flat_idx - input_idx * D; // Faster than modulo + + // Load segment and unique index (these are per-row, not per-element) + const int32_t seg = segment_ids[input_idx]; + const int64_t raw_idx = reverse_indices[input_idx]; + + // Load gradient (4 floats vectorized) + float4 g_vec; + if constexpr (mode == ReduceMode::TILE) { + g_vec = *reinterpret_cast(grad_output + input_idx * D + dp); + } else { + g_vec = *reinterpret_cast(grad_output + seg * D + dp); + } + + // Compute weight factor + scalar_t w_base; + if constexpr (USE_WEIGHT) { + w_base = weight[input_idx]; + if constexpr (mode == ReduceMode::MEAN) { + w_base /= weight_sums[seg]; + } + } else { + if constexpr (mode == ReduceMode::MEAN) { + w_base = static_cast(1) / static_cast(offsets[seg + 1] - offsets[seg]); + } else { + w_base = static_cast(1); + } + } + + // Write results with atomics (4 consecutive floats) + scalar_t* out_ptr = grad_unique_emb + raw_idx * D + dp; + atomicAdd(out_ptr + 0, g_vec.x * w_base); + atomicAdd(out_ptr + 1, g_vec.y * w_base); + atomicAdd(out_ptr + 2, g_vec.z * w_base); + atomicAdd(out_ptr + 3, g_vec.w * w_base); + } +} + #define FORWARD_LAUNCH_KERNEL(scalar_t, offset_t, mode, use_weight, vec_size) \ segment_reduce_forward_kernel \ @@ -243,32 +352,86 @@ void segment_reduce_forward_kernel_launcher( N, S, D); \ C10_CUDA_KERNEL_LAUNCH_CHECK(); +#define LAUNCH_HIGH_OCCUPANCY_BACKWARD_KERNEL(scalar_t, offset_t, mode, \ + use_weight) \ + high_occupancy_backward_kernel \ + <<>>( \ + grad_output, weight, reverse_indices, segment_ids, weight_sums, \ + offsets, grad_unique_emb, B, S, D); \ + C10_CUDA_KERNEL_LAUNCH_CHECK(); + template void segment_reduce_backward_kernel_launcher( const scalar_t* grad_output, const scalar_t* weight, bool use_weight, const int64_t* reverse_indices, const offset_t* offsets, scalar_t* grad_unique_emb, int64_t B, int64_t N, int64_t S, int64_t D) { - int64_t block_size = 256; - int64_t block_num = get_sm_count() * 8; - block_num = std::min(block_num, S); + auto stream = at::cuda::getCurrentCUDAStream(); + int sm_count = get_sm_count(); + // Use high-occupancy kernel for D % 4 == 0 (most common case) + // Fall back to original segment-parallel kernel for other cases if (D % 4 == 0) { - if (use_weight) { - LAUNCH_BACKWARD_KERNEL(scalar_t, offset_t, mode, true, 4) - } else { - LAUNCH_BACKWARD_KERNEL(scalar_t, offset_t, mode, false, 4) + int64_t block_size = 256; + int64_t total_work = (B * D) / 4; + // Use larger grid size for better latency hiding + int64_t grid_size = + std::min((total_work + block_size - 1) / block_size, + static_cast(sm_count * 32)); + + // Allocate temporary buffers for segment_ids and weight_sums + int32_t* segment_ids = cuda_malloc(B * sizeof(int32_t), stream); + scalar_t* weight_sums = nullptr; + if constexpr (mode == ReduceMode::MEAN) { + if (use_weight) { + weight_sums = cuda_malloc((S - 1) * sizeof(scalar_t), stream); + } } - } else if (D % 2 == 0) { + + // Precompute segment IDs + compute_segment_ids_kernel + <<>>(offsets, segment_ids, B, S); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + + // Precompute weight sums for MEAN mode with weights + if constexpr (mode == ReduceMode::MEAN) { + if (use_weight) { + int64_t ws_grid = (S - 1 + 255) / 256; + compute_weight_sums_kernel + <<>>(weight, offsets, weight_sums, S); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + } + } + + // Launch high-occupancy kernel if (use_weight) { - LAUNCH_BACKWARD_KERNEL(scalar_t, offset_t, mode, true, 2) + LAUNCH_HIGH_OCCUPANCY_BACKWARD_KERNEL(scalar_t, offset_t, mode, true) } else { - LAUNCH_BACKWARD_KERNEL(scalar_t, offset_t, mode, false, 2) + LAUNCH_HIGH_OCCUPANCY_BACKWARD_KERNEL(scalar_t, offset_t, mode, false) + } + + // Free temporary buffers + delete_cuda_ptr(segment_ids); + if (weight_sums != nullptr) { + delete_cuda_ptr(weight_sums); } } else { - if (use_weight) { - LAUNCH_BACKWARD_KERNEL(scalar_t, offset_t, mode, true, 1) + // Fall back to original segment-parallel kernel for D % 4 != 0 + int64_t block_size = 256; + int64_t block_num = sm_count * 8; + block_num = std::min(block_num, S); + + if (D % 2 == 0) { + if (use_weight) { + LAUNCH_BACKWARD_KERNEL(scalar_t, offset_t, mode, true, 2) + } else { + LAUNCH_BACKWARD_KERNEL(scalar_t, offset_t, mode, false, 2) + } } else { - LAUNCH_BACKWARD_KERNEL(scalar_t, offset_t, mode, false, 1) + if (use_weight) { + LAUNCH_BACKWARD_KERNEL(scalar_t, offset_t, mode, true, 1) + } else { + LAUNCH_BACKWARD_KERNEL(scalar_t, offset_t, mode, false, 1) + } } } } From ae80f42f1c716ef98245b14e841806b4d7840caf Mon Sep 17 00:00:00 2001 From: bryan Date: Tue, 9 Dec 2025 17:18:55 +0800 Subject: [PATCH 2/4] add ROCm macro --- csrc/ops/embedding_segment_reduce_kernel.cu | 123 +++++++++++--------- 1 file changed, 69 insertions(+), 54 deletions(-) diff --git a/csrc/ops/embedding_segment_reduce_kernel.cu b/csrc/ops/embedding_segment_reduce_kernel.cu index e836fd9..025a072 100644 --- a/csrc/ops/embedding_segment_reduce_kernel.cu +++ b/csrc/ops/embedding_segment_reduce_kernel.cu @@ -113,48 +113,6 @@ __global__ void segment_reduce_forward_kernel( } } -//============================================================================= -// Helper kernel: Compute segment IDs for each input element -//============================================================================= -template -__global__ void compute_segment_ids_kernel( - const offset_t* __restrict__ offsets, int32_t* __restrict__ segment_ids, - int64_t B, int64_t S) { - // Each block handles one segment and fills segment IDs - int64_t seg = blockIdx.x; - if (seg >= S - 1) return; - - offset_t start = offsets[seg]; - offset_t end = offsets[seg + 1]; - - for (offset_t i = start + threadIdx.x; i < end; i += blockDim.x) { - segment_ids[i] = seg; - } -} - -//============================================================================= -// Helper kernel: Compute weight sums for each segment (for MEAN mode) -//============================================================================= -template -__global__ void compute_weight_sums_kernel( - const scalar_t* __restrict__ weight, const offset_t* __restrict__ offsets, - scalar_t* __restrict__ weight_sums, int64_t S) { - int64_t seg = blockIdx.x * blockDim.x + threadIdx.x; - if (seg >= S - 1) return; - - offset_t start = offsets[seg]; - offset_t end = offsets[seg + 1]; - - scalar_t sum = 0; - for (offset_t i = start; i < end; ++i) { - sum += weight[i]; - } - weight_sums[seg] = sum; -} - -//============================================================================= -// Original: Segment-parallel backward kernel (kept for reference/fallback) -//============================================================================= template __global__ void segment_reduce_backward_kernel( @@ -238,17 +196,43 @@ __global__ void segment_reduce_backward_kernel( } } -//============================================================================= -// Optimized: High-occupancy element-parallel backward kernel -// Key optimizations: -// 1. Element-parallel instead of segment-parallel for better load balancing -// 2. Precomputed segment IDs avoid per-element binary search -// 3. Larger grid size for better latency hiding on modern GPUs -// 4. Vectorized loads (float4) for better memory throughput -//============================================================================= +#ifdef __HIP_PLATFORM_AMD__ +template +__global__ void compute_segment_ids_kernel( + const offset_t* __restrict__ offsets, int32_t* __restrict__ segment_ids, + int64_t B, int64_t S) { + // Each block handles one segment and fills segment IDs + int64_t seg = blockIdx.x; + if (seg >= S - 1) return; + + offset_t start = offsets[seg]; + offset_t end = offsets[seg + 1]; + + for (offset_t i = start + threadIdx.x; i < end; i += blockDim.x) { + segment_ids[i] = seg; + } +} + +template +__global__ void compute_weight_sums_kernel( + const scalar_t* __restrict__ weight, const offset_t* __restrict__ offsets, + scalar_t* __restrict__ weight_sums, int64_t S) { + int64_t seg = blockIdx.x * blockDim.x + threadIdx.x; + if (seg >= S - 1) return; + + offset_t start = offsets[seg]; + offset_t end = offsets[seg + 1]; + + scalar_t sum = 0; + for (offset_t i = start; i < end; ++i) { + sum += weight[i]; + } + weight_sums[seg] = sum; +} + template -__global__ void high_occupancy_backward_kernel( +__global__ void segment_reduce_backward_kernel_rocm( const scalar_t* __restrict__ grad_output, const scalar_t* __restrict__ weight, const int64_t* __restrict__ reverse_indices, @@ -290,7 +274,8 @@ __global__ void high_occupancy_backward_kernel( } } else { if constexpr (mode == ReduceMode::MEAN) { - w_base = static_cast(1) / static_cast(offsets[seg + 1] - offsets[seg]); + w_base = static_cast(1) / + static_cast(offsets[seg + 1] - offsets[seg]); } else { w_base = static_cast(1); } @@ -304,6 +289,7 @@ __global__ void high_occupancy_backward_kernel( atomicAdd(out_ptr + 3, g_vec.w * w_base); } } +#endif #define FORWARD_LAUNCH_KERNEL(scalar_t, offset_t, mode, use_weight, vec_size) \ segment_reduce_forward_kernel \ + segment_reduce_backward_kernel_rocm \ <<>>( \ grad_output, weight, reverse_indices, segment_ids, weight_sums, \ offsets, grad_unique_emb, B, S, D); \ C10_CUDA_KERNEL_LAUNCH_CHECK(); +#endif template void segment_reduce_backward_kernel_launcher( @@ -368,6 +356,32 @@ void segment_reduce_backward_kernel_launcher( auto stream = at::cuda::getCurrentCUDAStream(); int sm_count = get_sm_count(); + +#ifndef __HIP_PLATFORM_AMD__ + int64_t block_size = 256; + int64_t block_num = get_sm_count() * 8; + block_num = std::min(block_num, S); + + if (D % 4 == 0) { + if (use_weight) { + LAUNCH_BACKWARD_KERNEL(scalar_t, offset_t, mode, true, 4) + } else { + LAUNCH_BACKWARD_KERNEL(scalar_t, offset_t, mode, false, 4) + } + } else if (D % 2 == 0) { + if (use_weight) { + LAUNCH_BACKWARD_KERNEL(scalar_t, offset_t, mode, true, 2) + } else { + LAUNCH_BACKWARD_KERNEL(scalar_t, offset_t, mode, false, 2) + } + } else { + if (use_weight) { + LAUNCH_BACKWARD_KERNEL(scalar_t, offset_t, mode, true, 1) + } else { + LAUNCH_BACKWARD_KERNEL(scalar_t, offset_t, mode, false, 1) + } + } +#else // Use high-occupancy kernel for D % 4 == 0 (most common case) // Fall back to original segment-parallel kernel for other cases if (D % 4 == 0) { @@ -434,6 +448,7 @@ void segment_reduce_backward_kernel_launcher( } } } +#endif } at::Tensor segment_reduce_forward(at::Tensor unique_emb, c10::optional weight, From fdb6474dfd3d2b4862d48dca117b1afed52cf6ee Mon Sep 17 00:00:00 2001 From: bryan Date: Tue, 9 Dec 2025 18:43:41 +0800 Subject: [PATCH 3/4] fix macro issue and code format --- csrc/ops/embedding_segment_reduce_kernel.cu | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/csrc/ops/embedding_segment_reduce_kernel.cu b/csrc/ops/embedding_segment_reduce_kernel.cu index 025a072..4ed21d0 100644 --- a/csrc/ops/embedding_segment_reduce_kernel.cu +++ b/csrc/ops/embedding_segment_reduce_kernel.cu @@ -260,7 +260,8 @@ __global__ void segment_reduce_backward_kernel_rocm( // Load gradient (4 floats vectorized) float4 g_vec; if constexpr (mode == ReduceMode::TILE) { - g_vec = *reinterpret_cast(grad_output + input_idx * D + dp); + g_vec = + *reinterpret_cast(grad_output + input_idx * D + dp); } else { g_vec = *reinterpret_cast(grad_output + seg * D + dp); } @@ -339,9 +340,9 @@ void segment_reduce_forward_kernel_launcher( C10_CUDA_KERNEL_LAUNCH_CHECK(); #ifdef __HIP_PLATFORM_AMD__ -#define LAUNCH_BACKWARD_KERNEL_ROCM(scalar_t, offset_t, mode, \ +#define LAUNCH_BACKWARD_KERNEL_ROCM(scalar_t, offset_t, mode, \ use_weight) \ - segment_reduce_backward_kernel_rocm \ + segment_reduce_backward_kernel_rocm \ <<>>( \ grad_output, weight, reverse_indices, segment_ids, weight_sums, \ offsets, grad_unique_emb, B, S, D); \ @@ -403,24 +404,24 @@ void segment_reduce_backward_kernel_launcher( // Precompute segment IDs compute_segment_ids_kernel - <<>>(offsets, segment_ids, B, S); + <<>>(offsets, segment_ids, B, S); C10_CUDA_KERNEL_LAUNCH_CHECK(); // Precompute weight sums for MEAN mode with weights if constexpr (mode == ReduceMode::MEAN) { if (use_weight) { - int64_t ws_grid = (S - 1 + 255) / 256; + int64_t ws_grid = (S - 1 + block_size - 1) / block_size; compute_weight_sums_kernel - <<>>(weight, offsets, weight_sums, S); + <<>>(weight, offsets, weight_sums, S); C10_CUDA_KERNEL_LAUNCH_CHECK(); } } // Launch high-occupancy kernel if (use_weight) { - LAUNCH_HIGH_OCCUPANCY_BACKWARD_KERNEL(scalar_t, offset_t, mode, true) + LAUNCH_BACKWARD_KERNEL_ROCM(scalar_t, offset_t, mode, true) } else { - LAUNCH_HIGH_OCCUPANCY_BACKWARD_KERNEL(scalar_t, offset_t, mode, false) + LAUNCH_BACKWARD_KERNEL_ROCM(scalar_t, offset_t, mode, false) } // Free temporary buffers From 7f5c1d4c02edab7d92e730cc3e1de974e4b0924e Mon Sep 17 00:00:00 2001 From: bryan Date: Mon, 15 Dec 2025 11:33:02 +0800 Subject: [PATCH 4/4] fix empty tensor issue --- csrc/ops/embedding_segment_reduce_kernel.cu | 40 +++++++++++++++++---- 1 file changed, 33 insertions(+), 7 deletions(-) diff --git a/csrc/ops/embedding_segment_reduce_kernel.cu b/csrc/ops/embedding_segment_reduce_kernel.cu index 4ed21d0..060914c 100644 --- a/csrc/ops/embedding_segment_reduce_kernel.cu +++ b/csrc/ops/embedding_segment_reduce_kernel.cu @@ -217,17 +217,35 @@ template __global__ void compute_weight_sums_kernel( const scalar_t* __restrict__ weight, const offset_t* __restrict__ offsets, scalar_t* __restrict__ weight_sums, int64_t S) { - int64_t seg = blockIdx.x * blockDim.x + threadIdx.x; + using BlockReduce = cub::BlockReduce; + __shared__ typename BlockReduce::TempStorage temp_storage; + + // Each block handles one segment + int64_t seg = blockIdx.x; if (seg >= S - 1) return; offset_t start = offsets[seg]; offset_t end = offsets[seg + 1]; + int64_t length = end - start; + + scalar_t weight_sum = 0; + // The loop ending condition ensures all threads participate in BlockReduce + for (int64_t i = threadIdx.x; i / blockDim.x * blockDim.x < length; + i += blockDim.x) { + scalar_t w = 0; + if (i < length) { + w = weight[start + i]; + } + scalar_t res = BlockReduce(temp_storage).Sum(w); + if (threadIdx.x == 0) { + weight_sum += res; + } + __syncthreads(); + } - scalar_t sum = 0; - for (offset_t i = start; i < end; ++i) { - sum += weight[i]; + if (threadIdx.x == 0) { + weight_sums[seg] = weight_sum; } - weight_sums[seg] = sum; } template - <<>>(weight, offsets, weight_sums, S); + <<>>(weight, offsets, weight_sums, S); C10_CUDA_KERNEL_LAUNCH_CHECK(); } }