Skip to content

Comments

Reduce sum implementation#25

Draft
vvvdwbvvv wants to merge 3 commits intoKernel-Heim:mainfrom
vvvdwbvvv:reduce_sum
Draft

Reduce sum implementation#25
vvvdwbvvv wants to merge 3 commits intoKernel-Heim:mainfrom
vvvdwbvvv:reduce_sum

Conversation

@vvvdwbvvv
Copy link

@vvvdwbvvv vvvdwbvvv commented Feb 1, 2026

Implement CuTe Reduction Kernel with Multi-Variant Support

Related to #20


What Changed

This PR implements a high-performance row/column sum reduction kernel using NVIDIA CuTe DSL, with multiple optimization variants based on Mark Harris's parallel reduction techniques.
Optimizations (based on Mark Harris reduction techniques):

  1. Algorithm Cascading: Each thread processes multiple elements sequentially
    before reduction, reducing instruction overhead and enabling better latency hiding.
  2. Warp Shuffle: Use warp_reduction_sum to avoid shared memory for intra-warp reduction.
  3. Two-layer reduction: thread-local → warp shuffle → inter-warp via shared memory.
  4. Vectorized loads: Use 128-bit loads for memory bandwidth efficiency (row reduction).

Features Added

Feature Description
Bidirectional reduction Supports both dim=0 (column) and dim=1/-1 (row)
Multi-variant naive, improved, shfl optimization levels
Dtype support float16, bfloat16, float32
Kernel caching Compiled kernels cached by (dtype, dim, variant, M, N)

Optimization Variants

Variant Vecsize Threads/Row Description
naive 1 32 (1 warp) Baseline, no vectorization
improved 128-bit 32 (1 warp) Vectorized loads
shfl 128-bit up to 128 (4 warps) Multi-warp + warp shuffle

Limitation

  • Support 2D
  • Currently only support sum
  • Input has to be contiguous
  • Dim=0 is implemented via transpose

Test Done

Environment

OS:  Ubuntu 24.04 LTS
GPU: NVIDIA RTX 3090
CUDA: 13.0
Python: 3.13.11

Test Commands

# Run all reduce_sum tests (91 tests)
uv run pytest tests/test_reduce_sum.py -v

# Run pre-commit checks
uv run pre-commit run -a

Test Coverage

Test output
=================== 91 passed, 14 warnings in 17.26s==========================

Bench Results

Benchmark commands
# Run benchmark
uv run bench/benchmark_reduce_sum.py --m 4096 --n 4096 --dtype float16

# All dtypes
for dt in float16 bfloat16 float32; do
  echo "=== dtype=$dt ==="
  uv run bench/benchmark_reduce_sum.py --m 4096 --n 4096 --dtype $dt
done

Benchmark results
=== dtype=float16 ===
copy_transpose p50: 0.0501 ms, BW: 670.39 GB/s
reference p50: 0.2222 ms, BW: 151.04 GB/s
=== dtype=bfloat16 ===
copy_transpose p50: 0.0486 ms, BW: 690.02 GB/s
reference p50: 0.2222 ms, BW: 151.04 GB/s
=== dtype=float32 ===
copy_transpose p50: 0.0881 ms, BW: 762.23 GB/s
reference p50: 0.0932 ms, BW: 720.35 GB/s

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant