diff --git a/csrc/ops/fused_bucketized_kernel.cu b/csrc/ops/fused_bucketized_kernel.cu index 99547eb..0a5c7a5 100644 --- a/csrc/ops/fused_bucketized_kernel.cu +++ b/csrc/ops/fused_bucketized_kernel.cu @@ -16,6 +16,7 @@ struct BucketizeData { : boundaries(boundaries), len(len) {} }; +#ifndef __HIP_PLATFORM_AMD__ struct BucketizeFactory { __device__ int operator()(const float value, const BucketizeData& data) { int bucket = 0; @@ -35,6 +36,126 @@ struct BucketizeFactory { return bucket; } }; +#else +#define KBLOCK_SIZE 256 +struct BucketizeFactory { + __device__ __forceinline__ int operator()(const float value, + const float* boundaries, + int count) { + int c = count; + if (c <= 0) return 0; + int bucket = 0; + // ceil(log2(c)) for positive c + int iters = 32 - __builtin_clz(c); + #pragma unroll + for (int i = 0; i < iters; ++i) { + int step = c >> 1; + int left = bucket + step; + int cond = (value >= boundaries[left]); + bucket = cond ? (left + 1) : bucket; + c = cond ? (c - step - 1) : step; + } + return bucket; + } +}; + +template +__global__ __launch_bounds__(KBLOCK_SIZE, 2) +void fused_bucketize_kernel(const A** a, const B** b, C** c, + int64_t N, int64_t* sizes, int64_t* b_sizes, + BucketizeFactory factory) { + int vec_id = blockIdx.y; + if (vec_id >= N) return; + // use 32-bit indices locally to reduce register pressure + int size_local = static_cast(sizes[vec_id]); + int b_size_local = static_cast(b_sizes[vec_id]); + int threads_num = blockDim.x * gridDim.x; + int tid = blockIdx.x * blockDim.x + threadIdx.x; + + extern __shared__ unsigned char shmem[]; + B* sm_boundaries = reinterpret_cast(shmem); + + // move data to shared buffer + const B* __restrict__ b_vec = b[vec_id]; + if (threadIdx.x < b_size_local) { + sm_boundaries[threadIdx.x] = b_vec[threadIdx.x]; + } + __syncthreads(); + + const A* __restrict__ a_vec = a[vec_id]; + C* __restrict__ c_vec = c[vec_id]; + + // Process multiple elements per thread to improve ILP and reduce loop overhead + const int UNROLL = 8; + for (int index = tid; index < size_local; index += threads_num * UNROLL) { + int i0 = index; + // software prefetch of a few elements in the unroll group + A v0 = (i0 < size_local) ? a_vec[i0] : A(0); + A v1 = ((i0 + threads_num) < size_local) ? a_vec[i0 + threads_num] : A(0); + A v2 = ((i0 + 2 * threads_num) < size_local) ? a_vec[i0 + 2 * threads_num] : A(0); + A v3 = ((i0 + 3 * threads_num) < size_local) ? a_vec[i0 + 3 * threads_num] : A(0); + #pragma unroll + for (int u = 0; u < UNROLL; ++u) { + int ii = i0 + u * threads_num; + if (ii < size_local) { + A val = (u == 0) ? v0 : (u == 1) ? v1 : (u == 2) ? v2 : (u == 3) ? v3 : a_vec[ii]; + c_vec[ii] = factory(val, sm_boundaries, b_size_local); + } + } + } +} + +template +void fused_bucketize_launcher(const A** a, const B** b, C** c, int64_t* sizes, + int64_t* b_sizes, int64_t N, + BucketizeFactory factor, + bool with_pack, + hipStream_t stream) { + using namespace recis::cuda; + int64_t sm_count = get_sm_count(); + int64_t max_size = 0; + for (int64_t i = 0; i < N; ++i) { + max_size = std::max(max_size, sizes[i]); + } + if (max_size == 0) return; + + // copy sizes to device + int64_t* d_sizes = cuda_malloc_and_copy(sizes, N, stream); + int64_t* d_b_sizes = cuda_malloc_and_copy(b_sizes, N, stream); + + // shared memory size based on maximum boundary size across tensors + int64_t max_b_size = 0; + for (int64_t i = 0; i < N; ++i) { + max_b_size = std::max(max_b_size, b_sizes[i]); + } + size_t shared_memory_size = static_cast(max_b_size + 1) * sizeof(B); + + int maxActiveBlocksPerSM = 0; + C10_HIP_CHECK(hipOccupancyMaxActiveBlocksPerMultiprocessor( + &maxActiveBlocksPerSM, + fused_bucketize_kernel, + KBLOCK_SIZE, + shared_memory_size)); + if (maxActiveBlocksPerSM <= 0) { + maxActiveBlocksPerSM = 8; // fallback consistent with original heuristic + } + + int64_t block_num = std::min(sm_count * maxActiveBlocksPerSM, + (max_size + KBLOCK_SIZE - 1) / KBLOCK_SIZE); + if (block_num <= 0) block_num = sm_count * maxActiveBlocksPerSM; + dim3 grid(block_num, N); + dim3 block(KBLOCK_SIZE); + + fused_bucketize_kernel + <<>>(a, b, c, N, d_sizes, d_b_sizes, + factor); + + C10_CUDA_KERNEL_LAUNCH_CHECK(); + C10_CUDA_CHECK(cudaStreamSynchronize(stream)); + delete_cuda_ptr(d_sizes); + delete_cuda_ptr(d_b_sizes); +} +#endif void fused_bucketized_cuda(std::vector& inputs, std::vector& outputs, @@ -45,6 +166,29 @@ void fused_bucketized_cuda(std::vector& inputs, std::vector sizes(N); CudaVecParam inputs_ptrs(N, stream); CudaVecParam outputs_ptrs(N, stream); + +#ifdef __HIP_PLATFORM_AMD__ + CudaVecParam boundaries_ptrs(N, stream); + std::vector b_sizes(N); + for (int64_t i = 0; i < N; ++i) { + sizes[i] = inputs[i].numel(); + inputs_ptrs[i] = inputs[i].data_ptr(); + outputs_ptrs[i] = outputs[i].data_ptr(); + boundaries_ptrs[i] = boundaries[i].data_ptr(); + b_sizes[i] = static_cast(boundaries[i].numel()); + } + + fused_bucketize_launcher( + const_cast(inputs_ptrs.data()), + const_cast(boundaries_ptrs.data()), + outputs_ptrs.data(), + sizes.data(), + b_sizes.data(), + N, + BucketizeFactory(), + false, + stream); +#else CudaVecParam bucketize_datas(N, stream); for (int64_t i = 0; i < N; ++i) { sizes[i] = inputs[i].numel(); @@ -57,6 +201,7 @@ void fused_bucketized_cuda(std::vector& inputs, fused_element_wise_launcher( const_cast(inputs_ptrs.data()), bucketize_datas.data(), outputs_ptrs.data(), sizes.data(), N, BucketizeFactory(), false, stream); +#endif } } // namespace functional } // namespace recis