From 0015b9a236192004fbfa4cc5f3890f2428560961 Mon Sep 17 00:00:00 2001 From: jonah Date: Thu, 29 Jan 2026 13:20:19 -0800 Subject: [PATCH 1/6] reduce sum 1d works --- forge_cute_py/kernels/reduce_sum.py | 111 ++++++++++++++++++++++++++++ pyproject.toml | 6 +- 2 files changed, 114 insertions(+), 3 deletions(-) create mode 100644 forge_cute_py/kernels/reduce_sum.py diff --git a/forge_cute_py/kernels/reduce_sum.py b/forge_cute_py/kernels/reduce_sum.py new file mode 100644 index 0000000..9dd5db4 --- /dev/null +++ b/forge_cute_py/kernels/reduce_sum.py @@ -0,0 +1,111 @@ +import os +os.environ['CUDA_VISIBLE_DEVICES'] = '0' +os.environ['CUTLASS_CUDA_ARCH'] = '86' + +import math +import torch +from cutlass.cute.runtime import from_dlpack + +import cutlass +import cutlass.cute as cute + +from cutlass import dsl_user_op +from cutlass.cute.arch import nvvm +from cutlass._mlir.dialects.nvvm import AtomicOpKind, MemOrderKind, MemScopeKind +from cutlass.base_dsl.typing import T + + +@cute.kernel +def reduce_sum_kernel_one(input: cute.Tensor, output: cute.Tensor, N:cute.Int32, num_warps: int, max_iters: int): + smem_alloc = cutlass.utils.SmemAllocator() + smem_layout = cute.make_layout((32,)) + shmem = smem_alloc.allocate_tensor(cute.Float32, smem_layout) + + tidx, _, _ = cute.arch.thread_idx() + bidx, _, _ = cute.arch.block_idx() + bdim, _, _ = cute.arch.block_dim() + lane_idx = cute.arch.lane_idx() + warp_idx = cute.arch.warp_idx() + idx = bdim * bidx + tidx + + acc = cute.Float32(0) + for i in range(max_iters): + idx = idx + i * bdim + if idx < N: + acc = acc + input[idx] + acc = cute.arch.warp_reduction_sum(acc) + if lane_idx == 0: + shmem[warp_idx] = acc + cute.arch.sync_threads() + if warp_idx == 0: + acc = shmem[lane_idx] if lane_idx < num_warps else 0.0 + acc = cute.arch.warp_reduction_sum(acc) + if lane_idx == 0: + output[bidx] = acc + + +@dsl_user_op +def atomicAddF32(dst_ptr: cute.Pointer, val: cute.Float32, loc=None, ip=None) -> cute.Float32: + return nvvm.atomicrmw( + T.f32(), + AtomicOpKind.FADD, + dst_ptr.llvm_ptr, + val.ir_value(loc=loc, ip=ip), + mem_order=MemOrderKind.RELAXED, + syncscope=MemScopeKind.SYS, + loc=loc, + ip=ip, + ) + +@cute.kernel +def reduce_sum_kernel_atomic(input: cute.Tensor, output: cute.Tensor, N:cute.Int32, num_warps: int, coarsen: int): + smem_alloc = cutlass.utils.SmemAllocator() + smem_layout = cute.make_layout((32,)) + shmem = smem_alloc.allocate_tensor(cute.Float32, smem_layout) + + tidx, _, _ = cute.arch.thread_idx() + bidx, _, _ = cute.arch.block_idx() + bdim, _, _ = cute.arch.block_dim() + lane_idx = cute.arch.lane_idx() + warp_idx = cute.arch.warp_idx() + base_idx = coarsen * bdim * bidx + tidx + + acc = cute.Float32(0) + for i in range(coarsen): + idx = base_idx + i * bdim + if idx < N: + acc = acc + input[idx] + acc = cute.arch.warp_reduction_sum(acc) + if lane_idx == 0: + shmem[warp_idx] = acc + cute.arch.sync_threads() + if warp_idx == 0: + acc = shmem[lane_idx] if lane_idx < num_warps else 0.0 + acc = cute.arch.warp_reduction_sum(acc) + if lane_idx == 0: + atomicAddF32(output.iterator, acc) + + +@cute.jit +def solve(input: cute.Tensor, output: cute.Tensor, N: cute.Int32): + # if N <= 128: + # reduce_sum_kernel_one( input, output, N, 4, 1 + # ).launch( grid=(1, 1, 1), block=(128, 1, 1)) + # elif N <= 10240: + # reduce_sum_kernel_one( input, output, N, 32, 10 + # ).launch( grid=(1, 1, 1), block=(1024, 1, 1)) + # else: + num_warps = 8 + threads_per_block = 32 * num_warps + coarsen = 8 + blocks = cute.ceil_div(N, threads_per_block * coarsen) + reduce_sum_kernel_atomic(input, output, N, num_warps, coarsen + ).launch( grid=(blocks, 1, 1), block=(threads_per_block, 1, 1)) + + +N = 100000 +a = torch.randn((N,), device="cuda", dtype=torch.float32) +b = torch.zeros((1,), device='cuda', dtype=torch.float32) +vadd_compiled = cute.compile(solve, from_dlpack(a), from_dlpack(b), N) +vadd_compiled(from_dlpack(a), from_dlpack(b), N) +print(f'b=={b}') \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index a44eeec..e2d98fb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -18,12 +18,12 @@ dependencies = [ [tool.uv.sources] torch = [ - { index = "pytorch-cu130", marker = "sys_platform == 'linux' or sys_platform == 'win32'" }, + { index = "pytorch-cu128", marker = "sys_platform == 'linux' or sys_platform == 'win32'" }, ] [[tool.uv.index]] -name = "pytorch-cu130" -url = "https://download.pytorch.org/whl/cu130" +name = "pytorch-cu128" +url = "https://download.pytorch.org/whl/cu128" explicit = true From af5f7ec572033e1e939809fcc04ab199f533163f Mon Sep 17 00:00:00 2001 From: jonah Date: Thu, 29 Jan 2026 16:26:50 -0800 Subject: [PATCH 2/6] reduce sum 2d both dimensions --- forge_cute_py/kernels/reduce_sum.py | 137 +++++++++++++++++----------- 1 file changed, 83 insertions(+), 54 deletions(-) diff --git a/forge_cute_py/kernels/reduce_sum.py b/forge_cute_py/kernels/reduce_sum.py index 9dd5db4..c1bf242 100644 --- a/forge_cute_py/kernels/reduce_sum.py +++ b/forge_cute_py/kernels/reduce_sum.py @@ -1,3 +1,4 @@ +from concurrent.futures import thread import os os.environ['CUDA_VISIBLE_DEVICES'] = '0' os.environ['CUTLASS_CUDA_ARCH'] = '86' @@ -15,24 +16,40 @@ from cutlass.base_dsl.typing import T +@dsl_user_op +def atomicAddF32(dst_ptr: cute.Pointer, val: cute.Float32, loc=None, ip=None) -> cute.Float32: + return nvvm.atomicrmw( + T.f32(), + AtomicOpKind.FADD, + dst_ptr.llvm_ptr, + val.ir_value(loc=loc, ip=ip), + mem_order=MemOrderKind.RELAXED, + syncscope=MemScopeKind.SYS, + loc=loc, + ip=ip, + ) + + @cute.kernel -def reduce_sum_kernel_one(input: cute.Tensor, output: cute.Tensor, N:cute.Int32, num_warps: int, max_iters: int): +def reduce_sum_kernel_last(input: cute.Tensor, output: cute.Tensor, num_warps: int): smem_alloc = cutlass.utils.SmemAllocator() smem_layout = cute.make_layout((32,)) shmem = smem_alloc.allocate_tensor(cute.Float32, smem_layout) + _, N = input.shape tidx, _, _ = cute.arch.thread_idx() bidx, _, _ = cute.arch.block_idx() - bdim, _, _ = cute.arch.block_dim() + bdimx, _, _ = cute.arch.block_dim() lane_idx = cute.arch.lane_idx() warp_idx = cute.arch.warp_idx() - idx = bdim * bidx + tidx + + max_iters = cute.ceil_div(N, bdimx) acc = cute.Float32(0) for i in range(max_iters): - idx = idx + i * bdim + idx = tidx + i * bdimx if idx < N: - acc = acc + input[idx] + acc = acc + input[bidx, idx] acc = cute.arch.warp_reduction_sum(acc) if lane_idx == 0: shmem[warp_idx] = acc @@ -44,68 +61,80 @@ def reduce_sum_kernel_one(input: cute.Tensor, output: cute.Tensor, N:cute.Int32, output[bidx] = acc -@dsl_user_op -def atomicAddF32(dst_ptr: cute.Pointer, val: cute.Float32, loc=None, ip=None) -> cute.Float32: - return nvvm.atomicrmw( - T.f32(), - AtomicOpKind.FADD, - dst_ptr.llvm_ptr, - val.ir_value(loc=loc, ip=ip), - mem_order=MemOrderKind.RELAXED, - syncscope=MemScopeKind.SYS, - loc=loc, - ip=ip, - ) - @cute.kernel -def reduce_sum_kernel_atomic(input: cute.Tensor, output: cute.Tensor, N:cute.Int32, num_warps: int, coarsen: int): +def reduce_sum_kernel_first(input: cute.Tensor, output: cute.Tensor, stride: int): smem_alloc = cutlass.utils.SmemAllocator() - smem_layout = cute.make_layout((32,)) + smem_layout = cute.make_layout((4, 32)) shmem = smem_alloc.allocate_tensor(cute.Float32, smem_layout) + M, N = input.shape tidx, _, _ = cute.arch.thread_idx() bidx, _, _ = cute.arch.block_idx() - bdim, _, _ = cute.arch.block_dim() lane_idx = cute.arch.lane_idx() warp_idx = cute.arch.warp_idx() - base_idx = coarsen * bdim * bidx + tidx + max_iters = cute.ceil_div(M, stride) + col_offset = tidx % 4 + row_offset = tidx // 4 + col = 4 * bidx + col_offset acc = cute.Float32(0) - for i in range(coarsen): - idx = base_idx + i * bdim - if idx < N: - acc = acc + input[idx] + + row = row_offset + for _ in range(max_iters): + if row < M and col < N: + acc = acc + input[row, col] + row = row + 32 + + shmem[col_offset, row_offset] = acc + cute.arch.sync_threads() + acc = shmem[warp_idx, lane_idx] + acc = cute.arch.warp_reduction_sum(acc) if lane_idx == 0: - shmem[warp_idx] = acc - cute.arch.sync_threads() - if warp_idx == 0: - acc = shmem[lane_idx] if lane_idx < num_warps else 0.0 - acc = cute.arch.warp_reduction_sum(acc) - if lane_idx == 0: - atomicAddF32(output.iterator, acc) + output[bidx * 4 + warp_idx] = acc @cute.jit -def solve(input: cute.Tensor, output: cute.Tensor, N: cute.Int32): - # if N <= 128: - # reduce_sum_kernel_one( input, output, N, 4, 1 - # ).launch( grid=(1, 1, 1), block=(128, 1, 1)) - # elif N <= 10240: - # reduce_sum_kernel_one( input, output, N, 32, 10 - # ).launch( grid=(1, 1, 1), block=(1024, 1, 1)) - # else: - num_warps = 8 +def _reduce_sum_last(x, output): + num_warps = 4 threads_per_block = 32 * num_warps - coarsen = 8 - blocks = cute.ceil_div(N, threads_per_block * coarsen) - reduce_sum_kernel_atomic(input, output, N, num_warps, coarsen - ).launch( grid=(blocks, 1, 1), block=(threads_per_block, 1, 1)) - - -N = 100000 -a = torch.randn((N,), device="cuda", dtype=torch.float32) -b = torch.zeros((1,), device='cuda', dtype=torch.float32) -vadd_compiled = cute.compile(solve, from_dlpack(a), from_dlpack(b), N) -vadd_compiled(from_dlpack(a), from_dlpack(b), N) -print(f'b=={b}') \ No newline at end of file + m, _ = x.shape + reduce_sum_kernel_last(x, output, num_warps + ).launch( grid=(m, 1, 1), block=(threads_per_block, 1, 1)) + + +@cute.jit +def _reduce_sum_first(x, output): + num_warps = 4 + threads_per_block = num_warps * 32 + m, n = x.shape + yblocks = cute.ceil_div(n, 4) + reduce_sum_kernel_first(x, output, threads_per_block // 4 + ).launch( + grid=(yblocks, 1, 1), + block=(threads_per_block, 1, 1) + ) + + +def reduce_sum(x, dim=-1): + shape = list(x.shape) + shape.pop(dim) + output = torch.zeros(shape, device=x.device, dtype=x.dtype) + if dim == -1: + vadd_compiled = cute.compile(_reduce_sum_last, from_dlpack(x), from_dlpack(output)) + vadd_compiled(from_dlpack(x), from_dlpack(output)) + else: + vadd_compiled = cute.compile(_reduce_sum_first, from_dlpack(x), from_dlpack(output)) + vadd_compiled(from_dlpack(x), from_dlpack(output)) + + return output + + +def test(): + for dim in [-1, 0]: + M, N = 1100, 1200 + a = torch.randn((M, N), device="cuda", dtype=torch.float32) + output = reduce_sum(a, dim=dim) + close = torch.allclose(output, a.sum(dim), rtol=1e-3) + assert close, f"Error along dimension: {dim}" + print("tests pass") \ No newline at end of file From d3980ec2bfdddcb211d04309cf9d3387f226d526 Mon Sep 17 00:00:00 2001 From: jonah Date: Fri, 30 Jan 2026 16:35:34 -0800 Subject: [PATCH 3/6] close --- forge_cute_py/kernels/reduce_sum.py | 190 +++++++++++-- ncu_output/sum_reduce.txt | 411 ++++++++++++++++++++++++++++ 2 files changed, 583 insertions(+), 18 deletions(-) create mode 100644 ncu_output/sum_reduce.txt diff --git a/forge_cute_py/kernels/reduce_sum.py b/forge_cute_py/kernels/reduce_sum.py index c1bf242..2fc896a 100644 --- a/forge_cute_py/kernels/reduce_sum.py +++ b/forge_cute_py/kernels/reduce_sum.py @@ -5,6 +5,7 @@ import math import torch +import time from cutlass.cute.runtime import from_dlpack import cutlass @@ -15,6 +16,9 @@ from cutlass._mlir.dialects.nvvm import AtomicOpKind, MemOrderKind, MemScopeKind from cutlass.base_dsl.typing import T +_reduce_sum_last_cache = {} +_reduce_sum_first_cache = {} + @dsl_user_op def atomicAddF32(dst_ptr: cute.Pointer, val: cute.Float32, loc=None, ip=None) -> cute.Float32: @@ -31,7 +35,7 @@ def atomicAddF32(dst_ptr: cute.Pointer, val: cute.Float32, loc=None, ip=None) -> @cute.kernel -def reduce_sum_kernel_last(input: cute.Tensor, output: cute.Tensor, num_warps: int): +def og_reduce_sum_kernel_last(input: cute.Tensor, output: cute.Tensor, num_warps: int): smem_alloc = cutlass.utils.SmemAllocator() smem_layout = cute.make_layout((32,)) shmem = smem_alloc.allocate_tensor(cute.Float32, smem_layout) @@ -61,6 +65,77 @@ def reduce_sum_kernel_last(input: cute.Tensor, output: cute.Tensor, num_warps: i output[bidx] = acc +@cute.kernel +def reduce_sum_kernel_last(input: cute.Tensor, output: cute.Tensor, num_warps: int): + ROWS_PER_BLOCK = 4 + WARPS_PER_ROW = num_warps // ROWS_PER_BLOCK + THREADS_PER_ROW = WARPS_PER_ROW * 32 + + smem_alloc = cutlass.utils.SmemAllocator() + smem_layout = cute.make_layout((ROWS_PER_BLOCK, 32)) + shmem = smem_alloc.allocate_tensor(cute.Float32, smem_layout) + + M, N = input.shape + tidx, _, _ = cute.arch.thread_idx() + bidx, _, _ = cute.arch.block_idx() + lane_idx = cute.arch.lane_idx() + warp_idx = cute.arch.warp_idx() + + block_row = warp_idx // WARPS_PER_ROW + warp_in_row = warp_idx % WARPS_PER_ROW + tid_in_row = tidx % THREADS_PER_ROW + + og_row = bidx * ROWS_PER_BLOCK + row = og_row + block_row + + max_iters = cute.ceil_div(N, THREADS_PER_ROW) + + acc = cute.Float32(0) + for i in range(max_iters): + col = tid_in_row + i * THREADS_PER_ROW + if col < N and row < M: + acc = acc + input[row, col] + + acc = cute.arch.warp_reduction_sum(acc) + + if lane_idx == 0: + shmem[block_row, warp_in_row] = acc + + cute.arch.sync_threads() + if warp_idx < ROWS_PER_BLOCK: + v = shmem[warp_idx, lane_idx] if lane_idx < WARPS_PER_ROW else 0.0 + v = cute.arch.warp_reduction_sum(v) + if lane_idx == 0: + out_row = og_row + warp_idx + if out_row < M: + output[out_row] = v + + +@cute.jit +def _og_reduce_sum_last(x, output): + num_warps = 4 + threads_per_block = 32 * num_warps + m, _ = x.shape + og_reduce_sum_kernel_last(x, output, num_warps + ).launch( grid=(m, 1, 1), block=(threads_per_block, 1, 1)) + + +@cute.jit +def _reduce_sum_last(x, output): + # num_warps = 4 + # threads_per_block = 32 * num_warps + # m, _ = x.shape + # reduce_sum_kernel_last(x, output, num_warps + # ).launch( grid=(m, 1, 1), block=(threads_per_block, 1, 1)) + num_warps = 32 + ROWS_PER_BLOCK = 4 + threads_per_block = 32 * num_warps + m, _ = x.shape + blocks = cute.ceil_div(m, ROWS_PER_BLOCK) + reduce_sum_kernel_last(x, output, num_warps + ).launch( grid=(blocks, 1, 1), block=(threads_per_block, 1, 1)) + + @cute.kernel def reduce_sum_kernel_first(input: cute.Tensor, output: cute.Tensor, stride: int): smem_alloc = cutlass.utils.SmemAllocator() @@ -94,15 +169,6 @@ def reduce_sum_kernel_first(input: cute.Tensor, output: cute.Tensor, stride: int output[bidx * 4 + warp_idx] = acc -@cute.jit -def _reduce_sum_last(x, output): - num_warps = 4 - threads_per_block = 32 * num_warps - m, _ = x.shape - reduce_sum_kernel_last(x, output, num_warps - ).launch( grid=(m, 1, 1), block=(threads_per_block, 1, 1)) - - @cute.jit def _reduce_sum_first(x, output): num_warps = 4 @@ -117,15 +183,23 @@ def _reduce_sum_first(x, output): def reduce_sum(x, dim=-1): - shape = list(x.shape) - shape.pop(dim) - output = torch.zeros(shape, device=x.device, dtype=x.dtype) + cache_key = (x.dtype, x.shape) if dim == -1: - vadd_compiled = cute.compile(_reduce_sum_last, from_dlpack(x), from_dlpack(output)) - vadd_compiled(from_dlpack(x), from_dlpack(output)) + output = torch.empty((x.size(0),), device=x.device, dtype=x.dtype) + if cache_key not in _reduce_sum_last_cache: + print("compiling...") + _reduce_sum_last_cache[cache_key] = cute.compile( + _og_reduce_sum_last, from_dlpack(x), from_dlpack(output) + ) + _reduce_sum_last_cache[cache_key](from_dlpack(x), from_dlpack(output)) else: - vadd_compiled = cute.compile(_reduce_sum_first, from_dlpack(x), from_dlpack(output)) - vadd_compiled(from_dlpack(x), from_dlpack(output)) + output = torch.empty((x.size(1),), device=x.device, dtype=x.dtype) + if cache_key not in _reduce_sum_first_cache: + print("compiling...") + _reduce_sum_first_cache[cache_key] = cute.compile( + _reduce_sum_first, from_dlpack(x), from_dlpack(output) + ) + _reduce_sum_first_cache[cache_key](from_dlpack(x), from_dlpack(output)) return output @@ -137,4 +211,84 @@ def test(): output = reduce_sum(a, dim=dim) close = torch.allclose(output, a.sum(dim), rtol=1e-3) assert close, f"Error along dimension: {dim}" - print("tests pass") \ No newline at end of file + print("tests pass") + + +def benchmark(): + import time + M, N = 4096, 4096 + x = torch.randn(M, N, device='cuda', dtype=torch.float32) + + # Correctness checks + print("Correctness checks:") + for dim in [-1, 0]: + result = reduce_sum(x, dim=dim) + expected = x.sum(dim=dim) + is_close = torch.allclose(result, expected, rtol=1e-3, atol=1e-4) + print(f" dim={dim:2d}: {'✓ PASS' if is_close else '✗ FAIL'}") + if not is_close: + max_diff = (result - expected).abs().max().item() + print(f" max diff: {max_diff}") + + print("\nBenchmarks:") + + # Warmup + for _ in range(10): + _ = reduce_sum(x, dim=-1) + _ = reduce_sum(x, dim=0) + torch.cuda.synchronize() + + # Benchmark dim=-1 + del x + x = torch.randn(M, N, device='cuda', dtype=torch.float32) + start = time.perf_counter() + for _ in range(100): + _ = reduce_sum(x, dim=-1) + torch.cuda.synchronize() + print(f" reduce_sum dim=-1: {(time.perf_counter() - start) * 10:.3f} ms") + + # Benchmark dim=0 + del x + x = torch.randn(M, N, device='cuda', dtype=torch.float32) + start = time.perf_counter() + for _ in range(100): + _ = reduce_sum(x, dim=0) + torch.cuda.synchronize() + print(f" reduce_sum dim=0: {(time.perf_counter() - start) * 10:.3f} ms") + + # Compare to PyTorch + del x + x = torch.randn(M, N, device='cuda', dtype=torch.float32) + start = time.perf_counter() + for _ in range(100): + _ = x.sum(dim=-1) + torch.cuda.synchronize() + print(f" torch.sum dim=-1: {(time.perf_counter() - start) * 10:.3f} ms") + + del x + x = torch.randn(M, N, device='cuda', dtype=torch.float32) + start = time.perf_counter() + for _ in range(100): + _ = x.sum(dim=0) + torch.cuda.synchronize() + print(f" torch.sum dim=0: {(time.perf_counter() - start) * 10:.3f} ms") + + +''' +sudo systemctl stop dcgm +/usr/local/cuda-12.8/bin/ncu --set full -o reduce_sum_profile uv run python run.py +/usr/local/cuda-12.8/bin/ncu --import reduce_sum_profile.ncu-rep +''' +def ncu_test(): + x = torch.randn(4096, 4096, device='cuda', dtype=torch.float32) + + # Warmup (compiles the kernel) + _ = reduce_sum(x, dim=-1) + torch.cuda.synchronize() + + # Profile this run + y = reduce_sum(x, dim=-1) + torch.cuda.synchronize() + +benchmark() +# ncu_test() diff --git a/ncu_output/sum_reduce.txt b/ncu_output/sum_reduce.txt new file mode 100644 index 0000000..2ec3fa6 --- /dev/null +++ b/ncu_output/sum_reduce.txt @@ -0,0 +1,411 @@ + kernel_cutlass_reduce_sum_kernel_last_tensorptrf32gmemo4096409640961_tensorptrf32gmemo40961_4_0 (4096, 1, 1)x(128, 1, 1), Context 1, Stream 7, Device 0, CC 8.9 + Section: GPU Speed Of Light Throughput + ----------------------- ----------- ------------ + Metric Name Metric Unit Metric Value + ----------------------- ----------- ------------ + DRAM Frequency Ghz 8.99 + SM Frequency Ghz 1.06 + Elapsed Cycles cycle 100174 + Memory Throughput % 93.81 + DRAM Throughput % 93.81 + Duration us 94.11 + L1/TEX Cache Throughput % 15.18 + L2 Cache Throughput % 33.00 + SM Active Cycles cycle 97308.30 + Compute (SM) Throughput % 9.62 + ----------------------- ----------- ------------ + + INF This workload is utilizing greater than 80.0% of the available compute or memory performance of the device. + To further improve performance, work will likely need to be shifted from the most utilized to another unit. + Start by analyzing DRAM in the Memory Workload Analysis section. + + Section: GPU Speed Of Light Roofline Chart + INF The ratio of peak float (fp32) to double (fp64) performance on this device is 64:1. The workload achieved + close to 1% of this device's fp32 peak performance and 0% of its fp64 peak performance. See the Kernel + Profiling Guide (https://docs.nvidia.com/nsight-compute/ProfilingGuide/index.html#roofline) for more details + on roofline analysis. + + Section: PM Sampling + ------------------------- ----------- ------------ + Metric Name Metric Unit Metric Value + ------------------------- ----------- ------------ + Maximum Buffer Size Mbyte 50.33 + Dropped Samples sample 0 + Maximum Sampling Interval us 1 + # Pass Groups 2 + ------------------------- ----------- ------------ + + Section: Compute Workload Analysis + -------------------- ----------- ------------ + Metric Name Metric Unit Metric Value + -------------------- ----------- ------------ + Executed Ipc Active inst/cycle 0.33 + Executed Ipc Elapsed inst/cycle 0.32 + Issue Slots Busy % 8.18 + Issued Ipc Active inst/cycle 0.33 + SM Busy % 8.18 + -------------------- ----------- ------------ + + OPT Est. Local Speedup: 93.37% + All compute pipelines are under-utilized. Either this workload is very small or it doesn't issue enough warps + per scheduler. Check the Launch Statistics and Scheduler Statistics sections for further details. + + Section: Memory Workload Analysis + ---------------------------- ----------- ------------ + Metric Name Metric Unit Metric Value + ---------------------------- ----------- ------------ + Memory Throughput Gbyte/s 809.32 + Mem Busy % 17.07 + Max Bandwidth % 93.81 + L1/TEX Hit Rate % 0.02 + L2 Compression Success Rate % 0 + L2 Compression Ratio 0 + L2 Compression Input Sectors sector 0 + L2 Hit Rate % 0.58 + Mem Pipes Busy % 9.62 + ---------------------------- ----------- ------------ + + Section: Memory Workload Analysis Tables + OPT Est. Speedup: 25.74% + The memory access pattern for global stores to L2 might not be optimal. On average, only 4.0 of the 32 bytes + transmitted per sector are utilized by each thread. This applies to the 89.2% of sectors missed in L1TEX. + This could possibly be caused by a stride between threads. Check the Source Counters section for uncoalesced + global stores. + + Section: Scheduler Statistics + ---------------------------- ----------- ------------ + Metric Name Metric Unit Metric Value + ---------------------------- ----------- ------------ + One or More Eligible % 8.20 + Issued Warp Per Scheduler 0.08 + No Eligible % 91.80 + Active Warps Per Scheduler warp 9.36 + Eligible Warps Per Scheduler warp 0.10 + ---------------------------- ----------- ------------ + + OPT Est. Local Speedup: 6.188% + Every scheduler is capable of issuing one instruction per cycle, but for this workload each scheduler only + issues an instruction every 12.2 cycles. This might leave hardware resources underutilized and may lead to + less optimal performance. Out of the maximum of 12 warps per scheduler, this workload allocates an average + of 9.36 active warps per scheduler, but only an average of 0.10 warps were eligible per cycle. Eligible + warps are the subset of active warps that are ready to issue their next instruction. Every cycle with no + eligible warp results in no instruction being issued and the issue slot remains unused. To increase the + number of eligible warps, avoid possible load imbalances due to highly different execution durations per + warp. Reducing stalls indicated on the Warp State Statistics and Source Counters sections can help, too. + + Section: Warp State Statistics + ---------------------------------------- ----------- ------------ + Metric Name Metric Unit Metric Value + ---------------------------------------- ----------- ------------ + Warp Cycles Per Issued Instruction cycle 114.14 + Warp Cycles Per Executed Instruction cycle 114.69 + Avg. Active Threads Per Warp 31.89 + Avg. Not Predicated Off Threads Per Warp 31.29 + ---------------------------------------- ----------- ------------ + + OPT Est. Speedup: 6.188% + On average, each warp of this workload spends 106.6 cycles being stalled waiting for a scoreboard dependency + on a L1TEX (local, global, surface, texture) operation. Find the instruction producing the data being waited + upon to identify the culprit. To reduce the number of cycles waiting on L1TEX data accesses verify the + memory access patterns are optimal for the target architecture, attempt to increase cache hit rates by + increasing data locality (coalescing), or by changing the cache configuration. Consider moving frequently + used data to shared memory. This stall type represents about 93.4% of the total average of 114.1 cycles + between issuing two instructions. + ----- -------------------------------------------------------------------------------------------------------------- + INF Check the Warp Stall Sampling (All Samples) table for the top stall locations in your source based on + sampling data. The Kernel Profiling Guide + (https://docs.nvidia.com/nsight-compute/ProfilingGuide/index.html#metrics-reference) provides more details + on each stall reason. + + Section: Instruction Statistics + ---------------------------------------- ----------- ------------ + Metric Name Metric Unit Metric Value + ---------------------------------------- ----------- ------------ + Avg. Executed Instructions Per Scheduler inst 7917.97 + Executed Instructions inst 4497408 + Avg. Issued Instructions Per Scheduler inst 7955.98 + Issued Instructions inst 4518996 + ---------------------------------------- ----------- ------------ + + OPT Est. Speedup: 2.056% + This kernel executes 0 fused and 675840 non-fused FP32 instructions. By converting pairs of non-fused + instructions to their fused (https://docs.nvidia.com/cuda/floating-point/#cuda-and-floating-point), + higher-throughput equivalent, the achieved FP32 performance could be increased by up to 50% (relative to its + current performance). Check the Source page to identify where this kernel executes FP32 instructions. + + Section: Launch Statistics + -------------------------------- --------------- --------------- + Metric Name Metric Unit Metric Value + -------------------------------- --------------- --------------- + Block Size 128 + Function Cache Configuration CachePreferNone + Grid Size 4096 + Registers Per Thread register/thread 42 + Shared Memory Configuration Size Kbyte 32.77 + Driver Shared Memory Per Block Kbyte/block 1.02 + Dynamic Shared Memory Per Block byte/block 128 + Static Shared Memory Per Block byte/block 0 + # SMs SM 142 + Stack Size 1024 + Threads thread 524288 + # TPCs 71 + Enabled TPC IDs all + Uses Green Context 0 + Waves Per SM 2.88 + -------------------------------- --------------- --------------- + + OPT Est. Speedup: 33.33% + A wave of thread blocks is defined as the maximum number of blocks that can be executed in parallel on the + target GPU. The number of blocks in a wave depends on the number of multiprocessors and the theoretical + occupancy of the kernel. This kernel launch results in 2 full waves and a partial wave of 1257 thread + blocks. Under the assumption of a uniform execution duration of all thread blocks, this partial wave may + account for up to 33.3% of the total runtime of this kernel. Try launching a grid with no partial wave. The + overall impact of this tail effect also lessens with the number of full waves executed for a grid. See the + Hardware Model (https://docs.nvidia.com/nsight-compute/ProfilingGuide/index.html#metrics-hw-model) + description for more details on launch configurations. + + Section: Occupancy + ------------------------------- ----------- ------------ + Metric Name Metric Unit Metric Value + ------------------------------- ----------- ------------ + Block Limit SM block 24 + Block Limit Registers block 10 + Block Limit Shared Mem block 28 + Block Limit Warps block 12 + Theoretical Active Warps per SM warp 40 + Theoretical Occupancy % 83.33 + Achieved Occupancy % 77.93 + Achieved Active Warps Per SM warp 37.41 + ------------------------------- ----------- ------------ + + Section: GPU and Memory Workload Distribution + -------------------------- ----------- ------------ + Metric Name Metric Unit Metric Value + -------------------------- ----------- ------------ + Average DRAM Active Cycles cycle 793401.33 + Total DRAM Elapsed Cycles cycle 10148864 + Average L1 Active Cycles cycle 97308.30 + Total L1 Elapsed Cycles cycle 14224546 + Average L2 Active Cycles cycle 127385.48 + Total L2 Elapsed Cycles cycle 6356352 + Average SM Active Cycles cycle 97308.30 + Total SM Elapsed Cycles cycle 14224546 + Average SMSP Active Cycles cycle 97001.06 + Total SMSP Elapsed Cycles cycle 56898184 + -------------------------- ----------- ------------ + + Section: Source Counters + ------------------------- ----------- ------------ + Metric Name Metric Unit Metric Value + ------------------------- ----------- ------------ + Branch Instructions Ratio % 0.02 + Branch Instructions inst 106496 + Branch Efficiency % 100 + Avg. Divergent Branches 0 + ------------------------- ----------- ------------ + + kernel_cutlass_reduce_sum_kernel_last_tensorptrf32gmemo4096409640961_tensorptrf32gmemo40961_4_0 (4096, 1, 1)x(128, 1, 1), Context 1, Stream 7, Device 0, CC 8.9 + Section: GPU Speed Of Light Throughput + ----------------------- ----------- ------------ + Metric Name Metric Unit Metric Value + ----------------------- ----------- ------------ + DRAM Frequency Ghz 8.99 + SM Frequency Ghz 1.06 + Elapsed Cycles cycle 100250 + Memory Throughput % 93.71 + DRAM Throughput % 93.71 + Duration us 94.18 + L1/TEX Cache Throughput % 15.20 + L2 Cache Throughput % 33.26 + SM Active Cycles cycle 97148.04 + Compute (SM) Throughput % 9.61 + ----------------------- ----------- ------------ + + INF This workload is utilizing greater than 80.0% of the available compute or memory performance of the device. + To further improve performance, work will likely need to be shifted from the most utilized to another unit. + Start by analyzing DRAM in the Memory Workload Analysis section. + + Section: GPU Speed Of Light Roofline Chart + INF The ratio of peak float (fp32) to double (fp64) performance on this device is 64:1. The workload achieved + close to 1% of this device's fp32 peak performance and 0% of its fp64 peak performance. See the Kernel + Profiling Guide (https://docs.nvidia.com/nsight-compute/ProfilingGuide/index.html#roofline) for more details + on roofline analysis. + + Section: PM Sampling + ------------------------- ----------- ------------ + Metric Name Metric Unit Metric Value + ------------------------- ----------- ------------ + Maximum Buffer Size Mbyte 50.33 + Dropped Samples sample 0 + Maximum Sampling Interval us 1 + # Pass Groups 2 + ------------------------- ----------- ------------ + + Section: Compute Workload Analysis + -------------------- ----------- ------------ + Metric Name Metric Unit Metric Value + -------------------- ----------- ------------ + Executed Ipc Active inst/cycle 0.33 + Executed Ipc Elapsed inst/cycle 0.32 + Issue Slots Busy % 8.19 + Issued Ipc Active inst/cycle 0.33 + SM Busy % 8.19 + -------------------- ----------- ------------ + + OPT Est. Local Speedup: 93.36% + All compute pipelines are under-utilized. Either this workload is very small or it doesn't issue enough warps + per scheduler. Check the Launch Statistics and Scheduler Statistics sections for further details. + + Section: Memory Workload Analysis + ---------------------------- ----------- ------------ + Metric Name Metric Unit Metric Value + ---------------------------- ----------- ------------ + Memory Throughput Gbyte/s 808.46 + Mem Busy % 17.15 + Max Bandwidth % 93.71 + L1/TEX Hit Rate % 0.02 + L2 Compression Success Rate % 0 + L2 Compression Ratio 0 + L2 Compression Input Sectors sector 0 + L2 Hit Rate % 0.58 + Mem Pipes Busy % 9.61 + ---------------------------- ----------- ------------ + + Section: Memory Workload Analysis Tables + OPT Est. Speedup: 25.87% + The memory access pattern for global stores to L2 might not be optimal. On average, only 4.0 of the 32 bytes + transmitted per sector are utilized by each thread. This applies to the 88.9% of sectors missed in L1TEX. + This could possibly be caused by a stride between threads. Check the Source Counters section for uncoalesced + global stores. + + Section: Scheduler Statistics + ---------------------------- ----------- ------------ + Metric Name Metric Unit Metric Value + ---------------------------- ----------- ------------ + One or More Eligible % 8.20 + Issued Warp Per Scheduler 0.08 + No Eligible % 91.80 + Active Warps Per Scheduler warp 9.36 + Eligible Warps Per Scheduler warp 0.10 + ---------------------------- ----------- ------------ + + OPT Est. Local Speedup: 6.29% + Every scheduler is capable of issuing one instruction per cycle, but for this workload each scheduler only + issues an instruction every 12.2 cycles. This might leave hardware resources underutilized and may lead to + less optimal performance. Out of the maximum of 12 warps per scheduler, this workload allocates an average + of 9.36 active warps per scheduler, but only an average of 0.10 warps were eligible per cycle. Eligible + warps are the subset of active warps that are ready to issue their next instruction. Every cycle with no + eligible warp results in no instruction being issued and the issue slot remains unused. To increase the + number of eligible warps, avoid possible load imbalances due to highly different execution durations per + warp. Reducing stalls indicated on the Warp State Statistics and Source Counters sections can help, too. + + Section: Warp State Statistics + ---------------------------------------- ----------- ------------ + Metric Name Metric Unit Metric Value + ---------------------------------------- ----------- ------------ + Warp Cycles Per Issued Instruction cycle 114.07 + Warp Cycles Per Executed Instruction cycle 114.62 + Avg. Active Threads Per Warp 31.89 + Avg. Not Predicated Off Threads Per Warp 31.29 + ---------------------------------------- ----------- ------------ + + OPT Est. Speedup: 6.29% + On average, each warp of this workload spends 106.6 cycles being stalled waiting for a scoreboard dependency + on a L1TEX (local, global, surface, texture) operation. Find the instruction producing the data being waited + upon to identify the culprit. To reduce the number of cycles waiting on L1TEX data accesses verify the + memory access patterns are optimal for the target architecture, attempt to increase cache hit rates by + increasing data locality (coalescing), or by changing the cache configuration. Consider moving frequently + used data to shared memory. This stall type represents about 93.5% of the total average of 114.1 cycles + between issuing two instructions. + ----- -------------------------------------------------------------------------------------------------------------- + INF Check the Warp Stall Sampling (All Samples) table for the top stall locations in your source based on + sampling data. The Kernel Profiling Guide + (https://docs.nvidia.com/nsight-compute/ProfilingGuide/index.html#metrics-reference) provides more details + on each stall reason. + + Section: Instruction Statistics + ---------------------------------------- ----------- ------------ + Metric Name Metric Unit Metric Value + ---------------------------------------- ----------- ------------ + Avg. Executed Instructions Per Scheduler inst 7917.97 + Executed Instructions inst 4497408 + Avg. Issued Instructions Per Scheduler inst 7956.46 + Issued Instructions inst 4519271 + ---------------------------------------- ----------- ------------ + + OPT Est. Speedup: 2.06% + This kernel executes 0 fused and 675840 non-fused FP32 instructions. By converting pairs of non-fused + instructions to their fused (https://docs.nvidia.com/cuda/floating-point/#cuda-and-floating-point), + higher-throughput equivalent, the achieved FP32 performance could be increased by up to 50% (relative to its + current performance). Check the Source page to identify where this kernel executes FP32 instructions. + + Section: Launch Statistics + -------------------------------- --------------- --------------- + Metric Name Metric Unit Metric Value + -------------------------------- --------------- --------------- + Block Size 128 + Function Cache Configuration CachePreferNone + Grid Size 4096 + Registers Per Thread register/thread 42 + Shared Memory Configuration Size Kbyte 32.77 + Driver Shared Memory Per Block Kbyte/block 1.02 + Dynamic Shared Memory Per Block byte/block 128 + Static Shared Memory Per Block byte/block 0 + # SMs SM 142 + Stack Size 1024 + Threads thread 524288 + # TPCs 71 + Enabled TPC IDs all + Uses Green Context 0 + Waves Per SM 2.88 + -------------------------------- --------------- --------------- + + OPT Est. Speedup: 33.33% + A wave of thread blocks is defined as the maximum number of blocks that can be executed in parallel on the + target GPU. The number of blocks in a wave depends on the number of multiprocessors and the theoretical + occupancy of the kernel. This kernel launch results in 2 full waves and a partial wave of 1257 thread + blocks. Under the assumption of a uniform execution duration of all thread blocks, this partial wave may + account for up to 33.3% of the total runtime of this kernel. Try launching a grid with no partial wave. The + overall impact of this tail effect also lessens with the number of full waves executed for a grid. See the + Hardware Model (https://docs.nvidia.com/nsight-compute/ProfilingGuide/index.html#metrics-hw-model) + description for more details on launch configurations. + + Section: Occupancy + ------------------------------- ----------- ------------ + Metric Name Metric Unit Metric Value + ------------------------------- ----------- ------------ + Block Limit SM block 24 + Block Limit Registers block 10 + Block Limit Shared Mem block 28 + Block Limit Warps block 12 + Theoretical Active Warps per SM warp 40 + Theoretical Occupancy % 83.33 + Achieved Occupancy % 78.04 + Achieved Active Warps Per SM warp 37.46 + ------------------------------- ----------- ------------ + + Section: GPU and Memory Workload Distribution + -------------------------- ----------- ------------ + Metric Name Metric Unit Metric Value + -------------------------- ----------- ------------ + Average DRAM Active Cycles cycle 793101.33 + Total DRAM Elapsed Cycles cycle 10156032 + Average L1 Active Cycles cycle 97148.04 + Total L1 Elapsed Cycles cycle 14235170 + Average L2 Active Cycles cycle 126206.75 + Total L2 Elapsed Cycles cycle 6306576 + Average SM Active Cycles cycle 97148.04 + Total SM Elapsed Cycles cycle 14235170 + Average SMSP Active Cycles cycle 97011.88 + Total SMSP Elapsed Cycles cycle 56940680 + -------------------------- ----------- ------------ + + Section: Source Counters + ------------------------- ----------- ------------ + Metric Name Metric Unit Metric Value + ------------------------- ----------- ------------ + Branch Instructions Ratio % 0.02 + Branch Instructions inst 106496 + Branch Efficiency % 100 + Avg. Divergent Branches 0 + ------------------------- ----------- ------------ From 3338ae1ff04f7151d9564440bfaa84b06b385d54 Mon Sep 17 00:00:00 2001 From: jonah Date: Sat, 31 Jan 2026 09:33:13 -0800 Subject: [PATCH 4/6] float4 load on last dimension --- forge_cute_py/kernels/reduce_sum.py | 146 ++++++++++++---------------- forge_cute_py/ops/reduce_sum.py | 46 +++++---- 2 files changed, 87 insertions(+), 105 deletions(-) diff --git a/forge_cute_py/kernels/reduce_sum.py b/forge_cute_py/kernels/reduce_sum.py index 2fc896a..bbea019 100644 --- a/forge_cute_py/kernels/reduce_sum.py +++ b/forge_cute_py/kernels/reduce_sum.py @@ -20,6 +20,7 @@ _reduce_sum_first_cache = {} +# old just for future reference @dsl_user_op def atomicAddF32(dst_ptr: cute.Pointer, val: cute.Float32, loc=None, ip=None) -> cute.Float32: return nvvm.atomicrmw( @@ -34,106 +35,69 @@ def atomicAddF32(dst_ptr: cute.Pointer, val: cute.Float32, loc=None, ip=None) -> ) + @cute.kernel -def og_reduce_sum_kernel_last(input: cute.Tensor, output: cute.Tensor, num_warps: int): +def reduce_sum_kernel_last(input: cute.Tensor, output: cute.Tensor, tv_layout, num_warps: int): smem_alloc = cutlass.utils.SmemAllocator() - smem_layout = cute.make_layout((32,)) - shmem = smem_alloc.allocate_tensor(cute.Float32, smem_layout) - - _, N = input.shape + shmem = smem_alloc.allocate_tensor(cute.Float32, cute.make_layout((32,))) + tidx, _, _ = cute.arch.thread_idx() bidx, _, _ = cute.arch.block_idx() - bdimx, _, _ = cute.arch.block_dim() - lane_idx = cute.arch.lane_idx() - warp_idx = cute.arch.warp_idx() + lane = cute.arch.lane_idx() + warp = cute.arch.warp_idx() - max_iters = cute.ceil_div(N, bdimx) + # we want to load as float4 + op = cute.nvgpu.CopyUniversalOp() + atom = cute.make_copy_atom(op, cute.Float32, num_bits_per_copy=128) + acc = cute.Float32(0.0) - acc = cute.Float32(0) - for i in range(max_iters): - idx = tidx + i * bdimx - if idx < N: - acc = acc + input[bidx, idx] - acc = cute.arch.warp_reduction_sum(acc) - if lane_idx == 0: - shmem[warp_idx] = acc - cute.arch.sync_threads() - if warp_idx == 0: - acc = shmem[lane_idx] if lane_idx < num_warps else 0.0 - acc = cute.arch.warp_reduction_sum(acc) - if lane_idx == 0: - output[bidx] = acc + _, mode1 = input.shape + ntiles = mode1[1] + for tile_idx in range(ntiles): + blk_coord = ((0, None), (bidx, tile_idx)) # all values in this [bidx, tile_idx] tile + cta_tile = input[blk_coord] -@cute.kernel -def reduce_sum_kernel_last(input: cute.Tensor, output: cute.Tensor, num_warps: int): - ROWS_PER_BLOCK = 4 - WARPS_PER_ROW = num_warps // ROWS_PER_BLOCK - THREADS_PER_ROW = WARPS_PER_ROW * 32 - - smem_alloc = cutlass.utils.SmemAllocator() - smem_layout = cute.make_layout((ROWS_PER_BLOCK, 32)) - shmem = smem_alloc.allocate_tensor(cute.Float32, smem_layout) - - M, N = input.shape - tidx, _, _ = cute.arch.thread_idx() - bidx, _, _ = cute.arch.block_idx() - lane_idx = cute.arch.lane_idx() - warp_idx = cute.arch.warp_idx() - - block_row = warp_idx // WARPS_PER_ROW - warp_in_row = warp_idx % WARPS_PER_ROW - tid_in_row = tidx % THREADS_PER_ROW + thr_frag = cute.composition(cta_tile, tv_layout) + thr_src = thr_frag[(tidx, None)] - og_row = bidx * ROWS_PER_BLOCK - row = og_row + block_row + # register memory for float4 + r = cute.make_rmem_tensor(cute.make_layout((4,), stride=(1,)), cute.Float32) + # atom, src, dst + cute.copy(atom, thr_src, r) - max_iters = cute.ceil_div(N, THREADS_PER_ROW) - - acc = cute.Float32(0) - for i in range(max_iters): - col = tid_in_row + i * THREADS_PER_ROW - if col < N and row < M: - acc = acc + input[row, col] + acc += r[0] + r[1] + r[2] + r[3] acc = cute.arch.warp_reduction_sum(acc) - - if lane_idx == 0: - shmem[block_row, warp_in_row] = acc - + if lane == 0: + shmem[warp] = acc cute.arch.sync_threads() - if warp_idx < ROWS_PER_BLOCK: - v = shmem[warp_idx, lane_idx] if lane_idx < WARPS_PER_ROW else 0.0 - v = cute.arch.warp_reduction_sum(v) - if lane_idx == 0: - out_row = og_row + warp_idx - if out_row < M: - output[out_row] = v - -@cute.jit -def _og_reduce_sum_last(x, output): - num_warps = 4 - threads_per_block = 32 * num_warps - m, _ = x.shape - og_reduce_sum_kernel_last(x, output, num_warps - ).launch( grid=(m, 1, 1), block=(threads_per_block, 1, 1)) + if warp == 0: + acc2 = shmem[lane] if lane < num_warps else 0.0 + acc2 = cute.arch.warp_reduction_sum(acc2) + if lane == 0: + output[bidx] = acc2 @cute.jit def _reduce_sum_last(x, output): - # num_warps = 4 - # threads_per_block = 32 * num_warps - # m, _ = x.shape - # reduce_sum_kernel_last(x, output, num_warps - # ).launch( grid=(m, 1, 1), block=(threads_per_block, 1, 1)) - num_warps = 32 - ROWS_PER_BLOCK = 4 - threads_per_block = 32 * num_warps - m, _ = x.shape - blocks = cute.ceil_div(m, ROWS_PER_BLOCK) - reduce_sum_kernel_last(x, output, num_warps - ).launch( grid=(blocks, 1, 1), block=(threads_per_block, 1, 1)) + num_warps = 16 + threads_per_block = 512 + M, N = x.shape + + tiler_mn = (1, 2048) + gX = cute.zipped_divide(x, tiler_mn) + + thr_layout = cute.make_layout((threads_per_block,), stride=(1,)) + val_layout = cute.make_layout((4,), stride=(1,)) + _, tv_layout = cute.make_layout_tv(thr_layout, val_layout) + + reduce_sum_kernel_last(gX, output, tv_layout, num_warps).launch( + grid=(M, 1, 1), + block=(threads_per_block, 1, 1), + ) + @cute.kernel @@ -189,7 +153,7 @@ def reduce_sum(x, dim=-1): if cache_key not in _reduce_sum_last_cache: print("compiling...") _reduce_sum_last_cache[cache_key] = cute.compile( - _og_reduce_sum_last, from_dlpack(x), from_dlpack(output) + _reduce_sum_last, from_dlpack(x, assumed_align=16), from_dlpack(output) ) _reduce_sum_last_cache[cache_key](from_dlpack(x), from_dlpack(output)) else: @@ -290,5 +254,19 @@ def ncu_test(): y = reduce_sum(x, dim=-1) torch.cuda.synchronize() -benchmark() +''' +Correctness checks: against 4096x4096 +compiling... + dim=-1: ✓ PASS +compiling... + dim= 0: ✓ PASS + +Benchmarks: + reduce_sum dim=-1: 0.043 ms + reduce_sum dim=0: 0.042 ms + torch.sum dim=-1: 0.011 ms + torch.sum dim=0: 0.019 ms +''' + +# benchmark() # ncu_test() diff --git a/forge_cute_py/ops/reduce_sum.py b/forge_cute_py/ops/reduce_sum.py index d501eca..e2cd2a2 100644 --- a/forge_cute_py/ops/reduce_sum.py +++ b/forge_cute_py/ops/reduce_sum.py @@ -1,35 +1,39 @@ import torch +import cutlass.cute as cute +from cutlass import BFloat16, Float16, Float32 +from cutlass.cute.runtime import from_dlpack -@torch.library.custom_op("forge_cute_py::_reduce_sum", mutates_args={"out"}) -def _reduce_sum(x: torch.Tensor, out: torch.Tensor, dim: int = -1, variant: str = "shfl") -> None: - """Row/column sum reduction (reference implementation stub). +from forge_cute_py.kernels.reduce_sum import _reduce_sum_last, _reduce_sum_first - Args: - x: Input tensor of shape (M, N) - out: Output tensor (mutated in-place) - dim: Dimension to reduce over (-1, 0, or 1) - variant: Reduction variant (naive, improved, shfl) - currently unused - """ +_compile_cache = {} + +@torch.library.custom_op("forge_cute_py::_reduce_sum", mutates_args={"out"}) +def _reduce_sum(x: torch.Tensor, out: torch.Tensor, dim: int = -1) -> None: + """Sum reduction using CuTe DSL.""" assert x.dim() == 2, "reduce_sum expects a 2D tensor" assert x.is_cuda, f"reduce_sum is CUDA-only, got device={x.device}" - assert dim in (-1, 0, 1), f"reduce_sum expects dim in {{-1, 0, 1}}, got {dim}" - assert x.dtype in [torch.float16, torch.bfloat16, torch.float32], ( - f"Unsupported dtype: {x.dtype}" - ) - # Normalize dim to positive index dim = dim if dim >= 0 else x.ndim + dim - # For now, use reference implementation - # Future: call kernel implementation based on variant when available - from forge_cute_py.ref import reduce_sum as reduce_sum_ref - - result = reduce_sum_ref(x, dim=dim) - out.copy_(result) + dtype_map = { + torch.float16: Float16, + torch.float32: Float32, + torch.bfloat16: BFloat16, + } + cute_dtype = dtype_map[x.dtype] + + compile_key = (cute_dtype, dim, x.shape) + if compile_key not in _compile_cache: + jit_fn = _reduce_sum_last if dim == 1 else _reduce_sum_first + _compile_cache[compile_key] = cute.compile( + jit_fn, + from_dlpack(x, assumed_align=16), + from_dlpack(out), + ) -_reduce_sum.compile_cache = {} + _compile_cache[compile_key](from_dlpack(x), from_dlpack(out)) def reduce_sum(x: torch.Tensor, dim: int = -1, variant: str = "shfl") -> torch.Tensor: From d23dfd0c4dc5cbb1902e31dc0b794f1919b55f8c Mon Sep 17 00:00:00 2001 From: jonah Date: Mon, 2 Feb 2026 07:28:34 -0800 Subject: [PATCH 5/6] testing --- forge_cute_py/kernels/reduce_sum.py | 16 ++++++++------ forge_cute_py/ops/reduce_sum.py | 2 +- forge_cute_py/util/profile_launch.py | 32 ++++++++++++++++++++++++++++ 3 files changed, 42 insertions(+), 8 deletions(-) create mode 100644 forge_cute_py/util/profile_launch.py diff --git a/forge_cute_py/kernels/reduce_sum.py b/forge_cute_py/kernels/reduce_sum.py index bbea019..6dab0dd 100644 --- a/forge_cute_py/kernels/reduce_sum.py +++ b/forge_cute_py/kernels/reduce_sum.py @@ -82,15 +82,15 @@ def reduce_sum_kernel_last(input: cute.Tensor, output: cute.Tensor, tv_layout, n @cute.jit def _reduce_sum_last(x, output): - num_warps = 16 - threads_per_block = 512 + num_warps = 8 + threads_per_block = num_warps * 32 M, N = x.shape - tiler_mn = (1, 2048) - gX = cute.zipped_divide(x, tiler_mn) + tiler_mn = (1, threads_per_block) + gX = cute.zipped_divide(x, tiler_mn) # [M, (x, threads_per_block)] - thr_layout = cute.make_layout((threads_per_block,), stride=(1,)) - val_layout = cute.make_layout((4,), stride=(1,)) + thr_layout = cute.make_layout((threads_per_block,), stride=(1,)) # single tile view + val_layout = cute.make_layout((4,), stride=(1,)) # vectorized float4 _, tv_layout = cute.make_layout_tv(thr_layout, val_layout) reduce_sum_kernel_last(gX, output, tv_layout, num_warps).launch( @@ -153,7 +153,9 @@ def reduce_sum(x, dim=-1): if cache_key not in _reduce_sum_last_cache: print("compiling...") _reduce_sum_last_cache[cache_key] = cute.compile( - _reduce_sum_last, from_dlpack(x, assumed_align=16), from_dlpack(output) + _reduce_sum_last, from_dlpack(x, assumed_align=16), from_dlpack(output), + cute.runtime.make_fake_stream(use_tvm_ffi_env_stream=True), + options="--enable-tvm-ffi", ) _reduce_sum_last_cache[cache_key](from_dlpack(x), from_dlpack(output)) else: diff --git a/forge_cute_py/ops/reduce_sum.py b/forge_cute_py/ops/reduce_sum.py index e2cd2a2..7b5cd83 100644 --- a/forge_cute_py/ops/reduce_sum.py +++ b/forge_cute_py/ops/reduce_sum.py @@ -65,5 +65,5 @@ def reduce_sum(x: torch.Tensor, dim: int = -1, variant: str = "shfl") -> torch.T raise ValueError(f"Invalid dim={dim} for 2D tensor") out = torch.empty(out_shape, dtype=x.dtype, device=x.device) - _reduce_sum(x, out, dim, variant) + _reduce_sum(x, out, dim) return out diff --git a/forge_cute_py/util/profile_launch.py b/forge_cute_py/util/profile_launch.py new file mode 100644 index 0000000..cad20dd --- /dev/null +++ b/forge_cute_py/util/profile_launch.py @@ -0,0 +1,32 @@ +import torch + +from forge_cute_py.kernels.reduce_sum import reduce_sum + +def main(): + M, N = 4096, 4096 + x = torch.randn(M, N, device='cuda', dtype=torch.float32) + + # Warmup + print("Warming up...") + for _ in range(10): + _ = reduce_sum(x, dim=-1) + _ = x.sum(dim=-1) + torch.cuda.synchronize() + print("Warmup complete") + + # Profile cute + torch.cuda.nvtx.range_push("cute_reduce_sum") + for _ in range(100): + _ = reduce_sum(x, dim=-1) + torch.cuda.synchronize() + torch.cuda.nvtx.range_pop() + + # Profile torch + torch.cuda.nvtx.range_push("torch_sum") + for _ in range(100): + _ = x.sum(dim=-1) + torch.cuda.synchronize() + torch.cuda.nvtx.range_pop() + +if __name__ == "__main__": + main() \ No newline at end of file From 8ecc59069820f60fb072b53338fdca8c9567192e Mon Sep 17 00:00:00 2001 From: jonah Date: Mon, 2 Feb 2026 08:22:24 -0800 Subject: [PATCH 6/6] testing more --- forge_cute_py/kernels/reduce_sum.py | 366 ++++++++++----------------- forge_cute_py/ops/reduce_sum.py | 63 ++--- forge_cute_py/util/profile_launch.py | 2 +- 3 files changed, 155 insertions(+), 276 deletions(-) diff --git a/forge_cute_py/kernels/reduce_sum.py b/forge_cute_py/kernels/reduce_sum.py index 6dab0dd..f02fc77 100644 --- a/forge_cute_py/kernels/reduce_sum.py +++ b/forge_cute_py/kernels/reduce_sum.py @@ -2,14 +2,16 @@ import os os.environ['CUDA_VISIBLE_DEVICES'] = '0' os.environ['CUTLASS_CUDA_ARCH'] = '86' +os.environ['CUTE_DSL_ENABLE_TVM_FFI'] = '1' import math import torch import time from cutlass.cute.runtime import from_dlpack - -import cutlass +import cuda.bindings.driver as cuda import cutlass.cute as cute +from cutlass import const_expr +import cutlass from cutlass import dsl_user_op from cutlass.cute.arch import nvvm @@ -35,240 +37,128 @@ def atomicAddF32(dst_ptr: cute.Pointer, val: cute.Float32, loc=None, ip=None) -> ) - -@cute.kernel -def reduce_sum_kernel_last(input: cute.Tensor, output: cute.Tensor, tv_layout, num_warps: int): - smem_alloc = cutlass.utils.SmemAllocator() - shmem = smem_alloc.allocate_tensor(cute.Float32, cute.make_layout((32,))) - - tidx, _, _ = cute.arch.thread_idx() - bidx, _, _ = cute.arch.block_idx() - lane = cute.arch.lane_idx() - warp = cute.arch.warp_idx() - - # we want to load as float4 - op = cute.nvgpu.CopyUniversalOp() - atom = cute.make_copy_atom(op, cute.Float32, num_bits_per_copy=128) - acc = cute.Float32(0.0) - - _, mode1 = input.shape - ntiles = mode1[1] - - for tile_idx in range(ntiles): - blk_coord = ((0, None), (bidx, tile_idx)) # all values in this [bidx, tile_idx] tile - cta_tile = input[blk_coord] - - thr_frag = cute.composition(cta_tile, tv_layout) - thr_src = thr_frag[(tidx, None)] - - # register memory for float4 - r = cute.make_rmem_tensor(cute.make_layout((4,), stride=(1,)), cute.Float32) - # atom, src, dst - cute.copy(atom, thr_src, r) - - acc += r[0] + r[1] + r[2] + r[3] - - acc = cute.arch.warp_reduction_sum(acc) - if lane == 0: - shmem[warp] = acc - cute.arch.sync_threads() - - if warp == 0: - acc2 = shmem[lane] if lane < num_warps else 0.0 - acc2 = cute.arch.warp_reduction_sum(acc2) +class ReduceSumLast: + """Sum reduction along last dimension using CuTe DSL.""" + + def __init__(self, dtype: type): + self.dtype = dtype + self.num_warps = 8 + self.threads_per_block = self.num_warps * 32 + + @cute.jit + def __call__( + self, + input: cute.Tensor, + output: cute.Tensor, + # stream: cuda.CUstream = None, + ): + M, N = input.shape + tiler_mn = (1, self.threads_per_block * 4) + gX = cute.zipped_divide(input, tiler_mn) + + thr_layout = cute.make_layout((self.threads_per_block,), stride=(1,)) + val_layout = cute.make_layout((4,), stride=(1,)) + _, tv_layout = cute.make_layout_tv(thr_layout, val_layout) + + self.kernel(gX, output, tv_layout).launch( + grid=(M, 1, 1), + block=(self.threads_per_block, 1, 1), + # stream=stream, + ) + + @cute.kernel + def kernel(self, input: cute.Tensor, output: cute.Tensor, tv_layout): + num_warps = const_expr(self.num_warps) + + smem_alloc = cutlass.utils.SmemAllocator() + shmem = smem_alloc.allocate_tensor(cute.Float32, cute.make_layout((32,))) + + tidx, _, _ = cute.arch.thread_idx() + bidx, _, _ = cute.arch.block_idx() + lane = cute.arch.lane_idx() + warp = cute.arch.warp_idx() + + # op = cute.nvgpu.CopyUniversalOp() + # atom = cute.make_copy_atom(op, cute.Float32, num_bits_per_copy=128) + acc = cute.Float32(0.0) + + _, mode1 = input.shape + ntiles = mode1[1] + + for tile_idx in range(ntiles): + blk_coord = ((0, None), (bidx, tile_idx)) + cta_tile = input[blk_coord] + thr_frag = cute.composition(cta_tile, tv_layout) + thr_src = thr_frag[(tidx, None)] + # r = cute.make_rmem_tensor(cute.make_layout((4,), stride=(1,)), cute.Float32) + # cute.copy(atom, thr_src, r) + # acc += r[0] + r[1] + r[2] + r[3] + acc += thr_src[0] + thr_src[1] + thr_src[2] + thr_src[3] + + + acc = cute.arch.warp_reduction_sum(acc) if lane == 0: - output[bidx] = acc2 - - -@cute.jit -def _reduce_sum_last(x, output): - num_warps = 8 - threads_per_block = num_warps * 32 - M, N = x.shape - - tiler_mn = (1, threads_per_block) - gX = cute.zipped_divide(x, tiler_mn) # [M, (x, threads_per_block)] - - thr_layout = cute.make_layout((threads_per_block,), stride=(1,)) # single tile view - val_layout = cute.make_layout((4,), stride=(1,)) # vectorized float4 - _, tv_layout = cute.make_layout_tv(thr_layout, val_layout) - - reduce_sum_kernel_last(gX, output, tv_layout, num_warps).launch( - grid=(M, 1, 1), - block=(threads_per_block, 1, 1), - ) - - - -@cute.kernel -def reduce_sum_kernel_first(input: cute.Tensor, output: cute.Tensor, stride: int): - smem_alloc = cutlass.utils.SmemAllocator() - smem_layout = cute.make_layout((4, 32)) - shmem = smem_alloc.allocate_tensor(cute.Float32, smem_layout) - - M, N = input.shape - tidx, _, _ = cute.arch.thread_idx() - bidx, _, _ = cute.arch.block_idx() - lane_idx = cute.arch.lane_idx() - warp_idx = cute.arch.warp_idx() - - max_iters = cute.ceil_div(M, stride) - col_offset = tidx % 4 - row_offset = tidx // 4 - col = 4 * bidx + col_offset - acc = cute.Float32(0) - - row = row_offset - for _ in range(max_iters): - if row < M and col < N: - acc = acc + input[row, col] - row = row + 32 - - shmem[col_offset, row_offset] = acc - cute.arch.sync_threads() - acc = shmem[warp_idx, lane_idx] - - acc = cute.arch.warp_reduction_sum(acc) - if lane_idx == 0: - output[bidx * 4 + warp_idx] = acc - - -@cute.jit -def _reduce_sum_first(x, output): - num_warps = 4 - threads_per_block = num_warps * 32 - m, n = x.shape - yblocks = cute.ceil_div(n, 4) - reduce_sum_kernel_first(x, output, threads_per_block // 4 - ).launch( - grid=(yblocks, 1, 1), - block=(threads_per_block, 1, 1) - ) - - -def reduce_sum(x, dim=-1): - cache_key = (x.dtype, x.shape) - if dim == -1: - output = torch.empty((x.size(0),), device=x.device, dtype=x.dtype) - if cache_key not in _reduce_sum_last_cache: - print("compiling...") - _reduce_sum_last_cache[cache_key] = cute.compile( - _reduce_sum_last, from_dlpack(x, assumed_align=16), from_dlpack(output), - cute.runtime.make_fake_stream(use_tvm_ffi_env_stream=True), - options="--enable-tvm-ffi", - ) - _reduce_sum_last_cache[cache_key](from_dlpack(x), from_dlpack(output)) - else: - output = torch.empty((x.size(1),), device=x.device, dtype=x.dtype) - if cache_key not in _reduce_sum_first_cache: - print("compiling...") - _reduce_sum_first_cache[cache_key] = cute.compile( - _reduce_sum_first, from_dlpack(x), from_dlpack(output) - ) - _reduce_sum_first_cache[cache_key](from_dlpack(x), from_dlpack(output)) - - return output - - -def test(): - for dim in [-1, 0]: - M, N = 1100, 1200 - a = torch.randn((M, N), device="cuda", dtype=torch.float32) - output = reduce_sum(a, dim=dim) - close = torch.allclose(output, a.sum(dim), rtol=1e-3) - assert close, f"Error along dimension: {dim}" - print("tests pass") - - -def benchmark(): - import time - M, N = 4096, 4096 - x = torch.randn(M, N, device='cuda', dtype=torch.float32) - - # Correctness checks - print("Correctness checks:") - for dim in [-1, 0]: - result = reduce_sum(x, dim=dim) - expected = x.sum(dim=dim) - is_close = torch.allclose(result, expected, rtol=1e-3, atol=1e-4) - print(f" dim={dim:2d}: {'✓ PASS' if is_close else '✗ FAIL'}") - if not is_close: - max_diff = (result - expected).abs().max().item() - print(f" max diff: {max_diff}") - - print("\nBenchmarks:") - - # Warmup - for _ in range(10): - _ = reduce_sum(x, dim=-1) - _ = reduce_sum(x, dim=0) - torch.cuda.synchronize() - - # Benchmark dim=-1 - del x - x = torch.randn(M, N, device='cuda', dtype=torch.float32) - start = time.perf_counter() - for _ in range(100): - _ = reduce_sum(x, dim=-1) - torch.cuda.synchronize() - print(f" reduce_sum dim=-1: {(time.perf_counter() - start) * 10:.3f} ms") - - # Benchmark dim=0 - del x - x = torch.randn(M, N, device='cuda', dtype=torch.float32) - start = time.perf_counter() - for _ in range(100): - _ = reduce_sum(x, dim=0) - torch.cuda.synchronize() - print(f" reduce_sum dim=0: {(time.perf_counter() - start) * 10:.3f} ms") - - # Compare to PyTorch - del x - x = torch.randn(M, N, device='cuda', dtype=torch.float32) - start = time.perf_counter() - for _ in range(100): - _ = x.sum(dim=-1) - torch.cuda.synchronize() - print(f" torch.sum dim=-1: {(time.perf_counter() - start) * 10:.3f} ms") - - del x - x = torch.randn(M, N, device='cuda', dtype=torch.float32) - start = time.perf_counter() - for _ in range(100): - _ = x.sum(dim=0) - torch.cuda.synchronize() - print(f" torch.sum dim=0: {(time.perf_counter() - start) * 10:.3f} ms") - - -''' -sudo systemctl stop dcgm -/usr/local/cuda-12.8/bin/ncu --set full -o reduce_sum_profile uv run python run.py -/usr/local/cuda-12.8/bin/ncu --import reduce_sum_profile.ncu-rep -''' -def ncu_test(): - x = torch.randn(4096, 4096, device='cuda', dtype=torch.float32) - - # Warmup (compiles the kernel) - _ = reduce_sum(x, dim=-1) - torch.cuda.synchronize() - - # Profile this run - y = reduce_sum(x, dim=-1) - torch.cuda.synchronize() - -''' -Correctness checks: against 4096x4096 -compiling... - dim=-1: ✓ PASS -compiling... - dim= 0: ✓ PASS - -Benchmarks: - reduce_sum dim=-1: 0.043 ms - reduce_sum dim=0: 0.042 ms - torch.sum dim=-1: 0.011 ms - torch.sum dim=0: 0.019 ms -''' - -# benchmark() -# ncu_test() + shmem[warp] = acc + cute.arch.sync_threads() + + if warp == 0: + acc2 = shmem[lane] if lane < num_warps else 0.0 + acc2 = cute.arch.warp_reduction_sum(acc2) + if lane == 0: + output[bidx] = acc2 + + +class ReduceSumFirst: + """Sum reduction along first dimension using CuTe DSL.""" + + def __init__(self, dtype: type): + self.dtype = dtype + self.num_warps = 4 + self.threads_per_block = self.num_warps * 32 + + @cute.jit + def __call__( + self, + input: cute.Tensor, + output: cute.Tensor, + # stream: cuda.CUstream = None, + ): + M, N = input.shape + yblocks = cute.ceil_div(N, 4) + self.kernel(input, output, self.threads_per_block // 4).launch( + grid=(yblocks, 1, 1), + block=(self.threads_per_block, 1, 1), + # stream=stream, + ) + + @cute.kernel + def kernel(self, input: cute.Tensor, output: cute.Tensor, stride: int): + smem_alloc = cutlass.utils.SmemAllocator() + smem_layout = cute.make_layout((4, 32)) + shmem = smem_alloc.allocate_tensor(cute.Float32, smem_layout) + + M, N = input.shape + tidx, _, _ = cute.arch.thread_idx() + bidx, _, _ = cute.arch.block_idx() + lane_idx = cute.arch.lane_idx() + warp_idx = cute.arch.warp_idx() + + max_iters = cute.ceil_div(M, stride) + col_offset = tidx % 4 + row_offset = tidx // 4 + col = 4 * bidx + col_offset + acc = cute.Float32(0) + + row = row_offset + for _ in range(max_iters): + if row < M and col < N: + acc = acc + input[row, col] + row = row + 32 + + shmem[col_offset, row_offset] = acc + cute.arch.sync_threads() + acc = shmem[warp_idx, lane_idx] + + acc = cute.arch.warp_reduction_sum(acc) + if lane_idx == 0: + output[bidx * 4 + warp_idx] = acc diff --git a/forge_cute_py/ops/reduce_sum.py b/forge_cute_py/ops/reduce_sum.py index 7b5cd83..e7252c9 100644 --- a/forge_cute_py/ops/reduce_sum.py +++ b/forge_cute_py/ops/reduce_sum.py @@ -1,13 +1,13 @@ -import torch import cutlass.cute as cute - +import torch from cutlass import BFloat16, Float16, Float32 from cutlass.cute.runtime import from_dlpack -from forge_cute_py.kernels.reduce_sum import _reduce_sum_last, _reduce_sum_first +from forge_cute_py.kernels.reduce_sum import ReduceSumLast, ReduceSumFirst _compile_cache = {} + @torch.library.custom_op("forge_cute_py::_reduce_sum", mutates_args={"out"}) def _reduce_sum(x: torch.Tensor, out: torch.Tensor, dim: int = -1) -> None: """Sum reduction using CuTe DSL.""" @@ -22,48 +22,37 @@ def _reduce_sum(x: torch.Tensor, out: torch.Tensor, dim: int = -1) -> None: torch.bfloat16: BFloat16, } cute_dtype = dtype_map[x.dtype] - - compile_key = (cute_dtype, dim, x.shape) + compile_key = (cute_dtype, dim) if compile_key not in _compile_cache: - jit_fn = _reduce_sum_last if dim == 1 else _reduce_sum_first + m = cute.sym_int() + n = cute.sym_int() + input_cute = cute.runtime.make_fake_compact_tensor(cute_dtype, (m, n), stride_order=(1, 0)) + + if dim == 1: # Reduce last dim + output_cute = cute.runtime.make_fake_compact_tensor(cute_dtype, (m,)) + kernel_class = ReduceSumLast(cute_dtype) + else: # dim == 0 + output_cute = cute.runtime.make_fake_compact_tensor(cute_dtype, (n,)) + kernel_class = ReduceSumFirst(cute_dtype) + _compile_cache[compile_key] = cute.compile( - jit_fn, - from_dlpack(x, assumed_align=16), - from_dlpack(out), + kernel_class, + input_cute, + output_cute, + # cute.runtime.make_fake_stream(use_tvm_ffi_env_stream=False), + options="--enable-tvm-ffi", ) - _compile_cache[compile_key](from_dlpack(x), from_dlpack(out)) + x_cute = from_dlpack(x, assumed_align=16) + out_cute = from_dlpack(out, assumed_align=16) + _compile_cache[compile_key](x_cute, out_cute) -def reduce_sum(x: torch.Tensor, dim: int = -1, variant: str = "shfl") -> torch.Tensor: - """Row/column sum reduction. - - Args: - x: Input tensor of shape (M, N) - dim: Dimension to reduce over (-1 for last dim, 0 or 1) - variant: Reduction variant (naive, improved, shfl) - currently unused - - Returns: - Reduced tensor of shape (M,) if dim=1 or (N,) if dim=0 - - Examples: - >>> x = torch.randn(32, 128, device='cuda', dtype=torch.float16) - >>> y = reduce_sum(x, dim=-1) # Sum over columns, result shape: (32,) - >>> y.shape - torch.Size([32]) - """ - # Normalize dim to positive index +def reduce_sum(x: torch.Tensor, dim: int = -1, variant='') -> torch.Tensor: + """Sum reduction with CuTe DSL kernel.""" dim = dim if dim >= 0 else x.ndim + dim - - # Determine output shape - if dim == 0: - out_shape = (x.shape[1],) - elif dim == 1: - out_shape = (x.shape[0],) - else: - raise ValueError(f"Invalid dim={dim} for 2D tensor") - + out_shape = (x.shape[1],) if dim == 0 else (x.shape[0],) out = torch.empty(out_shape, dtype=x.dtype, device=x.device) _reduce_sum(x, out, dim) return out diff --git a/forge_cute_py/util/profile_launch.py b/forge_cute_py/util/profile_launch.py index cad20dd..0098364 100644 --- a/forge_cute_py/util/profile_launch.py +++ b/forge_cute_py/util/profile_launch.py @@ -1,6 +1,6 @@ import torch -from forge_cute_py.kernels.reduce_sum import reduce_sum +from forge_cute_py.ops.reduce_sum import reduce_sum def main(): M, N = 4096, 4096