diff --git a/csrc/ops/embedding_segment_reduce_kernel.cu b/csrc/ops/embedding_segment_reduce_kernel.cu index 2c38a57..060914c 100644 --- a/csrc/ops/embedding_segment_reduce_kernel.cu +++ b/csrc/ops/embedding_segment_reduce_kernel.cu @@ -196,6 +196,120 @@ __global__ void segment_reduce_backward_kernel( } } +#ifdef __HIP_PLATFORM_AMD__ +template +__global__ void compute_segment_ids_kernel( + const offset_t* __restrict__ offsets, int32_t* __restrict__ segment_ids, + int64_t B, int64_t S) { + // Each block handles one segment and fills segment IDs + int64_t seg = blockIdx.x; + if (seg >= S - 1) return; + + offset_t start = offsets[seg]; + offset_t end = offsets[seg + 1]; + + for (offset_t i = start + threadIdx.x; i < end; i += blockDim.x) { + segment_ids[i] = seg; + } +} + +template +__global__ void compute_weight_sums_kernel( + const scalar_t* __restrict__ weight, const offset_t* __restrict__ offsets, + scalar_t* __restrict__ weight_sums, int64_t S) { + using BlockReduce = cub::BlockReduce; + __shared__ typename BlockReduce::TempStorage temp_storage; + + // Each block handles one segment + int64_t seg = blockIdx.x; + if (seg >= S - 1) return; + + offset_t start = offsets[seg]; + offset_t end = offsets[seg + 1]; + int64_t length = end - start; + + scalar_t weight_sum = 0; + // The loop ending condition ensures all threads participate in BlockReduce + for (int64_t i = threadIdx.x; i / blockDim.x * blockDim.x < length; + i += blockDim.x) { + scalar_t w = 0; + if (i < length) { + w = weight[start + i]; + } + scalar_t res = BlockReduce(temp_storage).Sum(w); + if (threadIdx.x == 0) { + weight_sum += res; + } + __syncthreads(); + } + + if (threadIdx.x == 0) { + weight_sums[seg] = weight_sum; + } +} + +template +__global__ void segment_reduce_backward_kernel_rocm( + const scalar_t* __restrict__ grad_output, + const scalar_t* __restrict__ weight, + const int64_t* __restrict__ reverse_indices, + const int32_t* __restrict__ segment_ids, + const scalar_t* __restrict__ weight_sums, + const offset_t* __restrict__ offsets, scalar_t* grad_unique_emb, int64_t B, + int64_t S, int64_t D) { + // Total work items: B * D / 4 (with PACK_SIZE=4) + const int64_t total_work = (B * D) >> 2; // Divide by 4 + const int64_t tid = blockIdx.x * blockDim.x + threadIdx.x; + const int64_t stride = blockDim.x * gridDim.x; + + // Process multiple elements per thread for better ILP + #pragma unroll 1 // Don't unroll outer loop to save registers + for (int64_t work_idx = tid; work_idx < total_work; work_idx += stride) { + // Decode flat index to (input_idx, dp) + const int64_t flat_idx = work_idx << 2; // Multiply by 4 + const int64_t input_idx = flat_idx / D; + const int64_t dp = flat_idx - input_idx * D; // Faster than modulo + + // Load segment and unique index (these are per-row, not per-element) + const int32_t seg = segment_ids[input_idx]; + const int64_t raw_idx = reverse_indices[input_idx]; + + // Load gradient (4 floats vectorized) + float4 g_vec; + if constexpr (mode == ReduceMode::TILE) { + g_vec = + *reinterpret_cast(grad_output + input_idx * D + dp); + } else { + g_vec = *reinterpret_cast(grad_output + seg * D + dp); + } + + // Compute weight factor + scalar_t w_base; + if constexpr (USE_WEIGHT) { + w_base = weight[input_idx]; + if constexpr (mode == ReduceMode::MEAN) { + w_base /= weight_sums[seg]; + } + } else { + if constexpr (mode == ReduceMode::MEAN) { + w_base = static_cast(1) / + static_cast(offsets[seg + 1] - offsets[seg]); + } else { + w_base = static_cast(1); + } + } + + // Write results with atomics (4 consecutive floats) + scalar_t* out_ptr = grad_unique_emb + raw_idx * D + dp; + atomicAdd(out_ptr + 0, g_vec.x * w_base); + atomicAdd(out_ptr + 1, g_vec.y * w_base); + atomicAdd(out_ptr + 2, g_vec.z * w_base); + atomicAdd(out_ptr + 3, g_vec.w * w_base); + } +} +#endif + #define FORWARD_LAUNCH_KERNEL(scalar_t, offset_t, mode, use_weight, vec_size) \ segment_reduce_forward_kernel \ @@ -243,11 +357,26 @@ void segment_reduce_forward_kernel_launcher( N, S, D); \ C10_CUDA_KERNEL_LAUNCH_CHECK(); +#ifdef __HIP_PLATFORM_AMD__ +#define LAUNCH_BACKWARD_KERNEL_ROCM(scalar_t, offset_t, mode, \ + use_weight) \ + segment_reduce_backward_kernel_rocm \ + <<>>( \ + grad_output, weight, reverse_indices, segment_ids, weight_sums, \ + offsets, grad_unique_emb, B, S, D); \ + C10_CUDA_KERNEL_LAUNCH_CHECK(); +#endif + template void segment_reduce_backward_kernel_launcher( const scalar_t* grad_output, const scalar_t* weight, bool use_weight, const int64_t* reverse_indices, const offset_t* offsets, scalar_t* grad_unique_emb, int64_t B, int64_t N, int64_t S, int64_t D) { + auto stream = at::cuda::getCurrentCUDAStream(); + int sm_count = get_sm_count(); + + +#ifndef __HIP_PLATFORM_AMD__ int64_t block_size = 256; int64_t block_num = get_sm_count() * 8; block_num = std::min(block_num, S); @@ -271,6 +400,82 @@ void segment_reduce_backward_kernel_launcher( LAUNCH_BACKWARD_KERNEL(scalar_t, offset_t, mode, false, 1) } } +#else + // Early return if there's no work to do + // B == 0: no input elements + // S <= 1: no segments (S-1 segments, need at least 1) + // D == 0: no embedding dimension + if (B == 0 || S <= 1 || D == 0) { + return; + } + + // Use high-occupancy kernel for D % 4 == 0 (most common case) + // Fall back to original segment-parallel kernel for other cases + if (D % 4 == 0) { + int64_t block_size = 256; + int64_t total_work = (B * D) / 4; + // Use larger grid size for better latency hiding + int64_t grid_size = + std::min((total_work + block_size - 1) / block_size, + static_cast(sm_count * 32)); + + // Allocate temporary buffers for segment_ids and weight_sums + int32_t* segment_ids = cuda_malloc(B * sizeof(int32_t), stream); + scalar_t* weight_sums = nullptr; + if constexpr (mode == ReduceMode::MEAN) { + if (use_weight) { + weight_sums = cuda_malloc((S - 1) * sizeof(scalar_t), stream); + } + } + + // Precompute segment IDs + compute_segment_ids_kernel + <<>>(offsets, segment_ids, B, S); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + + // Precompute weight sums for MEAN mode with weights + if constexpr (mode == ReduceMode::MEAN) { + if (use_weight) { + // One block per segment for BlockReduce + compute_weight_sums_kernel + <<>>(weight, offsets, weight_sums, S); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + } + } + + // Launch high-occupancy kernel + if (use_weight) { + LAUNCH_BACKWARD_KERNEL_ROCM(scalar_t, offset_t, mode, true) + } else { + LAUNCH_BACKWARD_KERNEL_ROCM(scalar_t, offset_t, mode, false) + } + + // Free temporary buffers + delete_cuda_ptr(segment_ids); + if (weight_sums != nullptr) { + delete_cuda_ptr(weight_sums); + } + } else { + // Fall back to original segment-parallel kernel for D % 4 != 0 + int64_t block_size = 256; + int64_t block_num = sm_count * 8; + block_num = std::min(block_num, S); + + if (D % 2 == 0) { + if (use_weight) { + LAUNCH_BACKWARD_KERNEL(scalar_t, offset_t, mode, true, 2) + } else { + LAUNCH_BACKWARD_KERNEL(scalar_t, offset_t, mode, false, 2) + } + } else { + if (use_weight) { + LAUNCH_BACKWARD_KERNEL(scalar_t, offset_t, mode, true, 1) + } else { + LAUNCH_BACKWARD_KERNEL(scalar_t, offset_t, mode, false, 1) + } + } + } +#endif } at::Tensor segment_reduce_forward(at::Tensor unique_emb, c10::optional weight,