Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
164 changes: 164 additions & 0 deletions forge_cute_py/kernels/reduce_sum.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
from concurrent.futures import thread
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
os.environ['CUTLASS_CUDA_ARCH'] = '86'
os.environ['CUTE_DSL_ENABLE_TVM_FFI'] = '1'

import math
import torch
import time
from cutlass.cute.runtime import from_dlpack
import cuda.bindings.driver as cuda
import cutlass.cute as cute
from cutlass import const_expr
import cutlass

from cutlass import dsl_user_op
from cutlass.cute.arch import nvvm
from cutlass._mlir.dialects.nvvm import AtomicOpKind, MemOrderKind, MemScopeKind
from cutlass.base_dsl.typing import T

_reduce_sum_last_cache = {}
_reduce_sum_first_cache = {}


# old just for future reference
@dsl_user_op
def atomicAddF32(dst_ptr: cute.Pointer, val: cute.Float32, loc=None, ip=None) -> cute.Float32:
return nvvm.atomicrmw(
T.f32(),
AtomicOpKind.FADD,
dst_ptr.llvm_ptr,
val.ir_value(loc=loc, ip=ip),
mem_order=MemOrderKind.RELAXED,
syncscope=MemScopeKind.SYS,
loc=loc,
ip=ip,
)


class ReduceSumLast:
"""Sum reduction along last dimension using CuTe DSL."""

def __init__(self, dtype: type):
self.dtype = dtype
self.num_warps = 8
self.threads_per_block = self.num_warps * 32

@cute.jit
def __call__(
self,
input: cute.Tensor,
output: cute.Tensor,
# stream: cuda.CUstream = None,
):
M, N = input.shape
tiler_mn = (1, self.threads_per_block * 4)
gX = cute.zipped_divide(input, tiler_mn)

thr_layout = cute.make_layout((self.threads_per_block,), stride=(1,))
val_layout = cute.make_layout((4,), stride=(1,))
_, tv_layout = cute.make_layout_tv(thr_layout, val_layout)

self.kernel(gX, output, tv_layout).launch(
grid=(M, 1, 1),
block=(self.threads_per_block, 1, 1),
# stream=stream,
)

@cute.kernel
def kernel(self, input: cute.Tensor, output: cute.Tensor, tv_layout):
num_warps = const_expr(self.num_warps)

smem_alloc = cutlass.utils.SmemAllocator()
shmem = smem_alloc.allocate_tensor(cute.Float32, cute.make_layout((32,)))

tidx, _, _ = cute.arch.thread_idx()
bidx, _, _ = cute.arch.block_idx()
lane = cute.arch.lane_idx()
warp = cute.arch.warp_idx()

# op = cute.nvgpu.CopyUniversalOp()
# atom = cute.make_copy_atom(op, cute.Float32, num_bits_per_copy=128)
acc = cute.Float32(0.0)

_, mode1 = input.shape
ntiles = mode1[1]

for tile_idx in range(ntiles):
blk_coord = ((0, None), (bidx, tile_idx))
cta_tile = input[blk_coord]
thr_frag = cute.composition(cta_tile, tv_layout)
thr_src = thr_frag[(tidx, None)]
# r = cute.make_rmem_tensor(cute.make_layout((4,), stride=(1,)), cute.Float32)
# cute.copy(atom, thr_src, r)
# acc += r[0] + r[1] + r[2] + r[3]
acc += thr_src[0] + thr_src[1] + thr_src[2] + thr_src[3]


acc = cute.arch.warp_reduction_sum(acc)
if lane == 0:
shmem[warp] = acc
cute.arch.sync_threads()

if warp == 0:
acc2 = shmem[lane] if lane < num_warps else 0.0
acc2 = cute.arch.warp_reduction_sum(acc2)
if lane == 0:
output[bidx] = acc2


class ReduceSumFirst:
"""Sum reduction along first dimension using CuTe DSL."""

def __init__(self, dtype: type):
self.dtype = dtype
self.num_warps = 4
self.threads_per_block = self.num_warps * 32

@cute.jit
def __call__(
self,
input: cute.Tensor,
output: cute.Tensor,
# stream: cuda.CUstream = None,
):
M, N = input.shape
yblocks = cute.ceil_div(N, 4)
self.kernel(input, output, self.threads_per_block // 4).launch(
grid=(yblocks, 1, 1),
block=(self.threads_per_block, 1, 1),
# stream=stream,
)

@cute.kernel
def kernel(self, input: cute.Tensor, output: cute.Tensor, stride: int):
smem_alloc = cutlass.utils.SmemAllocator()
smem_layout = cute.make_layout((4, 32))
shmem = smem_alloc.allocate_tensor(cute.Float32, smem_layout)

M, N = input.shape
tidx, _, _ = cute.arch.thread_idx()
bidx, _, _ = cute.arch.block_idx()
lane_idx = cute.arch.lane_idx()
warp_idx = cute.arch.warp_idx()

max_iters = cute.ceil_div(M, stride)
col_offset = tidx % 4
row_offset = tidx // 4
col = 4 * bidx + col_offset
acc = cute.Float32(0)

row = row_offset
for _ in range(max_iters):
if row < M and col < N:
acc = acc + input[row, col]
row = row + 32

shmem[col_offset, row_offset] = acc
cute.arch.sync_threads()
acc = shmem[warp_idx, lane_idx]

