Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
216 changes: 216 additions & 0 deletions csrc/ops/embedding_segment_reduce_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,200 @@ __global__ void segment_reduce_forward_kernel(
}
}

#ifdef __HIP_PLATFORM_AMD__
#define LDS_LIMIT_BYTES (64u * 1024u)

__device__ __forceinline__ float warp_reduce_sum(float val, int width) {
#pragma unroll
for (int offset = width >> 1; offset > 0; offset >>= 1) {
val += __shfl_down(val, offset, width);
}
return val;
}

template <typename scalar_t, typename offset_t, ReduceMode mode,
bool USE_WEIGHT, int PACK_SIZE>
__global__ void segment_reduce_forward_kernel_rocm(
const scalar_t* __restrict__ unique_emb,
const scalar_t* __restrict__ weight,
const int64_t* __restrict__ reverse_indices,
const offset_t* __restrict__ offsets, scalar_t* output, int64_t B,
int64_t N, int64_t S, int64_t D) {
using AP = Packer<scalar_t, PACK_SIZE>;
using BlockReduce = hipcub::BlockReduce<scalar_t, 256>;

extern __shared__ __align__(sizeof(scalar_t)) unsigned char shared_raw[];
scalar_t* __restrict__ smem = reinterpret_cast<scalar_t*>(shared_raw);

// These shared variables will be optimized out by the compiler
// if we are not computing weighted mean.
using BlockReduce = hipcub::BlockReduce<scalar_t, 256>;
typename BlockReduce::TempStorage* temp_storage =
reinterpret_cast<typename BlockReduce::TempStorage*>(smem);

scalar_t* s_weight_sum = reinterpret_cast<scalar_t*>(
reinterpret_cast<char*>(temp_storage) + sizeof(*temp_storage));

const int warp = warpSize;
const int waves = blockDim.x / warp;
const int lane = threadIdx.x & (warp - 1);
const int wave_id = threadIdx.x / warp;

// Precompute stride increments to reduce div/mod inside the loop
const int64_t stride_items =
static_cast<int64_t>(blockDim.x) * static_cast<int64_t>(PACK_SIZE);
const int64_t delta_q = (D > 0) ? (stride_items / D) : 0;
const int64_t delta_dp = (D > 0) ? (stride_items % D) : 0;

// Initial dp/q per thread
int64_t q0 = ((int64_t)threadIdx.x * (int64_t)PACK_SIZE) / D;
int64_t dp0 = ((int64_t)threadIdx.x * (int64_t)PACK_SIZE) % D;

// Small chunk to improve ILP
constexpr int kChunk = 2;

for (int s = blockIdx.x; s < S - 1; s += gridDim.x) {
offset_t start = offsets[s];
offset_t end = offsets[s + 1];
int64_t length = end - start;
int64_t total_size = length * D;

if constexpr (mode == ReduceMode::TILE) {
// TILE mode: direct scaled write per idx, vectorized
int64_t q = q0;
int64_t dp = dp0;
for (int64_t i_base = threadIdx.x; i_base * PACK_SIZE < total_size;
i_base += blockDim.x * kChunk) {
#pragma unroll
for (int c = 0; c < kChunk; ++c) {
int64_t it = i_base + c * blockDim.x;
if (!(it * PACK_SIZE < total_size)) break;
int64_t idx = start + q;
int64_t raw_idx = reverse_indices[idx];
scalar_t w = scalar_t(1);
if constexpr (USE_WEIGHT) { w = weight[idx]; }

typename AP::type a_vec;
typename AP::type b_vec;
AP::load(unique_emb + raw_idx * D + dp, a_vec);
#pragma unroll
for (int j = 0; j < PACK_SIZE; j++) {
auto a_val = AP::get_element(a_vec, j);
AP::set_element(b_vec, j, a_val * w);
}
AP::store(output + idx * D + dp, b_vec);

dp += delta_dp;
q += delta_q;
if (dp >= D) { dp -= D; ++q; }
}
}
} else {
scalar_t weight_sum = 0;
if constexpr (USE_WEIGHT && mode == ReduceMode::MEAN) {
for (int64_t i = threadIdx.x; i / blockDim.x * blockDim.x < length;
i += blockDim.x) {
scalar_t w = 0;
if (i < length) {
w = weight[start + i];
}
scalar_t res = BlockReduce(*temp_storage).Sum(w);
if (threadIdx.x == 0) {
weight_sum += res;
}
__syncthreads();
}
if (threadIdx.x == 0) {
*s_weight_sum = weight_sum;
}
__syncthreads();
weight_sum = *s_weight_sum;
}

const int max_cols =
static_cast<int>(LDS_LIMIT_BYTES / sizeof(scalar_t) / blockDim.x);

for (int d_base = 0; d_base < D; d_base += max_cols) {
int this_cols = max_cols;
if (d_base + this_cols > D) this_cols = D - d_base;
if (this_cols <= 0) break;

// Initialize LDS tile: layout [col, thread]
for (int idx = threadIdx.x; idx < this_cols * blockDim.x;
idx += blockDim.x) {
smem[idx] = scalar_t(0);
}
__syncthreads();

// Reset per-thread q/dp for this tile
int64_t q = q0;
int64_t dp = dp0;

// Iterate global order but only accumulate indices
for (int64_t i_base = threadIdx.x; i_base * PACK_SIZE < total_size;
i_base += blockDim.x * kChunk) {
#pragma unroll
for (int c = 0; c < kChunk; ++c) {
int64_t it = i_base + c * blockDim.x;
if (!(it * PACK_SIZE < total_size)) break;

int64_t idx = start + q;
int64_t raw_idx = reverse_indices[idx];
scalar_t w = scalar_t(1);
if constexpr (USE_WEIGHT) {
w = weight[idx];
} else {
weight_sum = static_cast<scalar_t>(length);
}
if constexpr (mode == ReduceMode::MEAN) {
w = w / weight_sum;
}

typename AP::type a_vec;
AP::load(unique_emb + raw_idx * D + dp, a_vec);
#pragma unroll
for (int j = 0; j < PACK_SIZE; j++) {
int64_t gcol = dp + j;
auto a_val = AP::get_element(a_vec, j);
auto res = a_val * w;
if (gcol >= d_base && gcol < d_base + this_cols) {
int lcol = static_cast<int>(gcol - d_base);
smem[lcol * blockDim.x + threadIdx.x] += res;
}
}

dp += delta_dp;
q += delta_q;
if (dp >= D) { dp -= D; ++q; }
}
}

__syncthreads();
for (int lc = 0; lc < this_cols; ++lc) {
scalar_t* row = smem + lc * blockDim.x;
scalar_t val = row[threadIdx.x];
// wavefront reduction
val = warp_reduce_sum(val, warp);
if (lane == 0) {
row[wave_id] = val;
}
__syncthreads();
if (threadIdx.x == 0) {
scalar_t total = scalar_t(0);
#pragma unroll
for (int w = 0; w < waves; ++w) {
total += row[w];
}
output[s * D + (d_base + lc)] = total;
}
__syncthreads();
}
}
}
}
}
#endif

template <typename scalar_t, typename offset_t, ReduceMode mode,
bool USE_WEIGHT, int PACK_SIZE>
__global__ void segment_reduce_backward_kernel(
Expand Down Expand Up @@ -196,12 +390,22 @@ __global__ void segment_reduce_backward_kernel(
}
}

#ifndef __HIP_PLATFORM_AMD__
#define FORWARD_LAUNCH_KERNEL(scalar_t, offset_t, mode, use_weight, vec_size) \
segment_reduce_forward_kernel<scalar_t, offset_t, mode, use_weight, \
vec_size> \
<<<block_num, block_size, 0, at::cuda::getCurrentCUDAStream()>>>( \
unique_emb, weight, reverse_indices, offsets, output, B, N, S, D); \
C10_CUDA_KERNEL_LAUNCH_CHECK();
#else
#define FORWARD_LAUNCH_KERNEL(scalar_t, offset_t, mode, use_weight, vec_size) \
segment_reduce_forward_kernel_rocm<scalar_t, offset_t, mode, use_weight, \
vec_size> \
<<<block_num, block_size, shared_memory_size, \
at::cuda::getCurrentCUDAStream()>>>( \
unique_emb, weight, reverse_indices, offsets, output, B, N, S, D); \
C10_CUDA_KERNEL_LAUNCH_CHECK();
#endif

template <typename scalar_t, typename offset_t, ReduceMode mode>
void segment_reduce_forward_kernel_launcher(
Expand All @@ -212,6 +416,18 @@ void segment_reduce_forward_kernel_launcher(
int64_t block_num = 65536;
block_num = std::min(block_num, S);

#ifdef __HIP_PLATFORM_AMD__
size_t shared_memory_size = 0;
if (mode != ReduceMode::TILE) {
size_t max_cols =
LDS_LIMIT_BYTES / sizeof(scalar_t) / static_cast<size_t>(block_size);
if (max_cols < 1) max_cols = 1;
size_t cols = std::min(static_cast<size_t>(D), max_cols);
shared_memory_size =
cols * static_cast<size_t>(block_size) * sizeof(scalar_t);
}
#endif

if (D % 4 == 0) {
if (use_weight) {
FORWARD_LAUNCH_KERNEL(scalar_t, offset_t, mode, true, 4)
Expand Down
2 changes: 1 addition & 1 deletion tests/ops/emb_segment_reduce_test.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import unittest

import torch
import recis


def equal(a: torch.Tensor, b: torch.Tensor):
Expand Down