Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
205 changes: 205 additions & 0 deletions csrc/ops/embedding_segment_reduce_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,120 @@ __global__ void segment_reduce_backward_kernel(
}
}

#ifdef __HIP_PLATFORM_AMD__
template <typename offset_t>
__global__ void compute_segment_ids_kernel(
const offset_t* __restrict__ offsets, int32_t* __restrict__ segment_ids,
int64_t B, int64_t S) {
// Each block handles one segment and fills segment IDs
int64_t seg = blockIdx.x;
if (seg >= S - 1) return;

offset_t start = offsets[seg];
offset_t end = offsets[seg + 1];

for (offset_t i = start + threadIdx.x; i < end; i += blockDim.x) {
segment_ids[i] = seg;
}
}

template <typename scalar_t, typename offset_t>
__global__ void compute_weight_sums_kernel(
const scalar_t* __restrict__ weight, const offset_t* __restrict__ offsets,
scalar_t* __restrict__ weight_sums, int64_t S) {
using BlockReduce = cub::BlockReduce<scalar_t, 256>;
__shared__ typename BlockReduce::TempStorage temp_storage;

// Each block handles one segment
int64_t seg = blockIdx.x;
if (seg >= S - 1) return;

offset_t start = offsets[seg];
offset_t end = offsets[seg + 1];
int64_t length = end - start;

scalar_t weight_sum = 0;
// The loop ending condition ensures all threads participate in BlockReduce
for (int64_t i = threadIdx.x; i / blockDim.x * blockDim.x < length;
i += blockDim.x) {
scalar_t w = 0;
if (i < length) {
w = weight[start + i];
}
scalar_t res = BlockReduce(temp_storage).Sum(w);
if (threadIdx.x == 0) {
weight_sum += res;
}
__syncthreads();
}

if (threadIdx.x == 0) {
weight_sums[seg] = weight_sum;
}
}

template <typename scalar_t, typename offset_t, ReduceMode mode,
bool USE_WEIGHT>
__global__ void segment_reduce_backward_kernel_rocm(
const scalar_t* __restrict__ grad_output,
const scalar_t* __restrict__ weight,
const int64_t* __restrict__ reverse_indices,
const int32_t* __restrict__ segment_ids,
const scalar_t* __restrict__ weight_sums,
const offset_t* __restrict__ offsets, scalar_t* grad_unique_emb, int64_t B,
int64_t S, int64_t D) {
// Total work items: B * D / 4 (with PACK_SIZE=4)
const int64_t total_work = (B * D) >> 2; // Divide by 4
const int64_t tid = blockIdx.x * blockDim.x + threadIdx.x;
const int64_t stride = blockDim.x * gridDim.x;

// Process multiple elements per thread for better ILP
#pragma unroll 1 // Don't unroll outer loop to save registers
for (int64_t work_idx = tid; work_idx < total_work; work_idx += stride) {
// Decode flat index to (input_idx, dp)
const int64_t flat_idx = work_idx << 2; // Multiply by 4
const int64_t input_idx = flat_idx / D;
const int64_t dp = flat_idx - input_idx * D; // Faster than modulo

// Load segment and unique index (these are per-row, not per-element)
const int32_t seg = segment_ids[input_idx];
const int64_t raw_idx = reverse_indices[input_idx];

// Load gradient (4 floats vectorized)
float4 g_vec;
if constexpr (mode == ReduceMode::TILE) {
g_vec =
*reinterpret_cast<const float4*>(grad_output + input_idx * D + dp);
} else {
g_vec = *reinterpret_cast<const float4*>(grad_output + seg * D + dp);
}

// Compute weight factor
scalar_t w_base;
if constexpr (USE_WEIGHT) {
w_base = weight[input_idx];
if constexpr (mode == ReduceMode::MEAN) {
w_base /= weight_sums[seg];
}
} else {
if constexpr (mode == ReduceMode::MEAN) {
w_base = static_cast<scalar_t>(1) /
static_cast<scalar_t>(offsets[seg + 1] - offsets[seg]);
} else {
w_base = static_cast<scalar_t>(1);
}
}

// Write results with atomics (4 consecutive floats)
scalar_t* out_ptr = grad_unique_emb + raw_idx * D + dp;
atomicAdd(out_ptr + 0, g_vec.x * w_base);
atomicAdd(out_ptr + 1, g_vec.y * w_base);
atomicAdd(out_ptr + 2, g_vec.z * w_base);
atomicAdd(out_ptr + 3, g_vec.w * w_base);
}
}
#endif

