From be034122494ef6e5fb8da5a9abb571aa1135e9d9 Mon Sep 17 00:00:00 2001 From: bryan Date: Mon, 1 Dec 2025 21:38:35 +0800 Subject: [PATCH] embedding segment reduce forward opt --- csrc/ops/embedding_segment_reduce_kernel.cu | 216 ++++++++++++++++++++ tests/ops/emb_segment_reduce_test.py | 2 +- 2 files changed, 217 insertions(+), 1 deletion(-) diff --git a/csrc/ops/embedding_segment_reduce_kernel.cu b/csrc/ops/embedding_segment_reduce_kernel.cu index 2c38a57..eef4984 100644 --- a/csrc/ops/embedding_segment_reduce_kernel.cu +++ b/csrc/ops/embedding_segment_reduce_kernel.cu @@ -113,6 +113,200 @@ __global__ void segment_reduce_forward_kernel( } } +#ifdef __HIP_PLATFORM_AMD__ +#define LDS_LIMIT_BYTES (64u * 1024u) + +__device__ __forceinline__ float warp_reduce_sum(float val, int width) { +#pragma unroll + for (int offset = width >> 1; offset > 0; offset >>= 1) { + val += __shfl_down(val, offset, width); + } + return val; +} + +template +__global__ void segment_reduce_forward_kernel_rocm( + const scalar_t* __restrict__ unique_emb, + const scalar_t* __restrict__ weight, + const int64_t* __restrict__ reverse_indices, + const offset_t* __restrict__ offsets, scalar_t* output, int64_t B, + int64_t N, int64_t S, int64_t D) { + using AP = Packer; + using BlockReduce = hipcub::BlockReduce; + + extern __shared__ __align__(sizeof(scalar_t)) unsigned char shared_raw[]; + scalar_t* __restrict__ smem = reinterpret_cast(shared_raw); + + // These shared variables will be optimized out by the compiler + // if we are not computing weighted mean. + using BlockReduce = hipcub::BlockReduce; + typename BlockReduce::TempStorage* temp_storage = + reinterpret_cast(smem); + + scalar_t* s_weight_sum = reinterpret_cast( + reinterpret_cast(temp_storage) + sizeof(*temp_storage)); + + const int warp = warpSize; + const int waves = blockDim.x / warp; + const int lane = threadIdx.x & (warp - 1); + const int wave_id = threadIdx.x / warp; + + // Precompute stride increments to reduce div/mod inside the loop + const int64_t stride_items = + static_cast(blockDim.x) * static_cast(PACK_SIZE); + const int64_t delta_q = (D > 0) ? (stride_items / D) : 0; + const int64_t delta_dp = (D > 0) ? (stride_items % D) : 0; + + // Initial dp/q per thread + int64_t q0 = ((int64_t)threadIdx.x * (int64_t)PACK_SIZE) / D; + int64_t dp0 = ((int64_t)threadIdx.x * (int64_t)PACK_SIZE) % D; + + // Small chunk to improve ILP + constexpr int kChunk = 2; + + for (int s = blockIdx.x; s < S - 1; s += gridDim.x) { + offset_t start = offsets[s]; + offset_t end = offsets[s + 1]; + int64_t length = end - start; + int64_t total_size = length * D; + + if constexpr (mode == ReduceMode::TILE) { + // TILE mode: direct scaled write per idx, vectorized + int64_t q = q0; + int64_t dp = dp0; + for (int64_t i_base = threadIdx.x; i_base * PACK_SIZE < total_size; + i_base += blockDim.x * kChunk) { +#pragma unroll + for (int c = 0; c < kChunk; ++c) { + int64_t it = i_base + c * blockDim.x; + if (!(it * PACK_SIZE < total_size)) break; + int64_t idx = start + q; + int64_t raw_idx = reverse_indices[idx]; + scalar_t w = scalar_t(1); + if constexpr (USE_WEIGHT) { w = weight[idx]; } + + typename AP::type a_vec; + typename AP::type b_vec; + AP::load(unique_emb + raw_idx * D + dp, a_vec); +#pragma unroll + for (int j = 0; j < PACK_SIZE; j++) { + auto a_val = AP::get_element(a_vec, j); + AP::set_element(b_vec, j, a_val * w); + } + AP::store(output + idx * D + dp, b_vec); + + dp += delta_dp; + q += delta_q; + if (dp >= D) { dp -= D; ++q; } + } + } + } else { + scalar_t weight_sum = 0; + if constexpr (USE_WEIGHT && mode == ReduceMode::MEAN) { + for (int64_t i = threadIdx.x; i / blockDim.x * blockDim.x < length; + i += blockDim.x) { + scalar_t w = 0; + if (i < length) { + w = weight[start + i]; + } + scalar_t res = BlockReduce(*temp_storage).Sum(w); + if (threadIdx.x == 0) { + weight_sum += res; + } + __syncthreads(); + } + if (threadIdx.x == 0) { + *s_weight_sum = weight_sum; + } + __syncthreads(); + weight_sum = *s_weight_sum; + } + + const int max_cols = + static_cast(LDS_LIMIT_BYTES / sizeof(scalar_t) / blockDim.x); + + for (int d_base = 0; d_base < D; d_base += max_cols) { + int this_cols = max_cols; + if (d_base + this_cols > D) this_cols = D - d_base; + if (this_cols <= 0) break; + + // Initialize LDS tile: layout [col, thread] + for (int idx = threadIdx.x; idx < this_cols * blockDim.x; + idx += blockDim.x) { + smem[idx] = scalar_t(0); + } + __syncthreads(); + + // Reset per-thread q/dp for this tile + int64_t q = q0; + int64_t dp = dp0; + + // Iterate global order but only accumulate indices + for (int64_t i_base = threadIdx.x; i_base * PACK_SIZE < total_size; + i_base += blockDim.x * kChunk) { +#pragma unroll + for (int c = 0; c < kChunk; ++c) { + int64_t it = i_base + c * blockDim.x; + if (!(it * PACK_SIZE < total_size)) break; + + int64_t idx = start + q; + int64_t raw_idx = reverse_indices[idx]; + scalar_t w = scalar_t(1); + if constexpr (USE_WEIGHT) { + w = weight[idx]; + } else { + weight_sum = static_cast(length); + } + if constexpr (mode == ReduceMode::MEAN) { + w = w / weight_sum; + } + + typename AP::type a_vec; + AP::load(unique_emb + raw_idx * D + dp, a_vec); +#pragma unroll + for (int j = 0; j < PACK_SIZE; j++) { + int64_t gcol = dp + j; + auto a_val = AP::get_element(a_vec, j); + auto res = a_val * w; + if (gcol >= d_base && gcol < d_base + this_cols) { + int lcol = static_cast(gcol - d_base); + smem[lcol * blockDim.x + threadIdx.x] += res; + } + } + + dp += delta_dp; + q += delta_q; + if (dp >= D) { dp -= D; ++q; } + } + } + + __syncthreads(); + for (int lc = 0; lc < this_cols; ++lc) { + scalar_t* row = smem + lc * blockDim.x; + scalar_t val = row[threadIdx.x]; + // wavefront reduction + val = warp_reduce_sum(val, warp); + if (lane == 0) { + row[wave_id] = val; + } + __syncthreads(); + if (threadIdx.x == 0) { + scalar_t total = scalar_t(0); +#pragma unroll + for (int w = 0; w < waves; ++w) { + total += row[w]; + } + output[s * D + (d_base + lc)] = total; + } + __syncthreads(); + } + } + } + } +} +#endif + template __global__ void segment_reduce_backward_kernel( @@ -196,12 +390,22 @@ __global__ void segment_reduce_backward_kernel( } } +#ifndef __HIP_PLATFORM_AMD__ #define FORWARD_LAUNCH_KERNEL(scalar_t, offset_t, mode, use_weight, vec_size) \ segment_reduce_forward_kernel \ <<>>( \ unique_emb, weight, reverse_indices, offsets, output, B, N, S, D); \ C10_CUDA_KERNEL_LAUNCH_CHECK(); +#else +#define FORWARD_LAUNCH_KERNEL(scalar_t, offset_t, mode, use_weight, vec_size) \ + segment_reduce_forward_kernel_rocm \ + <<>>( \ + unique_emb, weight, reverse_indices, offsets, output, B, N, S, D); \ + C10_CUDA_KERNEL_LAUNCH_CHECK(); +#endif template void segment_reduce_forward_kernel_launcher( @@ -212,6 +416,18 @@ void segment_reduce_forward_kernel_launcher( int64_t block_num = 65536; block_num = std::min(block_num, S); +#ifdef __HIP_PLATFORM_AMD__ + size_t shared_memory_size = 0; + if (mode != ReduceMode::TILE) { + size_t max_cols = + LDS_LIMIT_BYTES / sizeof(scalar_t) / static_cast(block_size); + if (max_cols < 1) max_cols = 1; + size_t cols = std::min(static_cast(D), max_cols); + shared_memory_size = + cols * static_cast(block_size) * sizeof(scalar_t); + } +#endif + if (D % 4 == 0) { if (use_weight) { FORWARD_LAUNCH_KERNEL(scalar_t, offset_t, mode, true, 4) diff --git a/tests/ops/emb_segment_reduce_test.py b/tests/ops/emb_segment_reduce_test.py index 1d64078..b5bc4a8 100644 --- a/tests/ops/emb_segment_reduce_test.py +++ b/tests/ops/emb_segment_reduce_test.py @@ -1,6 +1,6 @@ import unittest - import torch +import recis def equal(a: torch.Tensor, b: torch.Tensor):