acc = cute.arch.warp_reduction_sum(acc)
if lane_idx == 0:
output[bidx * 4 + warp_idx] = acc
99 changes: 46 additions & 53 deletions forge_cute_py/ops/reduce_sum.py
Original file line number Diff line number Diff line change
@@ -1,65 +1,58 @@
import cutlass.cute as cute
import torch
from cutlass import BFloat16, Float16, Float32
from cutlass.cute.runtime import from_dlpack

from forge_cute_py.kernels.reduce_sum import ReduceSumLast, ReduceSumFirst

_compile_cache = {}

@torch.library.custom_op("forge_cute_py::_reduce_sum", mutates_args={"out"})
def _reduce_sum(x: torch.Tensor, out: torch.Tensor, dim: int = -1, variant: str = "shfl") -> None:
"""Row/column sum reduction (reference implementation stub).

Args:
x: Input tensor of shape (M, N)
out: Output tensor (mutated in-place)
dim: Dimension to reduce over (-1, 0, or 1)
variant: Reduction variant (naive, improved, shfl) - currently unused
"""
@torch.library.custom_op("forge_cute_py::_reduce_sum", mutates_args={"out"})
def _reduce_sum(x: torch.Tensor, out: torch.Tensor, dim: int = -1) -> None:
"""Sum reduction using CuTe DSL."""
assert x.dim() == 2, "reduce_sum expects a 2D tensor"
assert x.is_cuda, f"reduce_sum is CUDA-only, got device={x.device}"
assert dim in (-1, 0, 1), f"reduce_sum expects dim in {{-1, 0, 1}}, got {dim}"
assert x.dtype in [torch.float16, torch.bfloat16, torch.float32], (
f"Unsupported dtype: {x.dtype}"
)

# Normalize dim to positive index
dim = dim if dim >= 0 else x.ndim + dim

# For now, use reference implementation
# Future: call kernel implementation based on variant when available
from forge_cute_py.ref import reduce_sum as reduce_sum_ref

result = reduce_sum_ref(x, dim=dim)
out.copy_(result)


_reduce_sum.compile_cache = {}


def reduce_sum(x: torch.Tensor, dim: int = -1, variant: str = "shfl") -> torch.Tensor:
"""Row/column sum reduction.

Args:
x: Input tensor of shape (M, N)
dim: Dimension to reduce over (-1 for last dim, 0 or 1)
variant: Reduction variant (naive, improved, shfl) - currently unused

Returns:
Reduced tensor of shape (M,) if dim=1 or (N,) if dim=0

Examples:
>>> x = torch.randn(32, 128, device='cuda', dtype=torch.float16)
>>> y = reduce_sum(x, dim=-1) # Sum over columns, result shape: (32,)
>>> y.shape
torch.Size([32])
"""
# Normalize dim to positive index
dtype_map = {
torch.float16: Float16,
torch.float32: Float32,
torch.bfloat16: BFloat16,
}
cute_dtype = dtype_map[x.dtype]
compile_key = (cute_dtype, dim)

if compile_key not in _compile_cache:
m = cute.sym_int()
n = cute.sym_int()
input_cute = cute.runtime.make_fake_compact_tensor(cute_dtype, (m, n), stride_order=(1, 0))

if dim == 1: # Reduce last dim
output_cute = cute.runtime.make_fake_compact_tensor(cute_dtype, (m,))
kernel_class = ReduceSumLast(cute_dtype)
else: # dim == 0
output_cute = cute.runtime.make_fake_compact_tensor(cute_dtype, (n,))
kernel_class = ReduceSumFirst(cute_dtype)

_compile_cache[compile_key] = cute.compile(
kernel_class,
input_cute,
output_cute,
# cute.runtime.make_fake_stream(use_tvm_ffi_env_stream=False),
options="--enable-tvm-ffi",
)

x_cute = from_dlpack(x, assumed_align=16)
out_cute = from_dlpack(out, assumed_align=16)
_compile_cache[compile_key](x_cute, out_cute)


def reduce_sum(x: torch.Tensor, dim: int = -1, variant='') -> torch.Tensor:
"""Sum reduction with CuTe DSL kernel."""
dim = dim if dim >= 0 else x.ndim + dim

# Determine output shape
if dim == 0:
out_shape = (x.shape[1],)
elif dim == 1:
out_shape = (x.shape[0],)
else:
raise ValueError(f"Invalid dim={dim} for 2D tensor")

out_shape = (x.shape[1],) if dim == 0 else (x.shape[0],)
out = torch.empty(out_shape, dtype=x.dtype, device=x.device)
_reduce_sum(x, out, dim, variant)
_reduce_sum(x, out, dim)
return out
32 changes: 32 additions & 0 deletions forge_cute_py/util/profile_launch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import torch

from forge_cute_py.ops.reduce_sum import reduce_sum

def main():
M, N = 4096, 4096
x = torch.randn(M, N, device='cuda', dtype=torch.float32)

# Warmup
print("Warming up...")
for _ in range(10):
_ = reduce_sum(x, dim=-1)
_ = x.sum(dim=-1)
torch.cuda.synchronize()
print("Warmup complete")

# Profile cute
torch.cuda.nvtx.range_push("cute_reduce_sum")
for _ in range(100):
_ = reduce_sum(x, dim=-1)
torch.cuda.synchronize()
torch.cuda.nvtx.range_pop()

# Profile torch
torch.cuda.nvtx.range_push("torch_sum")
for _ in range(100):
_ = x.sum(dim=-1)
torch.cuda.synchronize()
torch.cuda.nvtx.range_pop()

if __name__ == "__main__":
main()
Loading