#define FORWARD_LAUNCH_KERNEL(scalar_t, offset_t, mode, use_weight, vec_size) \
segment_reduce_forward_kernel<scalar_t, offset_t, mode, use_weight, \
vec_size> \
Expand Down Expand Up @@ -243,11 +357,26 @@ void segment_reduce_forward_kernel_launcher(
N, S, D); \
C10_CUDA_KERNEL_LAUNCH_CHECK();

#ifdef __HIP_PLATFORM_AMD__
#define LAUNCH_BACKWARD_KERNEL_ROCM(scalar_t, offset_t, mode, \
use_weight) \
segment_reduce_backward_kernel_rocm<scalar_t, offset_t, mode, use_weight> \
<<<grid_size, block_size, 0, stream>>>( \
grad_output, weight, reverse_indices, segment_ids, weight_sums, \
offsets, grad_unique_emb, B, S, D); \
C10_CUDA_KERNEL_LAUNCH_CHECK();
#endif

template <typename scalar_t, typename offset_t, ReduceMode mode>
void segment_reduce_backward_kernel_launcher(
const scalar_t* grad_output, const scalar_t* weight, bool use_weight,
const int64_t* reverse_indices, const offset_t* offsets,
scalar_t* grad_unique_emb, int64_t B, int64_t N, int64_t S, int64_t D) {
auto stream = at::cuda::getCurrentCUDAStream();
int sm_count = get_sm_count();


#ifndef __HIP_PLATFORM_AMD__
int64_t block_size = 256;
int64_t block_num = get_sm_count() * 8;
block_num = std::min(block_num, S);
Expand All @@ -271,6 +400,82 @@ void segment_reduce_backward_kernel_launcher(
LAUNCH_BACKWARD_KERNEL(scalar_t, offset_t, mode, false, 1)
}
}
#else
// Early return if there's no work to do
// B == 0: no input elements
// S <= 1: no segments (S-1 segments, need at least 1)
// D == 0: no embedding dimension
if (B == 0 || S <= 1 || D == 0) {
return;
}

// Use high-occupancy kernel for D % 4 == 0 (most common case)
// Fall back to original segment-parallel kernel for other cases
if (D % 4 == 0) {
int64_t block_size = 256;
int64_t total_work = (B * D) / 4;
// Use larger grid size for better latency hiding
int64_t grid_size =
std::min((total_work + block_size - 1) / block_size,
static_cast<int64_t>(sm_count * 32));

// Allocate temporary buffers for segment_ids and weight_sums
int32_t* segment_ids = cuda_malloc<int32_t>(B * sizeof(int32_t), stream);
scalar_t* weight_sums = nullptr;
if constexpr (mode == ReduceMode::MEAN) {
if (use_weight) {
weight_sums = cuda_malloc<scalar_t>((S - 1) * sizeof(scalar_t), stream);
}
}

// Precompute segment IDs
compute_segment_ids_kernel<offset_t>
<<<S - 1, block_size, 0, stream>>>(offsets, segment_ids, B, S);
C10_CUDA_KERNEL_LAUNCH_CHECK();

// Precompute weight sums for MEAN mode with weights
if constexpr (mode == ReduceMode::MEAN) {
if (use_weight) {
// One block per segment for BlockReduce
compute_weight_sums_kernel<scalar_t, offset_t>
<<<S - 1, block_size, 0, stream>>>(weight, offsets, weight_sums, S);
C10_CUDA_KERNEL_LAUNCH_CHECK();
}
}

// Launch high-occupancy kernel
if (use_weight) {
LAUNCH_BACKWARD_KERNEL_ROCM(scalar_t, offset_t, mode, true)
} else {
LAUNCH_BACKWARD_KERNEL_ROCM(scalar_t, offset_t, mode, false)
}

// Free temporary buffers
delete_cuda_ptr(segment_ids);
if (weight_sums != nullptr) {
delete_cuda_ptr(weight_sums);
}
} else {
// Fall back to original segment-parallel kernel for D % 4 != 0
int64_t block_size = 256;
int64_t block_num = sm_count * 8;
block_num = std::min(block_num, S);

if (D % 2 == 0) {
if (use_weight) {
LAUNCH_BACKWARD_KERNEL(scalar_t, offset_t, mode, true, 2)
} else {
LAUNCH_BACKWARD_KERNEL(scalar_t, offset_t, mode, false, 2)
}
} else {
if (use_weight) {
LAUNCH_BACKWARD_KERNEL(scalar_t, offset_t, mode, true, 1)
} else {
LAUNCH_BACKWARD_KERNEL(scalar_t, offset_t, mode, false, 1)
}
}
}
#endif
}
at::Tensor segment_reduce_forward(at::Tensor unique_emb,
c10::optional<at::Tensor> weight,
Expand Down