From 8cc2eb761303a7f4eb38c5d49944faa9efc429d2 Mon Sep 17 00:00:00 2001 From: Lyxot Date: Sun, 8 Feb 2026 18:18:20 +0800 Subject: [PATCH 01/23] Init CUDA radix select --- benchmarks/python/benchmark_radix_select.py | 185 ++++ mlx/backend/cuda/device/radix_select.cuh | 922 ++++++++++++++++++++ mlx/backend/cuda/sort.cu | 278 +++++- 3 files changed, 1383 insertions(+), 2 deletions(-) create mode 100644 benchmarks/python/benchmark_radix_select.py create mode 100644 mlx/backend/cuda/device/radix_select.cuh diff --git a/benchmarks/python/benchmark_radix_select.py b/benchmarks/python/benchmark_radix_select.py new file mode 100644 index 0000000000..078d484cb1 --- /dev/null +++ b/benchmarks/python/benchmark_radix_select.py @@ -0,0 +1,185 @@ +#!/usr/bin/env python3 +""" +Benchmark script for MLX argpartition/partition operations. +Compares radix select implementation against full argsort. +""" + +import time + +import mlx.core as mx +import numpy as np + +GREEN = "\033[92m" +YELLOW = "\033[33m" +RED = "\033[91m" +RESET = "\033[0m" + + +def color_speedup(speedup): + s = f"{speedup:>5.2f}x" + if 0.9 <= speedup <= 1.1: + return f"{YELLOW}{s}{RESET}" + elif speedup > 1.1: + return f"{GREEN}{s}{RESET}" + else: + return f"{RED}{s}{RESET}" + + +def benchmark_argpartition(b, v, k, dtype=mx.bfloat16, warmup=5, iterations=100): + x = mx.random.uniform(shape=(b, v)).astype(dtype) + mx.eval(x) + for _ in range(warmup): + mx.eval(mx.argpartition(x, kth=k, axis=-1)) + start = time.perf_counter() + for _ in range(iterations): + mx.eval(mx.argpartition(x, kth=k, axis=-1)) + return (time.perf_counter() - start) / iterations * 1000 + + +def benchmark_argsort(b, v, dtype=mx.bfloat16, warmup=5, iterations=100): + x = mx.random.uniform(shape=(b, v)).astype(dtype) + mx.eval(x) + for _ in range(warmup): + mx.eval(mx.argsort(x, axis=-1)) + start = time.perf_counter() + for _ in range(iterations): + mx.eval(mx.argsort(x, axis=-1)) + return (time.perf_counter() - start) / iterations * 1000 + + +def verify_correctness(b, v, k): + x = mx.random.uniform(shape=(b, v)).astype(mx.float32) + mx.eval(x) + indices = mx.argpartition(x, kth=k, axis=-1) + mx.eval(indices) + x_np = np.array(x) + indices_np = np.array(indices) + for i in range(b): + pv = x_np[i, indices_np[i]] + assert np.all(pv[:k] <= pv[k]), f"Row {i}: elements before k not all <= kth" + assert np.all(pv[k + 1 :] >= pv[k]), f"Row {i}: elements after k not all >= kth" + return True + + +def sweep_boundary(dtype=mx.bfloat16, k_ratio=0.004, warmup=10, iterations=50): + dtype_name = str(dtype).split(".")[-1] + print(f"\nDtype={dtype_name} k=vocab*{k_ratio:.3f}") + print() + + batch_sizes = [1, 2, 4, 8, 16, 32, 48, 64, 96, 128, 256, 512, 1024, 2048] + vocab_sizes = [512, 1024, 2048, 4096, 8192, 16384, 32768, 65536, 131072] + + col_w = 10 + print(f"{'':>8}", end="") + for v in vocab_sizes: + label = f"v={v}" + print(f" {label:^{col_w}}", end="") + print() + + for b in batch_sizes: + print(f"b={b:<6}", end="") + for v in vocab_sizes: + k = max(1, int(v * k_ratio)) + try: + x = mx.random.uniform(shape=(b, v)).astype(dtype) + mx.eval(x) + for _ in range(warmup): + mx.eval(mx.argpartition(x, kth=k, axis=-1)) + start = time.perf_counter() + for _ in range(iterations): + mx.eval(mx.argpartition(x, kth=k, axis=-1)) + radix_ms = (time.perf_counter() - start) / iterations * 1000 + for _ in range(warmup): + mx.eval(mx.argsort(x, axis=-1)) + start = time.perf_counter() + for _ in range(iterations): + mx.eval(mx.argsort(x, axis=-1)) + argsort_ms = (time.perf_counter() - start) / iterations * 1000 + + speedup = argsort_ms / radix_ms + cell = color_speedup(speedup) + # pad accounting for invisible ANSI codes + print(f" {cell:^{col_w + len(GREEN) + len(RESET)}}", end="") + except Exception: + print(f" {'ERR':^{col_w}}", end="") + print() + + +def main(): + print("=" * 70) + print("MLX Radix Select Benchmark") + print("=" * 70) + + configs = [ + (2048, 8192, 32), + (2048, 4096, 32), + (1024, 4096, 16), + (512, 2048, 64), + (256, 1024, 32), + (128, 512, 16), + (1, 128000, 64), + (1, 512, 32), + (16, 8192, 32), + (32, 8192, 32), + (64, 8192, 32), + ] + + dtypes = [(mx.bfloat16, "bfloat16"), (mx.float32, "float32")] + + print("\n1. Correctness Verification") + print("-" * 40) + for b, v, k in configs: + try: + verify_correctness(b, v, k) + print(f" {GREEN}[PASS]{RESET} b={b}, v={v}, k={k}") + except AssertionError as e: + print(f" {RED}[FAIL]{RESET} b={b}, v={v}, k={k}: {e}") + + print("\n2. Performance Benchmarks") + print("-" * 70) + + for dtype, dtype_name in dtypes: + print(f"\nDtype: {dtype_name}") + print( + f"{'Config':<25} {'ArgPartition':>14} {'ArgSort':>12} {'Speedup':>10}" + ) + print("-" * 80) + + for b, v, k in configs: + try: + argpart_ms = benchmark_argpartition(b, v, k, dtype, warmup=3, iterations=50) + argsort_ms = benchmark_argsort( + b, v, dtype, warmup=3, iterations=50 + ) + speedup = argsort_ms / argpart_ms + config_str = f"b={b}, v={v}, k={k}" + print( + f"{config_str:<25} {argpart_ms:>12.3f}ms" + f" {argsort_ms:>10.3f}ms {color_speedup(speedup)}" + ) + except Exception as e: + print(f"b={b}, v={v}, k={k}: Error - {e}") + + print("\n3. Boundary Sweep") + print("-" * 70) + # sweep_boundary(mx.bool_) + sweep_boundary(mx.bfloat16) + # sweep_boundary(mx.float16) + sweep_boundary(mx.float32) + # sweep_boundary(mx.float64) + # sweep_boundary(mx.int8) + # sweep_boundary(mx.int16) + # sweep_boundary(mx.int32) + # sweep_boundary(mx.int64) + # sweep_boundary(mx.uint8) + # sweep_boundary(mx.uint16) + # sweep_boundary(mx.uint32) + # sweep_boundary(mx.uint64) + + print("\n" + "=" * 70) + print("Benchmark Complete") + print("=" * 70) + + +if __name__ == "__main__": + main() diff --git a/mlx/backend/cuda/device/radix_select.cuh b/mlx/backend/cuda/device/radix_select.cuh new file mode 100644 index 0000000000..c87e2e3b2c --- /dev/null +++ b/mlx/backend/cuda/device/radix_select.cuh @@ -0,0 +1,922 @@ +// Copyright © 2025 Apple Inc. + +#pragma once + +#include "mlx/backend/cuda/device/utils.cuh" +#include +#include +#include +#include + +namespace mlx::core::cu { + +/////////////////////////////////////////////////////////////////////////////// +// Radix Select Implementation for CUDA +// +// Multi-pass radix-based selection algorithm for partition operations. +// Uses IEEE 754 bit manipulation for correct floating-point ordering. +/////////////////////////////////////////////////////////////////////////////// + +// Radix configuration +constexpr int RADIX_BITS = 8; +constexpr int RADIX_SIZE = 1 << RADIX_BITS; // 256 bins + +/////////////////////////////////////////////////////////////////////////////// +// Bit manipulation for radix sorting +/////////////////////////////////////////////////////////////////////////////// + +template +struct RadixTraits; + +template <> +struct RadixTraits { + using UnsignedT = uint32_t; + static constexpr int BITS = 32; + + __device__ __forceinline__ static UnsignedT to_radix(float val) { + uint32_t bits = __float_as_uint(val); + uint32_t mask = -int32_t(bits >> 31) | 0x80000000u; + return bits ^ mask; + } + + __device__ __forceinline__ static float from_radix(UnsignedT bits) { + uint32_t mask = ((bits >> 31) - 1) | 0x80000000u; + return __uint_as_float(bits ^ mask); + } +}; + +template <> +struct RadixTraits { + using UnsignedT = uint64_t; + static constexpr int BITS = 64; + + __device__ __forceinline__ static UnsignedT to_radix(double val) { + uint64_t bits = __double_as_longlong(val); + uint64_t mask = -int64_t(bits >> 63) | 0x8000000000000000ull; + return bits ^ mask; + } + + __device__ __forceinline__ static double from_radix(UnsignedT bits) { + uint64_t mask = ((bits >> 63) - 1) | 0x8000000000000000ull; + return __longlong_as_double(bits ^ mask); + } +}; + +template <> +struct RadixTraits<__half> { + using UnsignedT = uint16_t; + static constexpr int BITS = 16; + + __device__ __forceinline__ static UnsignedT to_radix(__half val) { + uint16_t bits = __half_as_ushort(val); + uint16_t mask = -int16_t(bits >> 15) | 0x8000u; + return bits ^ mask; + } + + __device__ __forceinline__ static __half from_radix(UnsignedT bits) { + uint16_t mask = ((bits >> 15) - 1) | 0x8000u; + return __ushort_as_half(bits ^ mask); + } +}; + +template <> +struct RadixTraits<__nv_bfloat16> { + using UnsignedT = uint16_t; + static constexpr int BITS = 16; + + __device__ __forceinline__ static UnsignedT to_radix(__nv_bfloat16 val) { + uint16_t bits = __bfloat16_as_ushort(val); + uint16_t mask = -int16_t(bits >> 15) | 0x8000u; + return bits ^ mask; + } + + __device__ __forceinline__ static __nv_bfloat16 from_radix(UnsignedT bits) { + uint16_t mask = ((bits >> 15) - 1) | 0x8000u; + return __ushort_as_bfloat16(bits ^ mask); + } +}; + +template <> +struct RadixTraits { + using UnsignedT = uint8_t; + static constexpr int BITS = 8; + + __device__ __forceinline__ static UnsignedT to_radix(int8_t val) { + return static_cast(val) ^ 0x80u; + } + + __device__ __forceinline__ static int8_t from_radix(UnsignedT bits) { + return static_cast(bits ^ 0x80u); + } +}; + +template <> +struct RadixTraits { + using UnsignedT = uint16_t; + static constexpr int BITS = 16; + + __device__ __forceinline__ static UnsignedT to_radix(int16_t val) { + return static_cast(val) ^ 0x8000u; + } + + __device__ __forceinline__ static int16_t from_radix(UnsignedT bits) { + return static_cast(bits ^ 0x8000u); + } +}; + +template <> +struct RadixTraits { + using UnsignedT = uint32_t; + static constexpr int BITS = 32; + + __device__ __forceinline__ static UnsignedT to_radix(int32_t val) { + return static_cast(val) ^ 0x80000000u; + } + + __device__ __forceinline__ static int32_t from_radix(UnsignedT bits) { + return static_cast(bits ^ 0x80000000u); + } +}; + +template <> +struct RadixTraits { + using UnsignedT = uint64_t; + static constexpr int BITS = 64; + + __device__ __forceinline__ static UnsignedT to_radix(int64_t val) { + return static_cast(val) ^ 0x8000000000000000ull; + } + + __device__ __forceinline__ static int64_t from_radix(UnsignedT bits) { + return static_cast(bits ^ 0x8000000000000000ull); + } +}; + +template <> +struct RadixTraits { + using UnsignedT = uint8_t; + static constexpr int BITS = 8; + + __device__ __forceinline__ static UnsignedT to_radix(bool val) { + return static_cast(val); + } + + __device__ __forceinline__ static bool from_radix(UnsignedT bits) { + return bits != 0; + } +}; + +template <> +struct RadixTraits { + using UnsignedT = uint8_t; + static constexpr int BITS = 8; + + __device__ __forceinline__ static UnsignedT to_radix(uint8_t val) { + return val; + } + + __device__ __forceinline__ static uint8_t from_radix(UnsignedT bits) { + return bits; + } +}; + +template <> +struct RadixTraits { + using UnsignedT = uint16_t; + static constexpr int BITS = 16; + + __device__ __forceinline__ static UnsignedT to_radix(uint16_t val) { + return val; + } + + __device__ __forceinline__ static uint16_t from_radix(UnsignedT bits) { + return bits; + } +}; + +template <> +struct RadixTraits { + using UnsignedT = uint32_t; + static constexpr int BITS = 32; + + __device__ __forceinline__ static UnsignedT to_radix(uint32_t val) { + return val; + } + + __device__ __forceinline__ static uint32_t from_radix(UnsignedT bits) { + return bits; + } +}; + +template <> +struct RadixTraits { + using UnsignedT = uint64_t; + static constexpr int BITS = 64; + + __device__ __forceinline__ static UnsignedT to_radix(uint64_t val) { + return val; + } + + __device__ __forceinline__ static uint64_t from_radix(UnsignedT bits) { + return bits; + } +}; + +template +__device__ __forceinline__ int extract_digit( + UnsignedT val, + int start_bit, + int num_bits) { + return (val >> start_bit) & ((1 << num_bits) - 1); +} + +template +__device__ __forceinline__ bool is_nan_value(T val) { + if constexpr (cuda::std::is_floating_point_v) { + return cuda::std::isnan(val); + } else if constexpr (cuda::std::is_same_v) { + return __hisnan(val); + } else if constexpr (cuda::std::is_same_v) { + return __hisnan(val); + } else { + return false; + } +} + +/////////////////////////////////////////////////////////////////////////////// +// Warp-level utilities +/////////////////////////////////////////////////////////////////////////////// + +template +__device__ __forceinline__ T warp_reduce_sum(T val) { + for (int offset = 16; offset > 0; offset /= 2) { + val += __shfl_down_sync(0xFFFFFFFF, val, offset); + } + return val; +} + +/////////////////////////////////////////////////////////////////////////////// +// Single-pass Radix Select for small arrays (fits in shared memory) +/////////////////////////////////////////////////////////////////////////////// + +template < + typename ValT, + typename OutT, + bool ARG_PARTITION, + int BLOCK_THREADS, + int ITEMS_PER_THREAD> +__global__ void radix_select_small_kernel( + const ValT* input, + OutT* output, + int kth, + int n, + int in_stride, + int out_stride, + int in_segment_stride, + int out_segment_stride) { + using Traits = RadixTraits; + using UnsignedT = typename Traits::UnsignedT; + + constexpr int TILE_SIZE = BLOCK_THREADS * ITEMS_PER_THREAD; + constexpr int NUM_PASSES = (Traits::BITS + RADIX_BITS - 1) / RADIX_BITS; + + // Shared memory + __shared__ UnsignedT shared_keys[TILE_SIZE]; + __shared__ uint32_t shared_idxs[TILE_SIZE]; + __shared__ int shared_hist[RADIX_SIZE]; + __shared__ int shared_count[2]; + + int row = blockIdx.y; + const ValT* row_input = input + row * in_segment_stride; + OutT* row_output = output + row * out_segment_stride; + + int tile_n = min(n, TILE_SIZE); + + // Load data into shared memory + for (int i = threadIdx.x; i < TILE_SIZE; i += BLOCK_THREADS) { + if (i < tile_n) { + ValT val = row_input[i * in_stride]; + UnsignedT key = Traits::to_radix(val); + if (is_nan_value(val)) { + key = ~UnsignedT(0); + } + shared_keys[i] = key; + shared_idxs[i] = i; + } else { + shared_keys[i] = ~UnsignedT(0); + shared_idxs[i] = i; + } + } + __syncthreads(); + + // Radix select to find pivot + int k = kth + 1; + UnsignedT target_prefix = 0; + UnsignedT prefix_mask = 0; + + for (int pass = NUM_PASSES - 1; pass >= 0; pass--) { + int start_bit = pass * RADIX_BITS; + + // Clear histogram + for (int i = threadIdx.x; i < RADIX_SIZE; i += BLOCK_THREADS) { + shared_hist[i] = 0; + } + __syncthreads(); + + // Build histogram + for (int i = threadIdx.x; i < tile_n; i += BLOCK_THREADS) { + UnsignedT key = shared_keys[i]; + if ((key & prefix_mask) == target_prefix) { + int digit = extract_digit(key, start_bit, RADIX_BITS); + atomicAdd(&shared_hist[digit], 1); + } + } + __syncthreads(); + + // Find target bin (single thread) + if (threadIdx.x == 0) { + int cumsum = 0; + int target_bin = 0; + for (int bin = 0; bin < RADIX_SIZE; bin++) { + int count = shared_hist[bin]; + if (cumsum + count >= k) { + target_bin = bin; + k = k - cumsum; + break; + } + cumsum += count; + } + shared_count[0] = target_bin; + shared_count[1] = k; + } + __syncthreads(); + + int target_bin = shared_count[0]; + k = shared_count[1]; + + UnsignedT digit_mask = UnsignedT((1 << RADIX_BITS) - 1) << start_bit; + target_prefix |= UnsignedT(target_bin) << start_bit; + prefix_mask |= digit_mask; + + __syncthreads(); + } + + // Output partitioned array + if (threadIdx.x == 0) { + shared_count[0] = 0; + } + __syncthreads(); + + // Phase 1: output elements less than pivot + for (int i = threadIdx.x; i < tile_n; i += BLOCK_THREADS) { + UnsignedT key = shared_keys[i]; + if (key < target_prefix) { + int pos = atomicAdd(&shared_count[0], 1); + if (ARG_PARTITION) { + row_output[pos * out_stride] = shared_idxs[i]; + } else { + row_output[pos * out_stride] = + row_input[shared_idxs[i] * in_stride]; + } + } + } + __syncthreads(); + + // Phase 2: output elements equal to pivot + for (int i = threadIdx.x; i < tile_n; i += BLOCK_THREADS) { + UnsignedT key = shared_keys[i]; + if (key == target_prefix) { + int pos = atomicAdd(&shared_count[0], 1); + if (ARG_PARTITION) { + row_output[pos * out_stride] = shared_idxs[i]; + } else { + row_output[pos * out_stride] = + row_input[shared_idxs[i] * in_stride]; + } + } + } + __syncthreads(); + + // Phase 3: output elements greater than pivot + for (int i = threadIdx.x; i < tile_n; i += BLOCK_THREADS) { + UnsignedT key = shared_keys[i]; + if (key > target_prefix) { + int pos = atomicAdd(&shared_count[0], 1); + if (ARG_PARTITION) { + row_output[pos * out_stride] = shared_idxs[i]; + } else { + row_output[pos * out_stride] = + row_input[shared_idxs[i] * in_stride]; + } + } + } +} + +template < + typename ValT, + typename OutT, + bool ARG_PARTITION, + int BLOCK_THREADS, + int ITEMS_PER_THREAD> +__global__ void radix_select_small_nc_kernel( + const ValT* input, + OutT* output, + int kth, + int n, + int in_stride, + int out_stride, + const __grid_constant__ Shape nc_shape, + const __grid_constant__ Strides in_nc_strides, + const __grid_constant__ Strides out_nc_strides, + int nc_dim) { + using Traits = RadixTraits; + using UnsignedT = typename Traits::UnsignedT; + + constexpr int TILE_SIZE = BLOCK_THREADS * ITEMS_PER_THREAD; + constexpr int NUM_PASSES = (Traits::BITS + RADIX_BITS - 1) / RADIX_BITS; + + __shared__ UnsignedT shared_keys[TILE_SIZE]; + __shared__ uint32_t shared_idxs[TILE_SIZE]; + __shared__ int shared_hist[RADIX_SIZE]; + __shared__ int shared_count[2]; + + int row = blockIdx.y; + int64_t in_block_idx = + elem_to_loc(int64_t(row), nc_shape.data(), in_nc_strides.data(), nc_dim); + int64_t out_block_idx = elem_to_loc( + int64_t(row), nc_shape.data(), out_nc_strides.data(), nc_dim); + const ValT* row_input = input + in_block_idx; + OutT* row_output = output + out_block_idx; + + int tile_n = min(n, TILE_SIZE); + + for (int i = threadIdx.x; i < TILE_SIZE; i += BLOCK_THREADS) { + if (i < tile_n) { + ValT val = row_input[i * in_stride]; + UnsignedT key = Traits::to_radix(val); + if (is_nan_value(val)) { + key = ~UnsignedT(0); + } + shared_keys[i] = key; + shared_idxs[i] = i; + } else { + shared_keys[i] = ~UnsignedT(0); + shared_idxs[i] = i; + } + } + __syncthreads(); + + int k = kth + 1; + UnsignedT target_prefix = 0; + UnsignedT prefix_mask = 0; + + for (int pass = NUM_PASSES - 1; pass >= 0; pass--) { + int start_bit = pass * RADIX_BITS; + + for (int i = threadIdx.x; i < RADIX_SIZE; i += BLOCK_THREADS) { + shared_hist[i] = 0; + } + __syncthreads(); + + for (int i = threadIdx.x; i < tile_n; i += BLOCK_THREADS) { + UnsignedT key = shared_keys[i]; + if ((key & prefix_mask) == target_prefix) { + int digit = extract_digit(key, start_bit, RADIX_BITS); + atomicAdd(&shared_hist[digit], 1); + } + } + __syncthreads(); + + if (threadIdx.x == 0) { + int cumsum = 0; + int target_bin = 0; + for (int bin = 0; bin < RADIX_SIZE; bin++) { + int count = shared_hist[bin]; + if (cumsum + count >= k) { + target_bin = bin; + k = k - cumsum; + break; + } + cumsum += count; + } + shared_count[0] = target_bin; + shared_count[1] = k; + } + __syncthreads(); + + int target_bin = shared_count[0]; + k = shared_count[1]; + + UnsignedT digit_mask = UnsignedT((1 << RADIX_BITS) - 1) << start_bit; + target_prefix |= UnsignedT(target_bin) << start_bit; + prefix_mask |= digit_mask; + + __syncthreads(); + } + + if (threadIdx.x == 0) { + shared_count[0] = 0; + } + __syncthreads(); + + for (int i = threadIdx.x; i < tile_n; i += BLOCK_THREADS) { + UnsignedT key = shared_keys[i]; + if (key < target_prefix) { + int pos = atomicAdd(&shared_count[0], 1); + if (ARG_PARTITION) { + row_output[pos * out_stride] = shared_idxs[i]; + } else { + row_output[pos * out_stride] = row_input[shared_idxs[i] * in_stride]; + } + } + } + __syncthreads(); + + for (int i = threadIdx.x; i < tile_n; i += BLOCK_THREADS) { + UnsignedT key = shared_keys[i]; + if (key == target_prefix) { + int pos = atomicAdd(&shared_count[0], 1); + if (ARG_PARTITION) { + row_output[pos * out_stride] = shared_idxs[i]; + } else { + row_output[pos * out_stride] = row_input[shared_idxs[i] * in_stride]; + } + } + } + __syncthreads(); + + for (int i = threadIdx.x; i < tile_n; i += BLOCK_THREADS) { + UnsignedT key = shared_keys[i]; + if (key > target_prefix) { + int pos = atomicAdd(&shared_count[0], 1); + if (ARG_PARTITION) { + row_output[pos * out_stride] = shared_idxs[i]; + } else { + row_output[pos * out_stride] = row_input[shared_idxs[i] * in_stride]; + } + } + } +} + +/////////////////////////////////////////////////////////////////////////////// +// Large array streaming kernel (multi-pass, in-place) +/////////////////////////////////////////////////////////////////////////////// + +template < + typename ValT, + typename OutT, + bool ARG_PARTITION, + int BLOCK_THREADS> +__global__ void radix_select_large_streaming_kernel( + const ValT* input, + OutT* output, + int n, + int kth, + int in_stride, + int out_stride, + int in_segment_stride, + int out_segment_stride) { + using Traits = RadixTraits; + using UnsignedT = typename Traits::UnsignedT; + constexpr int NUM_PASSES = (Traits::BITS + RADIX_BITS - 1) / RADIX_BITS; + + int row = blockIdx.y; + const ValT* row_input = input + row * in_segment_stride; + OutT* row_output = output + row * out_segment_stride; + + // Shared memory + __shared__ int shared_hist[RADIX_SIZE]; + __shared__ int shared_pivot_info[2]; + __shared__ int shared_counts[2]; + __shared__ int shared_output_counters[3]; + + int k = kth + 1; + UnsignedT target_prefix = 0; + UnsignedT prefix_mask = 0; + + // Multi-pass to find pivot + for (int pass = NUM_PASSES - 1; pass >= 0; pass--) { + int start_bit = pass * RADIX_BITS; + + // Clear histogram + for (int i = threadIdx.x; i < RADIX_SIZE; i += BLOCK_THREADS) { + shared_hist[i] = 0; + } + __syncthreads(); + + // Build histogram + bool is_contiguous = (in_stride == 1); + if (is_contiguous) { + for (int i = threadIdx.x; i < n; i += BLOCK_THREADS) { + ValT val = row_input[i]; + UnsignedT key = Traits::to_radix(val); + if (is_nan_value(val)) { + key = ~UnsignedT(0); + } + + if ((key & prefix_mask) == target_prefix) { + int digit = extract_digit(key, start_bit, RADIX_BITS); + atomicAdd(&shared_hist[digit], 1); + } + } + } else { + for (int i = threadIdx.x; i < n; i += BLOCK_THREADS) { + ValT val = row_input[i * in_stride]; + UnsignedT key = Traits::to_radix(val); + if (is_nan_value(val)) { + key = ~UnsignedT(0); + } + + if ((key & prefix_mask) == target_prefix) { + int digit = extract_digit(key, start_bit, RADIX_BITS); + atomicAdd(&shared_hist[digit], 1); + } + } + } + __syncthreads(); + + // Find target bin + if (threadIdx.x == 0) { + int cumsum = 0; + int target_bin = 0; + for (int bin = 0; bin < RADIX_SIZE; bin++) { + int count = shared_hist[bin]; + if (cumsum + count >= k) { + target_bin = bin; + k = k - cumsum; + break; + } + cumsum += count; + } + shared_pivot_info[0] = target_bin; + shared_pivot_info[1] = k; + } + __syncthreads(); + + int target_bin = shared_pivot_info[0]; + k = shared_pivot_info[1]; + + UnsignedT digit_mask = UnsignedT((1 << RADIX_BITS) - 1) << start_bit; + target_prefix |= UnsignedT(target_bin) << start_bit; + prefix_mask |= digit_mask; + + // Initialize counters for next phase + if (threadIdx.x == 0) { + shared_counts[0] = 0; + shared_counts[1] = 0; + } + __syncthreads(); + } + + // Count partition sizes with warp reduction + int local_less = 0, local_equal = 0; + bool is_contiguous = (in_stride == 1); + + if (is_contiguous) { + for (int i = threadIdx.x; i < n; i += BLOCK_THREADS) { + ValT val = row_input[i]; + UnsignedT key = Traits::to_radix(val); + if (is_nan_value(val)) { + key = ~UnsignedT(0); + } + if (key < target_prefix) + local_less++; + else if (key == target_prefix) + local_equal++; + } + } else { + for (int i = threadIdx.x; i < n; i += BLOCK_THREADS) { + ValT val = row_input[i * in_stride]; + UnsignedT key = Traits::to_radix(val); + if (is_nan_value(val)) { + key = ~UnsignedT(0); + } + if (key < target_prefix) + local_less++; + else if (key == target_prefix) + local_equal++; + } + } + + // Warp reduction + local_less = warp_reduce_sum(local_less); + local_equal = warp_reduce_sum(local_equal); + + // First lane of each warp aggregates to shared memory + int lane = threadIdx.x % WARP_SIZE; + if (lane == 0) { + atomicAdd(&shared_counts[0], local_less); + atomicAdd(&shared_counts[1], local_equal); + } + __syncthreads(); + + // Read final counts + int less_count = shared_counts[0]; + int equal_count = shared_counts[1]; + + // Initialize output counters + if (threadIdx.x == 0) { + shared_output_counters[0] = 0; + shared_output_counters[1] = 0; + shared_output_counters[2] = 0; + } + __syncthreads(); + + // Output partitioned elements + if (is_contiguous && out_stride == 1) { + // Fast path: both input and output are contiguous + for (int i = threadIdx.x; i < n; i += BLOCK_THREADS) { + ValT val = row_input[i]; + UnsignedT key = Traits::to_radix(val); + if (is_nan_value(val)) { + key = ~UnsignedT(0); + } + + int pos; + if (key < target_prefix) { + pos = atomicAdd(&shared_output_counters[0], 1); + } else if (key == target_prefix) { + pos = less_count + atomicAdd(&shared_output_counters[1], 1); + } else { + pos = less_count + equal_count + atomicAdd(&shared_output_counters[2], 1); + } + + if (ARG_PARTITION) { + row_output[pos] = i; + } else { + row_output[pos] = val; + } + } + } else { + for (int i = threadIdx.x; i < n; i += BLOCK_THREADS) { + ValT val = row_input[i * in_stride]; + UnsignedT key = Traits::to_radix(val); + if (is_nan_value(val)) { + key = ~UnsignedT(0); + } + + int pos; + if (key < target_prefix) { + pos = atomicAdd(&shared_output_counters[0], 1); + } else if (key == target_prefix) { + pos = less_count + atomicAdd(&shared_output_counters[1], 1); + } else { + pos = less_count + equal_count + atomicAdd(&shared_output_counters[2], 1); + } + + if (ARG_PARTITION) { + row_output[pos * out_stride] = i; + } else { + row_output[pos * out_stride] = val; + } + } + } +} + +template < + typename ValT, + typename OutT, + bool ARG_PARTITION, + int BLOCK_THREADS> +__global__ void radix_select_large_streaming_nc_kernel( + const ValT* input, + OutT* output, + int n, + int kth, + int in_stride, + int out_stride, + const __grid_constant__ Shape nc_shape, + const __grid_constant__ Strides in_nc_strides, + const __grid_constant__ Strides out_nc_strides, + int nc_dim) { + using Traits = RadixTraits; + using UnsignedT = typename Traits::UnsignedT; + constexpr int NUM_PASSES = (Traits::BITS + RADIX_BITS - 1) / RADIX_BITS; + + int row = blockIdx.y; + int64_t in_block_idx = + elem_to_loc(int64_t(row), nc_shape.data(), in_nc_strides.data(), nc_dim); + int64_t out_block_idx = elem_to_loc( + int64_t(row), nc_shape.data(), out_nc_strides.data(), nc_dim); + const ValT* row_input = input + in_block_idx; + OutT* row_output = output + out_block_idx; + + __shared__ int shared_hist[RADIX_SIZE]; + __shared__ int shared_pivot_info[2]; + __shared__ int shared_counts[2]; + __shared__ int shared_output_counters[3]; + + int k = kth + 1; + UnsignedT target_prefix = 0; + UnsignedT prefix_mask = 0; + + for (int pass = NUM_PASSES - 1; pass >= 0; pass--) { + int start_bit = pass * RADIX_BITS; + + for (int i = threadIdx.x; i < RADIX_SIZE; i += BLOCK_THREADS) { + shared_hist[i] = 0; + } + __syncthreads(); + + for (int i = threadIdx.x; i < n; i += BLOCK_THREADS) { + ValT val = row_input[i * in_stride]; + UnsignedT key = Traits::to_radix(val); + if (is_nan_value(val)) { + key = ~UnsignedT(0); + } + if ((key & prefix_mask) == target_prefix) { + int digit = extract_digit(key, start_bit, RADIX_BITS); + atomicAdd(&shared_hist[digit], 1); + } + } + __syncthreads(); + + if (threadIdx.x == 0) { + int cumsum = 0; + int target_bin = 0; + for (int bin = 0; bin < RADIX_SIZE; bin++) { + int count = shared_hist[bin]; + if (cumsum + count >= k) { + target_bin = bin; + k = k - cumsum; + break; + } + cumsum += count; + } + shared_pivot_info[0] = target_bin; + shared_pivot_info[1] = k; + } + __syncthreads(); + + int target_bin = shared_pivot_info[0]; + k = shared_pivot_info[1]; + + UnsignedT digit_mask = UnsignedT((1 << RADIX_BITS) - 1) << start_bit; + target_prefix |= UnsignedT(target_bin) << start_bit; + prefix_mask |= digit_mask; + + if (threadIdx.x == 0) { + shared_counts[0] = 0; + shared_counts[1] = 0; + } + __syncthreads(); + } + + int local_less = 0, local_equal = 0; + for (int i = threadIdx.x; i < n; i += BLOCK_THREADS) { + ValT val = row_input[i * in_stride]; + UnsignedT key = Traits::to_radix(val); + if (is_nan_value(val)) { + key = ~UnsignedT(0); + } + if (key < target_prefix) + local_less++; + else if (key == target_prefix) + local_equal++; + } + + local_less = warp_reduce_sum(local_less); + local_equal = warp_reduce_sum(local_equal); + + int lane = threadIdx.x % WARP_SIZE; + if (lane == 0) { + atomicAdd(&shared_counts[0], local_less); + atomicAdd(&shared_counts[1], local_equal); + } + __syncthreads(); + + int less_count = shared_counts[0]; + int equal_count = shared_counts[1]; + + if (threadIdx.x == 0) { + shared_output_counters[0] = 0; + shared_output_counters[1] = 0; + shared_output_counters[2] = 0; + } + __syncthreads(); + + for (int i = threadIdx.x; i < n; i += BLOCK_THREADS) { + ValT val = row_input[i * in_stride]; + UnsignedT key = Traits::to_radix(val); + if (is_nan_value(val)) { + key = ~UnsignedT(0); + } + + int pos; + if (key < target_prefix) { + pos = atomicAdd(&shared_output_counters[0], 1); + } else if (key == target_prefix) { + pos = less_count + atomicAdd(&shared_output_counters[1], 1); + } else { + pos = less_count + equal_count + atomicAdd(&shared_output_counters[2], 1); + } + + if (ARG_PARTITION) { + row_output[pos * out_stride] = i; + } else { + row_output[pos * out_stride] = val; + } + } +} + +} // namespace mlx::core::cu diff --git a/mlx/backend/cuda/sort.cu b/mlx/backend/cuda/sort.cu index 43756f7078..4df85df992 100644 --- a/mlx/backend/cuda/sort.cu +++ b/mlx/backend/cuda/sort.cu @@ -6,6 +6,7 @@ #include "mlx/backend/cuda/device.h" #include "mlx/backend/cuda/device/fp16_math.cuh" +#include "mlx/backend/cuda/device/radix_select.cuh" #include "mlx/backend/cuda/kernel_utils.cuh" #include "mlx/backend/gpu/copy.h" #include "mlx/dtype_utils.h" @@ -1040,6 +1041,279 @@ void gpu_merge_sort( return single_block_sort(s, in, out, axis, bn, argsort); } +/////////////////////////////////////////////////////////////////////////////// +// Radix partition functions +/////////////////////////////////////////////////////////////////////////////// + +void gpu_radix_partition_small( + const Stream& s, + const array& in, + array& out, + int axis, + int kth, + bool arg_partition) { + int n_rows = in.size() / in.shape(axis); + + auto in_nc_str = in.strides(); + in_nc_str.erase(in_nc_str.begin() + axis); + + auto out_nc_str = out.strides(); + out_nc_str.erase(out_nc_str.begin() + axis); + + auto nc_shape = in.shape(); + nc_shape.erase(nc_shape.begin() + axis); + + int nc_dim = nc_shape.size(); + + int size_sorted_axis = in.shape(axis); + int64_t in_stride_sorted_axis = in.strides()[axis]; + int64_t out_stride_sorted_axis = out.strides()[axis]; + + bool contiguous = in.flags().contiguous; + auto check_strides = [](const array& x, int64_t sort_stride) { + int64_t min_stride = + *std::min_element(x.strides().begin(), x.strides().end()); + int64_t max_stride = + *std::max_element(x.strides().begin(), x.strides().end()); + return sort_stride == min_stride || sort_stride == max_stride; + }; + contiguous &= check_strides(in, in_stride_sorted_axis); + contiguous &= check_strides(out, out_stride_sorted_axis); + + auto& encoder = cu::get_command_encoder(s); + out.set_data(cu::malloc_async(out.nbytes(), encoder)); + encoder.set_input_array(in); + encoder.set_output_array(out); + + dispatch_all_types(in.dtype(), [&](auto type_tag) { + using CTYPE = MLX_GET_TYPE(type_tag); + if constexpr (!std::is_same_v) { + using ValT = cuda_type_t; + + constexpr int BLOCK_THREADS = 256; + constexpr int ITEMS_PER_THREAD = 8; + + dim3 grid(1, n_rows, 1); + dim3 block(BLOCK_THREADS, 1, 1); + + dispatch_bool(arg_partition, [&](auto arg_tag) { + constexpr bool ARG_PARTITION = decltype(arg_tag)::value; + using OutT = std::conditional_t; + + if (contiguous) { + auto kernel = cu::radix_select_small_kernel< + ValT, + OutT, + ARG_PARTITION, + BLOCK_THREADS, + ITEMS_PER_THREAD>; + + int64_t in_stride_segment_axis = INT64_MAX; + int64_t out_stride_segment_axis = INT64_MAX; + for (size_t i = 0; i < nc_shape.size(); i++) { + if (nc_shape[i] == 1) { + continue; + } + if (in_nc_str[i] > INT32_MAX || out_nc_str[i] > INT32_MAX) { + throw std::runtime_error( + "[Partition::eval_gpu] Stride too large."); + } + in_stride_segment_axis = + std::min(in_stride_segment_axis, in_nc_str[i]); + out_stride_segment_axis = + std::min(out_stride_segment_axis, out_nc_str[i]); + } + + encoder.add_kernel_node( + kernel, + grid, + block, + 0, + gpu_ptr(in), + gpu_ptr(out), + kth, + size_sorted_axis, + static_cast(in_stride_sorted_axis), + static_cast(out_stride_sorted_axis), + static_cast(in_stride_segment_axis), + static_cast(out_stride_segment_axis)); + } else { + auto kernel = cu::radix_select_small_nc_kernel< + ValT, + OutT, + ARG_PARTITION, + BLOCK_THREADS, + ITEMS_PER_THREAD>; + + auto nc_shape_param = const_param(nc_shape); + auto in_nc_strides_param = const_param(in_nc_str); + auto out_nc_strides_param = const_param(out_nc_str); + + encoder.add_kernel_node( + kernel, + grid, + block, + 0, + gpu_ptr(in), + gpu_ptr(out), + kth, + size_sorted_axis, + static_cast(in_stride_sorted_axis), + static_cast(out_stride_sorted_axis), + nc_shape_param, + in_nc_strides_param, + out_nc_strides_param, + nc_dim); + } + }); + } else { + throw std::runtime_error( + "CUDA backend does not support sorting complex numbers"); + } + }); +} + +void gpu_radix_partition_large( + const Stream& s, + const array& in, + array& out, + int axis, + int kth, + bool arg_partition) { + int n_rows = in.size() / in.shape(axis); + + int size_sorted_axis = in.shape(axis); + int64_t in_stride_sorted_axis = in.strides()[axis]; + int64_t out_stride_sorted_axis = out.strides()[axis]; + + auto in_nc_str = in.strides(); + in_nc_str.erase(in_nc_str.begin() + axis); + + auto out_nc_str = out.strides(); + out_nc_str.erase(out_nc_str.begin() + axis); + + auto nc_shape = in.shape(); + nc_shape.erase(nc_shape.begin() + axis); + + int nc_dim = nc_shape.size(); + + bool contiguous = in.flags().contiguous; + auto check_strides = [](const array& x, int64_t sort_stride) { + int64_t min_stride = + *std::min_element(x.strides().begin(), x.strides().end()); + int64_t max_stride = + *std::max_element(x.strides().begin(), x.strides().end()); + return sort_stride == min_stride || sort_stride == max_stride; + }; + contiguous &= check_strides(in, in_stride_sorted_axis); + contiguous &= check_strides(out, out_stride_sorted_axis); + + auto& encoder = cu::get_command_encoder(s); + out.set_data(cu::malloc_async(out.nbytes(), encoder)); + + encoder.set_input_array(in); + encoder.set_output_array(out); + + dispatch_all_types(in.dtype(), [&](auto type_tag) { + using CTYPE = MLX_GET_TYPE(type_tag); + if constexpr (!std::is_same_v) { + using ValT = cuda_type_t; + + constexpr int BLOCK_THREADS = 256; + + dim3 grid(1, n_rows, 1); + dim3 block(BLOCK_THREADS, 1, 1); + + dispatch_bool(arg_partition, [&](auto arg_tag) { + constexpr bool ARG_PARTITION = decltype(arg_tag)::value; + using OutT = std::conditional_t; + + if (contiguous) { + auto kernel = cu::radix_select_large_streaming_kernel< + ValT, + OutT, + ARG_PARTITION, + BLOCK_THREADS>; + + int64_t in_stride_segment_axis = INT64_MAX; + int64_t out_stride_segment_axis = INT64_MAX; + for (size_t i = 0; i < nc_shape.size(); i++) { + if (nc_shape[i] == 1) { + continue; + } + in_stride_segment_axis = + std::min(in_stride_segment_axis, in_nc_str[i]); + out_stride_segment_axis = + std::min(out_stride_segment_axis, out_nc_str[i]); + } + + encoder.add_kernel_node( + kernel, + grid, + block, + 0, + gpu_ptr(in), + gpu_ptr(out), + size_sorted_axis, + kth, + static_cast(in_stride_sorted_axis), + static_cast(out_stride_sorted_axis), + static_cast(in_stride_segment_axis), + static_cast(out_stride_segment_axis)); + } else { + auto kernel = cu::radix_select_large_streaming_nc_kernel< + ValT, + OutT, + ARG_PARTITION, + BLOCK_THREADS>; + + auto nc_shape_param = const_param(nc_shape); + auto in_nc_strides_param = const_param(in_nc_str); + auto out_nc_strides_param = const_param(out_nc_str); + + encoder.add_kernel_node( + kernel, + grid, + block, + 0, + gpu_ptr(in), + gpu_ptr(out), + size_sorted_axis, + kth, + static_cast(in_stride_sorted_axis), + static_cast(out_stride_sorted_axis), + nc_shape_param, + in_nc_strides_param, + out_nc_strides_param, + nc_dim); + } + }); + } else { + throw std::runtime_error( + "CUDA backend does not support sorting complex numbers"); + } + }); +} + +void gpu_radix_partition( + const Stream& s, + const array& in, + array& out, + int axis_, + int kth_, + bool arg_partition) { + int axis = axis_ < 0 ? axis_ + in.ndim() : axis_; + int size_sorted_axis = in.shape(axis); + int kth = kth_ < 0 ? kth_ + size_sorted_axis : kth_; + + // Dispatch based on size + if (size_sorted_axis <= 2048) { + return gpu_radix_partition_small(s, in, out, axis, kth, arg_partition); + } else { + return gpu_radix_partition_large(s, in, out, axis, kth, arg_partition); + } +} + void gpu_sort( const Stream& s, const array& in, @@ -1066,12 +1340,12 @@ void Sort::eval_gpu(const std::vector& inputs, array& out) { void ArgPartition::eval_gpu(const std::vector& inputs, array& out) { nvtx3::scoped_range r("ArgPartition::eval_gpu"); - gpu_sort(stream(), inputs[0], out, axis_, true); + gpu_radix_partition(stream(), inputs[0], out, axis_, kth_, true); } void Partition::eval_gpu(const std::vector& inputs, array& out) { nvtx3::scoped_range r("Partition::eval_gpu"); - gpu_sort(stream(), inputs[0], out, axis_, false); + gpu_radix_partition(stream(), inputs[0], out, axis_, kth_, false); } } // namespace mlx::core \ No newline at end of file From 5d4dd875c18a789fcecbea1891908d67cdf0fbd8 Mon Sep 17 00:00:00 2001 From: Lyxot Date: Mon, 9 Feb 2026 01:47:42 +0800 Subject: [PATCH 02/23] update fallback strategy --- mlx/backend/cuda/sort.cu | 79 +++++++++++++++++++++++++++++++++++++++- 1 file changed, 78 insertions(+), 1 deletion(-) diff --git a/mlx/backend/cuda/sort.cu b/mlx/backend/cuda/sort.cu index 4df85df992..c61a4fceae 100644 --- a/mlx/backend/cuda/sort.cu +++ b/mlx/backend/cuda/sort.cu @@ -9,6 +9,7 @@ #include "mlx/backend/cuda/device/radix_select.cuh" #include "mlx/backend/cuda/kernel_utils.cuh" #include "mlx/backend/gpu/copy.h" +#include "mlx/dtype.h" #include "mlx/dtype_utils.h" #include "mlx/primitives.h" @@ -1295,6 +1296,77 @@ void gpu_radix_partition_large( }); } +struct FallbackLinearModel { + int max_rows; + int axis_intercept; + int axis_slope; + int axis_min; +}; + +int axis_threshold(const FallbackLinearModel& model, int n_rows) { + return std::max(model.axis_min, model.axis_intercept + model.axis_slope * n_rows); +} + +bool should_use_merge_sort_fallback_model( + const FallbackLinearModel& model, + int n_rows, + int size_sorted_axis) { + if (n_rows <= 0 || n_rows > model.max_rows) { + return false; + } + return size_sorted_axis >= axis_threshold(model, n_rows); +} + +bool is_integer_dtype(Dtype dtype) { + return dtype == int8 || dtype == int16 || dtype == int32 || dtype == int64 || + dtype == uint8 || dtype == uint16 || dtype == uint32 || dtype == uint64; +} + +FallbackLinearModel float_fallback_model(int dtype_size) { + return { + 8, + 24576, + 16384 / dtype_size, + 102400 / dtype_size, + }; +} + +FallbackLinearModel integer_fallback_model(int dtype_size) { + return { + dtype_size == 8 ? 12 : 6, + 53248 / dtype_size, + 16384 / dtype_size, + 8192 / dtype_size, + }; +} + +bool should_use_merge_sort_fallback( + Dtype dtype, + int n_rows, + int size_sorted_axis) +{ + int dtype_size = size_of(dtype); + + if (dtype == float32) { + // Use fallback model or for small axis always use merge sort + return should_use_merge_sort_fallback_model(float_fallback_model(dtype_size), n_rows, size_sorted_axis) || + size_sorted_axis <= 512 || (n_rows <= 48 && size_sorted_axis <= 2048); + } else if (dtype == bfloat16 || dtype == float16) { + // Use fallback model or when batch is large and axis is small, merge sort wins + return should_use_merge_sort_fallback_model(float_fallback_model(dtype_size), n_rows, size_sorted_axis) || + (n_rows >= 512 && size_sorted_axis <= 512); + } else if (dtype == float64) { + // float64 is not supported on the GPU + return true; + } else if (is_integer_dtype(dtype) || dtype == bool_) { + // Use fallback model for all integer types and bool + return should_use_merge_sort_fallback_model(integer_fallback_model(dtype_size), n_rows, size_sorted_axis); + } else { + // Fallback for unknown or unsupported types + return true; + } +} + void gpu_radix_partition( const Stream& s, const array& in, @@ -1304,8 +1376,13 @@ void gpu_radix_partition( bool arg_partition) { int axis = axis_ < 0 ? axis_ + in.ndim() : axis_; int size_sorted_axis = in.shape(axis); + int n_rows = in.size() / size_sorted_axis; int kth = kth_ < 0 ? kth_ + size_sorted_axis : kth_; + if (should_use_merge_sort_fallback(in.dtype(), n_rows, size_sorted_axis)) { + return gpu_merge_sort(s, in, out, axis, arg_partition); + } + // Dispatch based on size if (size_sorted_axis <= 2048) { return gpu_radix_partition_small(s, in, out, axis, kth, arg_partition); @@ -1348,4 +1425,4 @@ void Partition::eval_gpu(const std::vector& inputs, array& out) { gpu_radix_partition(stream(), inputs[0], out, axis_, kth_, false); } -} // namespace mlx::core \ No newline at end of file +} // namespace mlx::core From 7fbf045e42cb3bf7c30efc1e2114fee27d27826b Mon Sep 17 00:00:00 2001 From: Lyxot Date: Mon, 9 Feb 2026 02:07:59 +0800 Subject: [PATCH 03/23] format code --- benchmarks/python/benchmark_radix_select.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/benchmarks/python/benchmark_radix_select.py b/benchmarks/python/benchmark_radix_select.py index 078d484cb1..858d21fef1 100644 --- a/benchmarks/python/benchmark_radix_select.py +++ b/benchmarks/python/benchmark_radix_select.py @@ -140,17 +140,15 @@ def main(): for dtype, dtype_name in dtypes: print(f"\nDtype: {dtype_name}") - print( - f"{'Config':<25} {'ArgPartition':>14} {'ArgSort':>12} {'Speedup':>10}" - ) + print(f"{'Config':<25} {'ArgPartition':>14} {'ArgSort':>12} {'Speedup':>10}") print("-" * 80) for b, v, k in configs: try: - argpart_ms = benchmark_argpartition(b, v, k, dtype, warmup=3, iterations=50) - argsort_ms = benchmark_argsort( - b, v, dtype, warmup=3, iterations=50 + argpart_ms = benchmark_argpartition( + b, v, k, dtype, warmup=3, iterations=50 ) + argsort_ms = benchmark_argsort(b, v, dtype, warmup=3, iterations=50) speedup = argsort_ms / argpart_ms config_str = f"b={b}, v={v}, k={k}" print( From 7c264f48ed40b3e72e660e8dbdfd32beff72927f Mon Sep 17 00:00:00 2001 From: Lyxot Date: Mon, 9 Feb 2026 03:07:36 +0800 Subject: [PATCH 04/23] cuda: fix radix partition row indexing and 64-bit strides Fix two correctness issues in CUDA radix partition/argpartition: - In the large contiguous radix path, stop deriving row bases from `row * min(non-axis stride)` and compute row offsets with `elem_to_loc(...)` using non-axis shape/strides (matching merge-sort indexing behavior). - Keep stride arguments 64-bit end-to-end in radix-select kernels and launches (remove narrowing to `int` and related `INT32_MAX` guard). This fixes incorrect row addressing for valid contiguous non-linear layouts (e.g. column-major with axis=0) and avoids silent misindexing on large strides. --- mlx/backend/cuda/device/radix_select.cuh | 34 ++++++++++-------- mlx/backend/cuda/sort.cu | 44 +++++++++--------------- 2 files changed, 37 insertions(+), 41 deletions(-) diff --git a/mlx/backend/cuda/device/radix_select.cuh b/mlx/backend/cuda/device/radix_select.cuh index c87e2e3b2c..1b7d1c393e 100644 --- a/mlx/backend/cuda/device/radix_select.cuh +++ b/mlx/backend/cuda/device/radix_select.cuh @@ -270,10 +270,10 @@ __global__ void radix_select_small_kernel( OutT* output, int kth, int n, - int in_stride, - int out_stride, - int in_segment_stride, - int out_segment_stride) { + int64_t in_stride, + int64_t out_stride, + int64_t in_segment_stride, + int64_t out_segment_stride) { using Traits = RadixTraits; using UnsignedT = typename Traits::UnsignedT; @@ -423,8 +423,8 @@ __global__ void radix_select_small_nc_kernel( OutT* output, int kth, int n, - int in_stride, - int out_stride, + int64_t in_stride, + int64_t out_stride, const __grid_constant__ Shape nc_shape, const __grid_constant__ Strides in_nc_strides, const __grid_constant__ Strides out_nc_strides, @@ -572,17 +572,23 @@ __global__ void radix_select_large_streaming_kernel( OutT* output, int n, int kth, - int in_stride, - int out_stride, - int in_segment_stride, - int out_segment_stride) { + int64_t in_stride, + int64_t out_stride, + const __grid_constant__ Shape nc_shape, + const __grid_constant__ Strides in_nc_strides, + const __grid_constant__ Strides out_nc_strides, + int nc_dim) { using Traits = RadixTraits; using UnsignedT = typename Traits::UnsignedT; constexpr int NUM_PASSES = (Traits::BITS + RADIX_BITS - 1) / RADIX_BITS; int row = blockIdx.y; - const ValT* row_input = input + row * in_segment_stride; - OutT* row_output = output + row * out_segment_stride; + int64_t in_block_idx = + elem_to_loc(int64_t(row), nc_shape.data(), in_nc_strides.data(), nc_dim); + int64_t out_block_idx = elem_to_loc( + int64_t(row), nc_shape.data(), out_nc_strides.data(), nc_dim); + const ValT* row_input = input + in_block_idx; + OutT* row_output = output + out_block_idx; // Shared memory __shared__ int shared_hist[RADIX_SIZE]; @@ -783,8 +789,8 @@ __global__ void radix_select_large_streaming_nc_kernel( OutT* output, int n, int kth, - int in_stride, - int out_stride, + int64_t in_stride, + int64_t out_stride, const __grid_constant__ Shape nc_shape, const __grid_constant__ Strides in_nc_strides, const __grid_constant__ Strides out_nc_strides, diff --git a/mlx/backend/cuda/sort.cu b/mlx/backend/cuda/sort.cu index c61a4fceae..8b2be4bd74 100644 --- a/mlx/backend/cuda/sort.cu +++ b/mlx/backend/cuda/sort.cu @@ -1115,10 +1115,6 @@ void gpu_radix_partition_small( if (nc_shape[i] == 1) { continue; } - if (in_nc_str[i] > INT32_MAX || out_nc_str[i] > INT32_MAX) { - throw std::runtime_error( - "[Partition::eval_gpu] Stride too large."); - } in_stride_segment_axis = std::min(in_stride_segment_axis, in_nc_str[i]); out_stride_segment_axis = @@ -1134,10 +1130,10 @@ void gpu_radix_partition_small( gpu_ptr(out), kth, size_sorted_axis, - static_cast(in_stride_sorted_axis), - static_cast(out_stride_sorted_axis), - static_cast(in_stride_segment_axis), - static_cast(out_stride_segment_axis)); + in_stride_sorted_axis, + out_stride_sorted_axis, + in_stride_segment_axis, + out_stride_segment_axis); } else { auto kernel = cu::radix_select_small_nc_kernel< ValT, @@ -1159,8 +1155,8 @@ void gpu_radix_partition_small( gpu_ptr(out), kth, size_sorted_axis, - static_cast(in_stride_sorted_axis), - static_cast(out_stride_sorted_axis), + in_stride_sorted_axis, + out_stride_sorted_axis, nc_shape_param, in_nc_strides_param, out_nc_strides_param, @@ -1236,17 +1232,9 @@ void gpu_radix_partition_large( ARG_PARTITION, BLOCK_THREADS>; - int64_t in_stride_segment_axis = INT64_MAX; - int64_t out_stride_segment_axis = INT64_MAX; - for (size_t i = 0; i < nc_shape.size(); i++) { - if (nc_shape[i] == 1) { - continue; - } - in_stride_segment_axis = - std::min(in_stride_segment_axis, in_nc_str[i]); - out_stride_segment_axis = - std::min(out_stride_segment_axis, out_nc_str[i]); - } + auto nc_shape_param = const_param(nc_shape); + auto in_nc_strides_param = const_param(in_nc_str); + auto out_nc_strides_param = const_param(out_nc_str); encoder.add_kernel_node( kernel, @@ -1257,10 +1245,12 @@ void gpu_radix_partition_large( gpu_ptr(out), size_sorted_axis, kth, - static_cast(in_stride_sorted_axis), - static_cast(out_stride_sorted_axis), - static_cast(in_stride_segment_axis), - static_cast(out_stride_segment_axis)); + in_stride_sorted_axis, + out_stride_sorted_axis, + nc_shape_param, + in_nc_strides_param, + out_nc_strides_param, + nc_dim); } else { auto kernel = cu::radix_select_large_streaming_nc_kernel< ValT, @@ -1281,8 +1271,8 @@ void gpu_radix_partition_large( gpu_ptr(out), size_sorted_axis, kth, - static_cast(in_stride_sorted_axis), - static_cast(out_stride_sorted_axis), + in_stride_sorted_axis, + out_stride_sorted_axis, nc_shape_param, in_nc_strides_param, out_nc_strides_param, From 1764781345cd4d3f8a7bf9f44cc7b24f61ca575a Mon Sep 17 00:00:00 2001 From: Lyxot Date: Mon, 9 Feb 2026 04:27:52 +0800 Subject: [PATCH 05/23] cuda: remove radix-partition rank cap and use dynamic NC metadata Eliminate MAX_NDIM-based rank limits in CUDA radix partition/argpartition by switching radix kernels from fixed-size __grid_constant__ shape/stride params to dynamic device pointers for non-axis metadata. Changes: - Update radix kernels to take dynamic NC metadata pointers: - radix_select_small_nc_kernel - radix_select_large_streaming_kernel - radix_select_large_streaming_nc_kernel - In gpu_radix_partition_small/gpu_radix_partition_large: - allocate device buffers for nc_shape/in_nc_strides/out_nc_strides - copy host metadata with cudaMemcpyAsync - pass pointers into kernel launches - Remove MAX_NDIM-dependent routing so high-rank tensors can still use radix partition path. - Keep stride handling 64-bit end-to-end in radix launches/kernels. Also slightly widens fallback-model threshold range (without changing max_rows). --- mlx/backend/cuda/device/radix_select.cuh | 71 +++++------- mlx/backend/cuda/sort.cu | 142 +++++++++++++++++------ 2 files changed, 138 insertions(+), 75 deletions(-) diff --git a/mlx/backend/cuda/device/radix_select.cuh b/mlx/backend/cuda/device/radix_select.cuh index 1b7d1c393e..74cc9e38c3 100644 --- a/mlx/backend/cuda/device/radix_select.cuh +++ b/mlx/backend/cuda/device/radix_select.cuh @@ -2,11 +2,11 @@ #pragma once -#include "mlx/backend/cuda/device/utils.cuh" #include #include #include #include +#include "mlx/backend/cuda/device/utils.cuh" namespace mlx::core::cu { @@ -223,10 +223,8 @@ struct RadixTraits { }; template -__device__ __forceinline__ int extract_digit( - UnsignedT val, - int start_bit, - int num_bits) { +__device__ __forceinline__ int +extract_digit(UnsignedT val, int start_bit, int num_bits) { return (val >> start_bit) & ((1 << num_bits) - 1); } @@ -375,8 +373,7 @@ __global__ void radix_select_small_kernel( if (ARG_PARTITION) { row_output[pos * out_stride] = shared_idxs[i]; } else { - row_output[pos * out_stride] = - row_input[shared_idxs[i] * in_stride]; + row_output[pos * out_stride] = row_input[shared_idxs[i] * in_stride]; } } } @@ -390,8 +387,7 @@ __global__ void radix_select_small_kernel( if (ARG_PARTITION) { row_output[pos * out_stride] = shared_idxs[i]; } else { - row_output[pos * out_stride] = - row_input[shared_idxs[i] * in_stride]; + row_output[pos * out_stride] = row_input[shared_idxs[i] * in_stride]; } } } @@ -405,8 +401,7 @@ __global__ void radix_select_small_kernel( if (ARG_PARTITION) { row_output[pos * out_stride] = shared_idxs[i]; } else { - row_output[pos * out_stride] = - row_input[shared_idxs[i] * in_stride]; + row_output[pos * out_stride] = row_input[shared_idxs[i] * in_stride]; } } } @@ -425,9 +420,9 @@ __global__ void radix_select_small_nc_kernel( int n, int64_t in_stride, int64_t out_stride, - const __grid_constant__ Shape nc_shape, - const __grid_constant__ Strides in_nc_strides, - const __grid_constant__ Strides out_nc_strides, + const int32_t* nc_shape, + const int64_t* in_nc_strides, + const int64_t* out_nc_strides, int nc_dim) { using Traits = RadixTraits; using UnsignedT = typename Traits::UnsignedT; @@ -442,9 +437,9 @@ __global__ void radix_select_small_nc_kernel( int row = blockIdx.y; int64_t in_block_idx = - elem_to_loc(int64_t(row), nc_shape.data(), in_nc_strides.data(), nc_dim); - int64_t out_block_idx = elem_to_loc( - int64_t(row), nc_shape.data(), out_nc_strides.data(), nc_dim); + elem_to_loc(int64_t(row), nc_shape, in_nc_strides, nc_dim); + int64_t out_block_idx = + elem_to_loc(int64_t(row), nc_shape, out_nc_strides, nc_dim); const ValT* row_input = input + in_block_idx; OutT* row_output = output + out_block_idx; @@ -562,11 +557,7 @@ __global__ void radix_select_small_nc_kernel( // Large array streaming kernel (multi-pass, in-place) /////////////////////////////////////////////////////////////////////////////// -template < - typename ValT, - typename OutT, - bool ARG_PARTITION, - int BLOCK_THREADS> +template __global__ void radix_select_large_streaming_kernel( const ValT* input, OutT* output, @@ -574,9 +565,9 @@ __global__ void radix_select_large_streaming_kernel( int kth, int64_t in_stride, int64_t out_stride, - const __grid_constant__ Shape nc_shape, - const __grid_constant__ Strides in_nc_strides, - const __grid_constant__ Strides out_nc_strides, + const int32_t* nc_shape, + const int64_t* in_nc_strides, + const int64_t* out_nc_strides, int nc_dim) { using Traits = RadixTraits; using UnsignedT = typename Traits::UnsignedT; @@ -584,9 +575,9 @@ __global__ void radix_select_large_streaming_kernel( int row = blockIdx.y; int64_t in_block_idx = - elem_to_loc(int64_t(row), nc_shape.data(), in_nc_strides.data(), nc_dim); - int64_t out_block_idx = elem_to_loc( - int64_t(row), nc_shape.data(), out_nc_strides.data(), nc_dim); + elem_to_loc(int64_t(row), nc_shape, in_nc_strides, nc_dim); + int64_t out_block_idx = + elem_to_loc(int64_t(row), nc_shape, out_nc_strides, nc_dim); const ValT* row_input = input + in_block_idx; OutT* row_output = output + out_block_idx; @@ -744,7 +735,8 @@ __global__ void radix_select_large_streaming_kernel( } else if (key == target_prefix) { pos = less_count + atomicAdd(&shared_output_counters[1], 1); } else { - pos = less_count + equal_count + atomicAdd(&shared_output_counters[2], 1); + pos = + less_count + equal_count + atomicAdd(&shared_output_counters[2], 1); } if (ARG_PARTITION) { @@ -767,7 +759,8 @@ __global__ void radix_select_large_streaming_kernel( } else if (key == target_prefix) { pos = less_count + atomicAdd(&shared_output_counters[1], 1); } else { - pos = less_count + equal_count + atomicAdd(&shared_output_counters[2], 1); + pos = + less_count + equal_count + atomicAdd(&shared_output_counters[2], 1); } if (ARG_PARTITION) { @@ -779,11 +772,7 @@ __global__ void radix_select_large_streaming_kernel( } } -template < - typename ValT, - typename OutT, - bool ARG_PARTITION, - int BLOCK_THREADS> +template __global__ void radix_select_large_streaming_nc_kernel( const ValT* input, OutT* output, @@ -791,9 +780,9 @@ __global__ void radix_select_large_streaming_nc_kernel( int kth, int64_t in_stride, int64_t out_stride, - const __grid_constant__ Shape nc_shape, - const __grid_constant__ Strides in_nc_strides, - const __grid_constant__ Strides out_nc_strides, + const int32_t* nc_shape, + const int64_t* in_nc_strides, + const int64_t* out_nc_strides, int nc_dim) { using Traits = RadixTraits; using UnsignedT = typename Traits::UnsignedT; @@ -801,9 +790,9 @@ __global__ void radix_select_large_streaming_nc_kernel( int row = blockIdx.y; int64_t in_block_idx = - elem_to_loc(int64_t(row), nc_shape.data(), in_nc_strides.data(), nc_dim); - int64_t out_block_idx = elem_to_loc( - int64_t(row), nc_shape.data(), out_nc_strides.data(), nc_dim); + elem_to_loc(int64_t(row), nc_shape, in_nc_strides, nc_dim); + int64_t out_block_idx = + elem_to_loc(int64_t(row), nc_shape, out_nc_strides, nc_dim); const ValT* row_input = input + in_block_idx; OutT* row_output = output + out_block_idx; diff --git a/mlx/backend/cuda/sort.cu b/mlx/backend/cuda/sort.cu index 8b2be4bd74..ad648291d3 100644 --- a/mlx/backend/cuda/sort.cu +++ b/mlx/backend/cuda/sort.cu @@ -1086,6 +1086,47 @@ void gpu_radix_partition_small( encoder.set_input_array(in); encoder.set_output_array(out); + const int32_t* nc_shape_ptr = nullptr; + const int64_t* in_nc_strides_ptr = nullptr; + const int64_t* out_nc_strides_ptr = nullptr; + if (!contiguous && nc_dim > 0) { + array nc_shape_dev({nc_dim}, int32, nullptr, {}); + array in_nc_strides_dev({nc_dim}, int64, nullptr, {}); + array out_nc_strides_dev({nc_dim}, int64, nullptr, {}); + nc_shape_dev.set_data(cu::malloc_async(nc_shape_dev.nbytes(), encoder)); + in_nc_strides_dev.set_data( + cu::malloc_async(in_nc_strides_dev.nbytes(), encoder)); + out_nc_strides_dev.set_data( + cu::malloc_async(out_nc_strides_dev.nbytes(), encoder)); + + CHECK_CUDA_ERROR(cudaMemcpyAsync( + gpu_ptr(nc_shape_dev), + nc_shape.data(), + nc_shape_dev.nbytes(), + cudaMemcpyHostToDevice, + encoder.stream())); + CHECK_CUDA_ERROR(cudaMemcpyAsync( + gpu_ptr(in_nc_strides_dev), + in_nc_str.data(), + in_nc_strides_dev.nbytes(), + cudaMemcpyHostToDevice, + encoder.stream())); + CHECK_CUDA_ERROR(cudaMemcpyAsync( + gpu_ptr(out_nc_strides_dev), + out_nc_str.data(), + out_nc_strides_dev.nbytes(), + cudaMemcpyHostToDevice, + encoder.stream())); + + nc_shape_ptr = gpu_ptr(nc_shape_dev); + in_nc_strides_ptr = gpu_ptr(in_nc_strides_dev); + out_nc_strides_ptr = gpu_ptr(out_nc_strides_dev); + + encoder.add_temporary(nc_shape_dev); + encoder.add_temporary(in_nc_strides_dev); + encoder.add_temporary(out_nc_strides_dev); + } + dispatch_all_types(in.dtype(), [&](auto type_tag) { using CTYPE = MLX_GET_TYPE(type_tag); if constexpr (!std::is_same_v) { @@ -1142,10 +1183,6 @@ void gpu_radix_partition_small( BLOCK_THREADS, ITEMS_PER_THREAD>; - auto nc_shape_param = const_param(nc_shape); - auto in_nc_strides_param = const_param(in_nc_str); - auto out_nc_strides_param = const_param(out_nc_str); - encoder.add_kernel_node( kernel, grid, @@ -1157,9 +1194,9 @@ void gpu_radix_partition_small( size_sorted_axis, in_stride_sorted_axis, out_stride_sorted_axis, - nc_shape_param, - in_nc_strides_param, - out_nc_strides_param, + nc_shape_ptr, + in_nc_strides_ptr, + out_nc_strides_ptr, nc_dim); } }); @@ -1211,6 +1248,47 @@ void gpu_radix_partition_large( encoder.set_input_array(in); encoder.set_output_array(out); + const int32_t* nc_shape_ptr = nullptr; + const int64_t* in_nc_strides_ptr = nullptr; + const int64_t* out_nc_strides_ptr = nullptr; + if (nc_dim > 0) { + array nc_shape_dev({nc_dim}, int32, nullptr, {}); + array in_nc_strides_dev({nc_dim}, int64, nullptr, {}); + array out_nc_strides_dev({nc_dim}, int64, nullptr, {}); + nc_shape_dev.set_data(cu::malloc_async(nc_shape_dev.nbytes(), encoder)); + in_nc_strides_dev.set_data( + cu::malloc_async(in_nc_strides_dev.nbytes(), encoder)); + out_nc_strides_dev.set_data( + cu::malloc_async(out_nc_strides_dev.nbytes(), encoder)); + + CHECK_CUDA_ERROR(cudaMemcpyAsync( + gpu_ptr(nc_shape_dev), + nc_shape.data(), + nc_shape_dev.nbytes(), + cudaMemcpyHostToDevice, + encoder.stream())); + CHECK_CUDA_ERROR(cudaMemcpyAsync( + gpu_ptr(in_nc_strides_dev), + in_nc_str.data(), + in_nc_strides_dev.nbytes(), + cudaMemcpyHostToDevice, + encoder.stream())); + CHECK_CUDA_ERROR(cudaMemcpyAsync( + gpu_ptr(out_nc_strides_dev), + out_nc_str.data(), + out_nc_strides_dev.nbytes(), + cudaMemcpyHostToDevice, + encoder.stream())); + + nc_shape_ptr = gpu_ptr(nc_shape_dev); + in_nc_strides_ptr = gpu_ptr(in_nc_strides_dev); + out_nc_strides_ptr = gpu_ptr(out_nc_strides_dev); + + encoder.add_temporary(nc_shape_dev); + encoder.add_temporary(in_nc_strides_dev); + encoder.add_temporary(out_nc_strides_dev); + } + dispatch_all_types(in.dtype(), [&](auto type_tag) { using CTYPE = MLX_GET_TYPE(type_tag); if constexpr (!std::is_same_v) { @@ -1232,10 +1310,6 @@ void gpu_radix_partition_large( ARG_PARTITION, BLOCK_THREADS>; - auto nc_shape_param = const_param(nc_shape); - auto in_nc_strides_param = const_param(in_nc_str); - auto out_nc_strides_param = const_param(out_nc_str); - encoder.add_kernel_node( kernel, grid, @@ -1247,9 +1321,9 @@ void gpu_radix_partition_large( kth, in_stride_sorted_axis, out_stride_sorted_axis, - nc_shape_param, - in_nc_strides_param, - out_nc_strides_param, + nc_shape_ptr, + in_nc_strides_ptr, + out_nc_strides_ptr, nc_dim); } else { auto kernel = cu::radix_select_large_streaming_nc_kernel< @@ -1258,10 +1332,6 @@ void gpu_radix_partition_large( ARG_PARTITION, BLOCK_THREADS>; - auto nc_shape_param = const_param(nc_shape); - auto in_nc_strides_param = const_param(in_nc_str); - auto out_nc_strides_param = const_param(out_nc_str); - encoder.add_kernel_node( kernel, grid, @@ -1273,9 +1343,9 @@ void gpu_radix_partition_large( kth, in_stride_sorted_axis, out_stride_sorted_axis, - nc_shape_param, - in_nc_strides_param, - out_nc_strides_param, + nc_shape_ptr, + in_nc_strides_ptr, + out_nc_strides_ptr, nc_dim); } }); @@ -1294,7 +1364,8 @@ struct FallbackLinearModel { }; int axis_threshold(const FallbackLinearModel& model, int n_rows) { - return std::max(model.axis_min, model.axis_intercept + model.axis_slope * n_rows); + return std::max( + model.axis_min, model.axis_intercept + model.axis_slope * n_rows); } bool should_use_merge_sort_fallback_model( @@ -1315,42 +1386,45 @@ bool is_integer_dtype(Dtype dtype) { FallbackLinearModel float_fallback_model(int dtype_size) { return { 8, - 24576, + 22528, 16384 / dtype_size, - 102400 / dtype_size, + 94208 / dtype_size, }; } FallbackLinearModel integer_fallback_model(int dtype_size) { return { dtype_size == 8 ? 12 : 6, - 53248 / dtype_size, + 51200 / dtype_size, 16384 / dtype_size, - 8192 / dtype_size, + 6144 / dtype_size, }; } bool should_use_merge_sort_fallback( Dtype dtype, int n_rows, - int size_sorted_axis) -{ + int size_sorted_axis) { int dtype_size = size_of(dtype); if (dtype == float32) { // Use fallback model or for small axis always use merge sort - return should_use_merge_sort_fallback_model(float_fallback_model(dtype_size), n_rows, size_sorted_axis) || - size_sorted_axis <= 512 || (n_rows <= 48 && size_sorted_axis <= 2048); + return should_use_merge_sort_fallback_model( + float_fallback_model(dtype_size), n_rows, size_sorted_axis) || + size_sorted_axis <= 512 || (n_rows <= 64 && size_sorted_axis <= 4096); } else if (dtype == bfloat16 || dtype == float16) { - // Use fallback model or when batch is large and axis is small, merge sort wins - return should_use_merge_sort_fallback_model(float_fallback_model(dtype_size), n_rows, size_sorted_axis) || - (n_rows >= 512 && size_sorted_axis <= 512); + // Use fallback model or when batch is large and axis is small, merge sort + // wins + return should_use_merge_sort_fallback_model( + float_fallback_model(dtype_size), n_rows, size_sorted_axis) || + (n_rows >= 512 && size_sorted_axis <= 512); } else if (dtype == float64) { // float64 is not supported on the GPU return true; } else if (is_integer_dtype(dtype) || dtype == bool_) { // Use fallback model for all integer types and bool - return should_use_merge_sort_fallback_model(integer_fallback_model(dtype_size), n_rows, size_sorted_axis); + return should_use_merge_sort_fallback_model( + integer_fallback_model(dtype_size), n_rows, size_sorted_axis); } else { // Fallback for unknown or unsupported types return true; From 40d162f6ea6ce99ef90cce9c8904910931d2b434 Mon Sep 17 00:00:00 2001 From: Lyxot Date: Mon, 9 Feb 2026 15:02:11 +0800 Subject: [PATCH 06/23] unify radix partition kernels to reduce code duplication --- mlx/backend/cuda/device/radix_select.cuh | 312 ++--------------------- mlx/backend/cuda/sort.cu | 103 +++----- 2 files changed, 54 insertions(+), 361 deletions(-) diff --git a/mlx/backend/cuda/device/radix_select.cuh b/mlx/backend/cuda/device/radix_select.cuh index 74cc9e38c3..d34c9c2f00 100644 --- a/mlx/backend/cuda/device/radix_select.cuh +++ b/mlx/backend/cuda/device/radix_select.cuh @@ -261,6 +261,7 @@ template < typename ValT, typename OutT, bool ARG_PARTITION, + bool USE_SIMPLE_STRIDE, int BLOCK_THREADS, int ITEMS_PER_THREAD> __global__ void radix_select_small_kernel( @@ -271,7 +272,11 @@ __global__ void radix_select_small_kernel( int64_t in_stride, int64_t out_stride, int64_t in_segment_stride, - int64_t out_segment_stride) { + int64_t out_segment_stride, + const int32_t* nc_shape, + const int64_t* in_nc_strides, + const int64_t* out_nc_strides, + int nc_dim) { using Traits = RadixTraits; using UnsignedT = typename Traits::UnsignedT; @@ -285,8 +290,21 @@ __global__ void radix_select_small_kernel( __shared__ int shared_count[2]; int row = blockIdx.y; - const ValT* row_input = input + row * in_segment_stride; - OutT* row_output = output + row * out_segment_stride; + + // Compute row pointers based on addressing mode + const ValT* row_input; + OutT* row_output; + if constexpr (USE_SIMPLE_STRIDE) { + row_input = input + row * in_segment_stride; + row_output = output + row * out_segment_stride; + } else { + int64_t in_block_idx = + elem_to_loc(int64_t(row), nc_shape, in_nc_strides, nc_dim); + int64_t out_block_idx = + elem_to_loc(int64_t(row), nc_shape, out_nc_strides, nc_dim); + row_input = input + in_block_idx; + row_output = output + out_block_idx; + } int tile_n = min(n, TILE_SIZE); @@ -407,152 +425,6 @@ __global__ void radix_select_small_kernel( } } -template < - typename ValT, - typename OutT, - bool ARG_PARTITION, - int BLOCK_THREADS, - int ITEMS_PER_THREAD> -__global__ void radix_select_small_nc_kernel( - const ValT* input, - OutT* output, - int kth, - int n, - int64_t in_stride, - int64_t out_stride, - const int32_t* nc_shape, - const int64_t* in_nc_strides, - const int64_t* out_nc_strides, - int nc_dim) { - using Traits = RadixTraits; - using UnsignedT = typename Traits::UnsignedT; - - constexpr int TILE_SIZE = BLOCK_THREADS * ITEMS_PER_THREAD; - constexpr int NUM_PASSES = (Traits::BITS + RADIX_BITS - 1) / RADIX_BITS; - - __shared__ UnsignedT shared_keys[TILE_SIZE]; - __shared__ uint32_t shared_idxs[TILE_SIZE]; - __shared__ int shared_hist[RADIX_SIZE]; - __shared__ int shared_count[2]; - - int row = blockIdx.y; - int64_t in_block_idx = - elem_to_loc(int64_t(row), nc_shape, in_nc_strides, nc_dim); - int64_t out_block_idx = - elem_to_loc(int64_t(row), nc_shape, out_nc_strides, nc_dim); - const ValT* row_input = input + in_block_idx; - OutT* row_output = output + out_block_idx; - - int tile_n = min(n, TILE_SIZE); - - for (int i = threadIdx.x; i < TILE_SIZE; i += BLOCK_THREADS) { - if (i < tile_n) { - ValT val = row_input[i * in_stride]; - UnsignedT key = Traits::to_radix(val); - if (is_nan_value(val)) { - key = ~UnsignedT(0); - } - shared_keys[i] = key; - shared_idxs[i] = i; - } else { - shared_keys[i] = ~UnsignedT(0); - shared_idxs[i] = i; - } - } - __syncthreads(); - - int k = kth + 1; - UnsignedT target_prefix = 0; - UnsignedT prefix_mask = 0; - - for (int pass = NUM_PASSES - 1; pass >= 0; pass--) { - int start_bit = pass * RADIX_BITS; - - for (int i = threadIdx.x; i < RADIX_SIZE; i += BLOCK_THREADS) { - shared_hist[i] = 0; - } - __syncthreads(); - - for (int i = threadIdx.x; i < tile_n; i += BLOCK_THREADS) { - UnsignedT key = shared_keys[i]; - if ((key & prefix_mask) == target_prefix) { - int digit = extract_digit(key, start_bit, RADIX_BITS); - atomicAdd(&shared_hist[digit], 1); - } - } - __syncthreads(); - - if (threadIdx.x == 0) { - int cumsum = 0; - int target_bin = 0; - for (int bin = 0; bin < RADIX_SIZE; bin++) { - int count = shared_hist[bin]; - if (cumsum + count >= k) { - target_bin = bin; - k = k - cumsum; - break; - } - cumsum += count; - } - shared_count[0] = target_bin; - shared_count[1] = k; - } - __syncthreads(); - - int target_bin = shared_count[0]; - k = shared_count[1]; - - UnsignedT digit_mask = UnsignedT((1 << RADIX_BITS) - 1) << start_bit; - target_prefix |= UnsignedT(target_bin) << start_bit; - prefix_mask |= digit_mask; - - __syncthreads(); - } - - if (threadIdx.x == 0) { - shared_count[0] = 0; - } - __syncthreads(); - - for (int i = threadIdx.x; i < tile_n; i += BLOCK_THREADS) { - UnsignedT key = shared_keys[i]; - if (key < target_prefix) { - int pos = atomicAdd(&shared_count[0], 1); - if (ARG_PARTITION) { - row_output[pos * out_stride] = shared_idxs[i]; - } else { - row_output[pos * out_stride] = row_input[shared_idxs[i] * in_stride]; - } - } - } - __syncthreads(); - - for (int i = threadIdx.x; i < tile_n; i += BLOCK_THREADS) { - UnsignedT key = shared_keys[i]; - if (key == target_prefix) { - int pos = atomicAdd(&shared_count[0], 1); - if (ARG_PARTITION) { - row_output[pos * out_stride] = shared_idxs[i]; - } else { - row_output[pos * out_stride] = row_input[shared_idxs[i] * in_stride]; - } - } - } - __syncthreads(); - - for (int i = threadIdx.x; i < tile_n; i += BLOCK_THREADS) { - UnsignedT key = shared_keys[i]; - if (key > target_prefix) { - int pos = atomicAdd(&shared_count[0], 1); - if (ARG_PARTITION) { - row_output[pos * out_stride] = shared_idxs[i]; - } else { - row_output[pos * out_stride] = row_input[shared_idxs[i] * in_stride]; - } - } - } -} - /////////////////////////////////////////////////////////////////////////////// // Large array streaming kernel (multi-pass, in-place) /////////////////////////////////////////////////////////////////////////////// @@ -772,146 +644,4 @@ __global__ void radix_select_large_streaming_kernel( } } -template -__global__ void radix_select_large_streaming_nc_kernel( - const ValT* input, - OutT* output, - int n, - int kth, - int64_t in_stride, - int64_t out_stride, - const int32_t* nc_shape, - const int64_t* in_nc_strides, - const int64_t* out_nc_strides, - int nc_dim) { - using Traits = RadixTraits; - using UnsignedT = typename Traits::UnsignedT; - constexpr int NUM_PASSES = (Traits::BITS + RADIX_BITS - 1) / RADIX_BITS; - - int row = blockIdx.y; - int64_t in_block_idx = - elem_to_loc(int64_t(row), nc_shape, in_nc_strides, nc_dim); - int64_t out_block_idx = - elem_to_loc(int64_t(row), nc_shape, out_nc_strides, nc_dim); - const ValT* row_input = input + in_block_idx; - OutT* row_output = output + out_block_idx; - - __shared__ int shared_hist[RADIX_SIZE]; - __shared__ int shared_pivot_info[2]; - __shared__ int shared_counts[2]; - __shared__ int shared_output_counters[3]; - - int k = kth + 1; - UnsignedT target_prefix = 0; - UnsignedT prefix_mask = 0; - - for (int pass = NUM_PASSES - 1; pass >= 0; pass--) { - int start_bit = pass * RADIX_BITS; - - for (int i = threadIdx.x; i < RADIX_SIZE; i += BLOCK_THREADS) { - shared_hist[i] = 0; - } - __syncthreads(); - - for (int i = threadIdx.x; i < n; i += BLOCK_THREADS) { - ValT val = row_input[i * in_stride]; - UnsignedT key = Traits::to_radix(val); - if (is_nan_value(val)) { - key = ~UnsignedT(0); - } - if ((key & prefix_mask) == target_prefix) { - int digit = extract_digit(key, start_bit, RADIX_BITS); - atomicAdd(&shared_hist[digit], 1); - } - } - __syncthreads(); - - if (threadIdx.x == 0) { - int cumsum = 0; - int target_bin = 0; - for (int bin = 0; bin < RADIX_SIZE; bin++) { - int count = shared_hist[bin]; - if (cumsum + count >= k) { - target_bin = bin; - k = k - cumsum; - break; - } - cumsum += count; - } - shared_pivot_info[0] = target_bin; - shared_pivot_info[1] = k; - } - __syncthreads(); - - int target_bin = shared_pivot_info[0]; - k = shared_pivot_info[1]; - - UnsignedT digit_mask = UnsignedT((1 << RADIX_BITS) - 1) << start_bit; - target_prefix |= UnsignedT(target_bin) << start_bit; - prefix_mask |= digit_mask; - - if (threadIdx.x == 0) { - shared_counts[0] = 0; - shared_counts[1] = 0; - } - __syncthreads(); - } - - int local_less = 0, local_equal = 0; - for (int i = threadIdx.x; i < n; i += BLOCK_THREADS) { - ValT val = row_input[i * in_stride]; - UnsignedT key = Traits::to_radix(val); - if (is_nan_value(val)) { - key = ~UnsignedT(0); - } - if (key < target_prefix) - local_less++; - else if (key == target_prefix) - local_equal++; - } - - local_less = warp_reduce_sum(local_less); - local_equal = warp_reduce_sum(local_equal); - - int lane = threadIdx.x % WARP_SIZE; - if (lane == 0) { - atomicAdd(&shared_counts[0], local_less); - atomicAdd(&shared_counts[1], local_equal); - } - __syncthreads(); - - int less_count = shared_counts[0]; - int equal_count = shared_counts[1]; - - if (threadIdx.x == 0) { - shared_output_counters[0] = 0; - shared_output_counters[1] = 0; - shared_output_counters[2] = 0; - } - __syncthreads(); - - for (int i = threadIdx.x; i < n; i += BLOCK_THREADS) { - ValT val = row_input[i * in_stride]; - UnsignedT key = Traits::to_radix(val); - if (is_nan_value(val)) { - key = ~UnsignedT(0); - } - - int pos; - if (key < target_prefix) { - pos = atomicAdd(&shared_output_counters[0], 1); - } else if (key == target_prefix) { - pos = less_count + atomicAdd(&shared_output_counters[1], 1); - } else { - pos = less_count + equal_count + atomicAdd(&shared_output_counters[2], 1); - } - - if (ARG_PARTITION) { - row_output[pos * out_stride] = i; - } else { - row_output[pos * out_stride] = val; - } - } -} - } // namespace mlx::core::cu diff --git a/mlx/backend/cuda/sort.cu b/mlx/backend/cuda/sort.cu index ad648291d3..7f0f8e60d7 100644 --- a/mlx/backend/cuda/sort.cu +++ b/mlx/backend/cuda/sort.cu @@ -1142,16 +1142,9 @@ void gpu_radix_partition_small( constexpr bool ARG_PARTITION = decltype(arg_tag)::value; using OutT = std::conditional_t; + int64_t in_stride_segment_axis = INT64_MAX; + int64_t out_stride_segment_axis = INT64_MAX; if (contiguous) { - auto kernel = cu::radix_select_small_kernel< - ValT, - OutT, - ARG_PARTITION, - BLOCK_THREADS, - ITEMS_PER_THREAD>; - - int64_t in_stride_segment_axis = INT64_MAX; - int64_t out_stride_segment_axis = INT64_MAX; for (size_t i = 0; i < nc_shape.size(); i++) { if (nc_shape[i] == 1) { continue; @@ -1161,25 +1154,16 @@ void gpu_radix_partition_small( out_stride_segment_axis = std::min(out_stride_segment_axis, out_nc_str[i]); } + } - encoder.add_kernel_node( - kernel, - grid, - block, - 0, - gpu_ptr(in), - gpu_ptr(out), - kth, - size_sorted_axis, - in_stride_sorted_axis, - out_stride_sorted_axis, - in_stride_segment_axis, - out_stride_segment_axis); - } else { - auto kernel = cu::radix_select_small_nc_kernel< + dispatch_bool(contiguous, [&](auto contiguous_tag) { + constexpr bool USE_SIMPLE_STRIDE = decltype(contiguous_tag)::value; + + auto kernel = cu::radix_select_small_kernel< ValT, OutT, ARG_PARTITION, + USE_SIMPLE_STRIDE, BLOCK_THREADS, ITEMS_PER_THREAD>; @@ -1194,11 +1178,13 @@ void gpu_radix_partition_small( size_sorted_axis, in_stride_sorted_axis, out_stride_sorted_axis, + in_stride_segment_axis, + out_stride_segment_axis, nc_shape_ptr, in_nc_strides_ptr, out_nc_strides_ptr, nc_dim); - } + }); }); } else { throw std::runtime_error( @@ -1303,51 +1289,28 @@ void gpu_radix_partition_large( constexpr bool ARG_PARTITION = decltype(arg_tag)::value; using OutT = std::conditional_t; - if (contiguous) { - auto kernel = cu::radix_select_large_streaming_kernel< - ValT, - OutT, - ARG_PARTITION, - BLOCK_THREADS>; - - encoder.add_kernel_node( - kernel, - grid, - block, - 0, - gpu_ptr(in), - gpu_ptr(out), - size_sorted_axis, - kth, - in_stride_sorted_axis, - out_stride_sorted_axis, - nc_shape_ptr, - in_nc_strides_ptr, - out_nc_strides_ptr, - nc_dim); - } else { - auto kernel = cu::radix_select_large_streaming_nc_kernel< - ValT, - OutT, - ARG_PARTITION, - BLOCK_THREADS>; + // Large kernel always uses elem_to_loc addressing + auto kernel = cu::radix_select_large_streaming_kernel< + ValT, + OutT, + ARG_PARTITION, + BLOCK_THREADS>; - encoder.add_kernel_node( - kernel, - grid, - block, - 0, - gpu_ptr(in), - gpu_ptr(out), - size_sorted_axis, - kth, - in_stride_sorted_axis, - out_stride_sorted_axis, - nc_shape_ptr, - in_nc_strides_ptr, - out_nc_strides_ptr, - nc_dim); - } + encoder.add_kernel_node( + kernel, + grid, + block, + 0, + gpu_ptr(in), + gpu_ptr(out), + size_sorted_axis, + kth, + in_stride_sorted_axis, + out_stride_sorted_axis, + nc_shape_ptr, + in_nc_strides_ptr, + out_nc_strides_ptr, + nc_dim); }); } else { throw std::runtime_error( @@ -1489,4 +1452,4 @@ void Partition::eval_gpu(const std::vector& inputs, array& out) { gpu_radix_partition(stream(), inputs[0], out, axis_, kth_, false); } -} // namespace mlx::core +} // namespace mlx::core \ No newline at end of file From 35de134edd7fa6ff5ecc0c4cca39b0ec12d205b6 Mon Sep 17 00:00:00 2001 From: Lyxot Date: Mon, 9 Feb 2026 15:27:48 +0800 Subject: [PATCH 07/23] implement dynamic shared memory for radix partition small kernel --- mlx/backend/cuda/device/radix_select.cuh | 22 +++++++++++++++++----- mlx/backend/cuda/sort.cu | 11 ++++++++++- 2 files changed, 27 insertions(+), 6 deletions(-) diff --git a/mlx/backend/cuda/device/radix_select.cuh b/mlx/backend/cuda/device/radix_select.cuh index d34c9c2f00..0084b803bc 100644 --- a/mlx/backend/cuda/device/radix_select.cuh +++ b/mlx/backend/cuda/device/radix_select.cuh @@ -257,6 +257,15 @@ __device__ __forceinline__ T warp_reduce_sum(T val) { // Single-pass Radix Select for small arrays (fits in shared memory) /////////////////////////////////////////////////////////////////////////////// +// Helper to calculate required shared memory size for small kernel +template +constexpr size_t radix_select_small_shared_mem_size() { + return TILE_SIZE * sizeof(UnsignedT) + // shared_keys + TILE_SIZE * sizeof(uint32_t) + // shared_idxs + RADIX_SIZE * sizeof(int) + // shared_hist + 2 * sizeof(int); // shared_count +} + template < typename ValT, typename OutT, @@ -283,11 +292,14 @@ __global__ void radix_select_small_kernel( constexpr int TILE_SIZE = BLOCK_THREADS * ITEMS_PER_THREAD; constexpr int NUM_PASSES = (Traits::BITS + RADIX_BITS - 1) / RADIX_BITS; - // Shared memory - __shared__ UnsignedT shared_keys[TILE_SIZE]; - __shared__ uint32_t shared_idxs[TILE_SIZE]; - __shared__ int shared_hist[RADIX_SIZE]; - __shared__ int shared_count[2]; + // Dynamic shared memory layout + extern __shared__ char shared_mem[]; + + // Calculate offsets for different arrays in shared memory + UnsignedT* shared_keys = reinterpret_cast(shared_mem); + uint32_t* shared_idxs = reinterpret_cast(shared_keys + TILE_SIZE); + int* shared_hist = reinterpret_cast(shared_idxs + TILE_SIZE); + int* shared_count = shared_hist + RADIX_SIZE; int row = blockIdx.y; diff --git a/mlx/backend/cuda/sort.cu b/mlx/backend/cuda/sort.cu index 7f0f8e60d7..e0aae76779 100644 --- a/mlx/backend/cuda/sort.cu +++ b/mlx/backend/cuda/sort.cu @@ -1167,11 +1167,20 @@ void gpu_radix_partition_small( BLOCK_THREADS, ITEMS_PER_THREAD>; + // Calculate dynamic shared memory size + using UnsignedT = typename cu::RadixTraits::UnsignedT; + constexpr int TILE_SIZE_VAL = BLOCK_THREADS * ITEMS_PER_THREAD; + constexpr size_t shared_mem_bytes = + TILE_SIZE_VAL * sizeof(UnsignedT) + // shared_keys + TILE_SIZE_VAL * sizeof(uint32_t) + // shared_idxs + 256 * sizeof(int) + // shared_hist (RADIX_SIZE=256) + 2 * sizeof(int); // shared_count + encoder.add_kernel_node( kernel, grid, block, - 0, + shared_mem_bytes, gpu_ptr(in), gpu_ptr(out), kth, From 498860fb27098d8056a5623600ed5e944297dd8e Mon Sep 17 00:00:00 2001 From: Lyxot Date: Tue, 10 Feb 2026 00:43:01 +0800 Subject: [PATCH 08/23] introduce multi-block-per-row and multi-row-per-block remove fallback strategy --- mlx/backend/cuda/device/radix_select.cuh | 328 +++++++++++++++++++ mlx/backend/cuda/sort.cu | 391 ++++++++++++++++++----- 2 files changed, 638 insertions(+), 81 deletions(-) diff --git a/mlx/backend/cuda/device/radix_select.cuh b/mlx/backend/cuda/device/radix_select.cuh index 0084b803bc..ee24dbae51 100644 --- a/mlx/backend/cuda/device/radix_select.cuh +++ b/mlx/backend/cuda/device/radix_select.cuh @@ -241,6 +241,17 @@ __device__ __forceinline__ bool is_nan_value(T val) { } } +template +__device__ __forceinline__ typename RadixTraits::UnsignedT +radix_key_with_nan_last(ValT val) { + using UnsignedT = typename RadixTraits::UnsignedT; + UnsignedT key = RadixTraits::to_radix(val); + if (is_nan_value(val)) { + key = ~UnsignedT(0); + } + return key; +} + /////////////////////////////////////////////////////////////////////////////// // Warp-level utilities /////////////////////////////////////////////////////////////////////////////// @@ -656,4 +667,321 @@ __global__ void radix_select_large_streaming_kernel( } } +/////////////////////////////////////////////////////////////////////////////// +// Tiled large-array kernels +// +// These kernels run with a 2D launch: +// - x-dimension tiles one row across multiple blocks (multi-block-per-row) +// - y-dimension packs multiple rows into one block group (multi-row-per-block) +/////////////////////////////////////////////////////////////////////////////// + +template +__global__ void radix_select_tiled_init_state_kernel( + UnsignedT* target_prefix, + UnsignedT* prefix_mask, + int* k_values, + int* row_hist, + int kth, + int n_rows, + int rows_per_block) { + int row_start = blockIdx.y * rows_per_block; + int row_end = min(n_rows, row_start + rows_per_block); + for (int row = row_start; row < row_end; ++row) { + if (threadIdx.x == 0) { + target_prefix[row] = UnsignedT(0); + prefix_mask[row] = UnsignedT(0); + k_values[row] = kth + 1; + } + int* hist = row_hist + row * RADIX_SIZE; + for (int i = threadIdx.x; i < RADIX_SIZE; i += blockDim.x) { + hist[i] = 0; + } + } +} + +template +__global__ void radix_select_tiled_histogram_kernel( + const ValT* input, + int n, + int64_t in_stride, + int64_t in_segment_stride, + const typename RadixTraits::UnsignedT* target_prefix, + const typename RadixTraits::UnsignedT* prefix_mask, + int start_bit, + int blocks_per_row, + int n_rows, + int rows_per_block, + int* row_hist) { + using Traits = RadixTraits; + using UnsignedT = typename Traits::UnsignedT; + + int block_in_row = blockIdx.x; + int row_start = blockIdx.y * rows_per_block; + int row_end = min(n_rows, row_start + rows_per_block); + + int chunk = (n + blocks_per_row - 1) / blocks_per_row; + int start = block_in_row * chunk; + int end = min(n, start + chunk); + if (start >= n || row_start >= row_end) { + return; + } + + __shared__ int shared_hist[RADIX_SIZE]; + for (int row = row_start; row < row_end; ++row) { + const ValT* row_input = input + row * in_segment_stride; + UnsignedT row_prefix = target_prefix[row]; + UnsignedT row_mask = prefix_mask[row]; + + for (int i = threadIdx.x; i < RADIX_SIZE; i += BLOCK_THREADS) { + shared_hist[i] = 0; + } + __syncthreads(); + + for (int i = start + threadIdx.x; i < end; i += BLOCK_THREADS) { + ValT val = row_input[i * in_stride]; + UnsignedT key = radix_key_with_nan_last(val); + if ((key & row_mask) == row_prefix) { + int digit = extract_digit(key, start_bit, RADIX_BITS); + atomicAdd(&shared_hist[digit], 1); + } + } + __syncthreads(); + + int* hist = row_hist + row * RADIX_SIZE; + for (int i = threadIdx.x; i < RADIX_SIZE; i += BLOCK_THREADS) { + atomicAdd(&hist[i], shared_hist[i]); + } + __syncthreads(); + } +} + +template +__global__ void radix_select_tiled_select_bin_kernel( + int* row_hist, + UnsignedT* target_prefix, + UnsignedT* prefix_mask, + int* k_values, + int clear_hist_for_next_pass, + int start_bit, + int n_rows, + int rows_per_block) { + int row_start = blockIdx.y * rows_per_block; + int row_end = min(n_rows, row_start + rows_per_block); + for (int row = row_start; row < row_end; ++row) { + int* hist = row_hist + row * RADIX_SIZE; + + if (threadIdx.x == 0) { + int k = k_values[row]; + int cumsum = 0; + int target_bin = 0; + for (int bin = 0; bin < RADIX_SIZE; bin++) { + int count = hist[bin]; + if (cumsum + count >= k) { + target_bin = bin; + k -= cumsum; + break; + } + cumsum += count; + } + k_values[row] = k; + + UnsignedT digit_mask = + (UnsignedT((UnsignedT(1) << RADIX_BITS) - UnsignedT(1)) << start_bit); + target_prefix[row] |= UnsignedT(target_bin) << start_bit; + prefix_mask[row] |= digit_mask; + } + __syncthreads(); + + if (clear_hist_for_next_pass) { + for (int i = threadIdx.x; i < RADIX_SIZE; i += blockDim.x) { + hist[i] = 0; + } + } + __syncthreads(); + } +} + +template +__global__ void radix_select_tiled_count_kernel( + const ValT* input, + int n, + int64_t in_stride, + int64_t in_segment_stride, + const typename RadixTraits::UnsignedT* target_prefix, + int blocks_per_row, + int n_rows, + int rows_per_block, + int* block_less, + int* block_equal) { + using Traits = RadixTraits; + using UnsignedT = typename Traits::UnsignedT; + + int block_in_row = blockIdx.x; + int row_start = blockIdx.y * rows_per_block; + int row_end = min(n_rows, row_start + rows_per_block); + + int chunk = (n + blocks_per_row - 1) / blocks_per_row; + int start = block_in_row * chunk; + int end = min(n, start + chunk); + if (row_start >= row_end) { + return; + } + + __shared__ int shared_counts[2]; + for (int row = row_start; row < row_end; ++row) { + int block_idx = row * blocks_per_row + block_in_row; + const ValT* row_input = input + row * in_segment_stride; + UnsignedT row_prefix = target_prefix[row]; + + int local_less = 0; + int local_equal = 0; + for (int i = start + threadIdx.x; i < end; i += BLOCK_THREADS) { + ValT val = row_input[i * in_stride]; + UnsignedT key = radix_key_with_nan_last(val); + if (key < row_prefix) { + local_less++; + } else if (key == row_prefix) { + local_equal++; + } + } + + local_less = warp_reduce_sum(local_less); + local_equal = warp_reduce_sum(local_equal); + + if (threadIdx.x == 0) { + shared_counts[0] = 0; + shared_counts[1] = 0; + } + __syncthreads(); + + if ((threadIdx.x % WARP_SIZE) == 0) { + atomicAdd(&shared_counts[0], local_less); + atomicAdd(&shared_counts[1], local_equal); + } + __syncthreads(); + + if (threadIdx.x == 0) { + block_less[block_idx] = shared_counts[0]; + block_equal[block_idx] = shared_counts[1]; + } + __syncthreads(); + } +} + +__global__ void radix_select_tiled_prefix_kernel( + int n, + int blocks_per_row, + int n_rows, + int rows_per_block, + const int* block_less, + const int* block_equal, + int* less_base, + int* equal_base, + int* greater_base) { + if (threadIdx.x != 0) { + return; + } + + int row_start = blockIdx.y * rows_per_block; + int row_end = min(n_rows, row_start + rows_per_block); + int chunk = (n + blocks_per_row - 1) / blocks_per_row; + + for (int row = row_start; row < row_end; ++row) { + int row_off = row * blocks_per_row; + int total_less = 0; + int total_equal = 0; + for (int b = 0; b < blocks_per_row; b++) { + int idx = row_off + b; + total_less += block_less[idx]; + total_equal += block_equal[idx]; + } + + int run_less = 0; + int run_equal = 0; + int run_greater = 0; + for (int b = 0; b < blocks_per_row; b++) { + int idx = row_off + b; + less_base[idx] = run_less; + equal_base[idx] = total_less + run_equal; + greater_base[idx] = total_less + total_equal + run_greater; + + int start = b * chunk; + int end = min(n, start + chunk); + int chunk_size = max(0, end - start); + int greater_count = chunk_size - block_less[idx] - block_equal[idx]; + + run_less += block_less[idx]; + run_equal += block_equal[idx]; + run_greater += greater_count; + } + } +} + +template +__global__ void radix_select_tiled_scatter_kernel( + const ValT* input, + OutT* output, + int n, + int64_t in_stride, + int64_t out_stride, + int64_t in_segment_stride, + int64_t out_segment_stride, + const typename RadixTraits::UnsignedT* target_prefix, + int blocks_per_row, + int n_rows, + int rows_per_block, + const int* less_base, + const int* equal_base, + const int* greater_base) { + using Traits = RadixTraits; + using UnsignedT = typename Traits::UnsignedT; + + int block_in_row = blockIdx.x; + int row_start = blockIdx.y * rows_per_block; + int row_end = min(n_rows, row_start + rows_per_block); + + int chunk = (n + blocks_per_row - 1) / blocks_per_row; + int start = block_in_row * chunk; + int end = min(n, start + chunk); + if (start >= n || row_start >= row_end) { + return; + } + + __shared__ int shared_out[3]; + for (int row = row_start; row < row_end; ++row) { + int block_idx = row * blocks_per_row + block_in_row; + const ValT* row_input = input + row * in_segment_stride; + OutT* row_output = output + row * out_segment_stride; + UnsignedT row_prefix = target_prefix[row]; + + if (threadIdx.x == 0) { + shared_out[0] = 0; + shared_out[1] = 0; + shared_out[2] = 0; + } + __syncthreads(); + + for (int i = start + threadIdx.x; i < end; i += BLOCK_THREADS) { + ValT val = row_input[i * in_stride]; + UnsignedT key = radix_key_with_nan_last(val); + + int pos; + if (key < row_prefix) { + pos = less_base[block_idx] + atomicAdd(&shared_out[0], 1); + } else if (key == row_prefix) { + pos = equal_base[block_idx] + atomicAdd(&shared_out[1], 1); + } else { + pos = greater_base[block_idx] + atomicAdd(&shared_out[2], 1); + } + + if (ARG_PARTITION) { + row_output[pos * out_stride] = static_cast(i); + } else { + row_output[pos * out_stride] = static_cast(val); + } + } + __syncthreads(); + } +} + } // namespace mlx::core::cu diff --git a/mlx/backend/cuda/sort.cu b/mlx/backend/cuda/sort.cu index e0aae76779..f0f774e492 100644 --- a/mlx/backend/cuda/sort.cu +++ b/mlx/backend/cuda/sort.cu @@ -1202,6 +1202,292 @@ void gpu_radix_partition_small( }); } +int64_t segment_stride_for_contiguous( + const Shape& shape_no_axis, + const Strides& strides_no_axis) { + int64_t stride = INT64_MAX; + for (size_t i = 0; i < shape_no_axis.size(); ++i) { + if (shape_no_axis[i] == 1) { + continue; + } + stride = std::min(stride, strides_no_axis[i]); + } + return (stride == INT64_MAX) ? int64_t(0) : stride; +} + +struct RadixLaunchPlan { + int blocks_per_row{1}; + int rows_per_block{1}; + + bool uses_tiled_launch() const { + return blocks_per_row > 1; + } +}; + +RadixLaunchPlan make_radix_tiled_launch_plan( + const Stream& s, + int n_rows, + int size_sorted_axis) { + if (n_rows <= 0 || size_sorted_axis <= 0 || size_sorted_axis < 8192) { + return {}; + } + + constexpr int kBlocksPerSmTarget = 4; + constexpr int kMinElemsPerBlock = 1024; + constexpr int kMaxBlocksPerRow = 32; + constexpr int kMaxRowsPerBlock = 4; + + int sm_count = 0; + CHECK_CUDA_ERROR(cudaDeviceGetAttribute( + &sm_count, cudaDevAttrMultiProcessorCount, s.device.index)); + sm_count = std::max(sm_count, 1); + + int target_blocks = std::max(1, sm_count * kBlocksPerSmTarget); + int needed_blocks_per_row = + std::max(1, (target_blocks + n_rows - 1) / n_rows); + + int max_blocks_by_work = std::max( + 1, std::min(kMaxBlocksPerRow, size_sorted_axis / kMinElemsPerBlock)); + int blocks_per_row = std::min(needed_blocks_per_row, max_blocks_by_work); + + if (blocks_per_row <= 1) { + return {}; + } + + int total_blocks = n_rows * blocks_per_row; + int chunk = (size_sorted_axis + blocks_per_row - 1) / blocks_per_row; + + int rows_per_block = 1; + if (total_blocks > target_blocks && chunk <= 2 * kMinElemsPerBlock) { + rows_per_block = std::min( + {kMaxRowsPerBlock, + n_rows, + std::max(1, (total_blocks + target_blocks - 1) / target_blocks)}); + } + return {blocks_per_row, rows_per_block}; +} + +Dtype unsigned_dtype_for_size(int size) { + switch (size) { + case 1: + return uint8; + case 2: + return uint16; + case 4: + return uint32; + case 8: + return uint64; + default: + throw std::runtime_error("Unsupported radix key size"); + } +} + +void gpu_radix_partition_large_tiled( + const Stream& s, + const array& in, + array& out, + int kth, + int n_rows, + int size_sorted_axis, + int64_t in_stride_sorted_axis, + int64_t out_stride_sorted_axis, + int64_t in_stride_segment_axis, + int64_t out_stride_segment_axis, + int blocks_per_row, + int rows_per_block, + bool arg_partition) { + auto& encoder = cu::get_command_encoder(s); + out.set_data(cu::malloc_async(out.nbytes(), encoder)); + encoder.set_input_array(in); + encoder.set_output_array(out); + + dispatch_all_types(in.dtype(), [&](auto type_tag) { + using CTYPE = MLX_GET_TYPE(type_tag); + if constexpr (!std::is_same_v) { + using ValT = cuda_type_t; + constexpr int BLOCK_THREADS = 256; + constexpr int NUM_PASSES = + (cu::RadixTraits::BITS + cu::RADIX_BITS - 1) / cu::RADIX_BITS; + using UnsignedT = typename cu::RadixTraits::UnsignedT; + + Dtype unsigned_dtype = unsigned_dtype_for_size(sizeof(UnsignedT)); + + array target_prefix_dev({n_rows}, unsigned_dtype, nullptr, {}); + array prefix_mask_dev({n_rows}, unsigned_dtype, nullptr, {}); + array k_values_dev({n_rows}, int32, nullptr, {}); + array row_hist_dev({n_rows, cu::RADIX_SIZE}, int32, nullptr, {}); + + int total_blocks = n_rows * blocks_per_row; + array block_less_dev({total_blocks}, int32, nullptr, {}); + array block_equal_dev({total_blocks}, int32, nullptr, {}); + array less_base_dev({total_blocks}, int32, nullptr, {}); + array equal_base_dev({total_blocks}, int32, nullptr, {}); + array greater_base_dev({total_blocks}, int32, nullptr, {}); + + auto allocate_temporary = [&](array& a) { + a.set_data(cu::malloc_async(a.nbytes(), encoder)); + encoder.add_temporary(a); + }; + allocate_temporary(target_prefix_dev); + allocate_temporary(prefix_mask_dev); + allocate_temporary(k_values_dev); + allocate_temporary(row_hist_dev); + allocate_temporary(block_less_dev); + allocate_temporary(block_equal_dev); + allocate_temporary(less_base_dev); + allocate_temporary(equal_base_dev); + allocate_temporary(greater_base_dev); + + int row_groups = (n_rows + rows_per_block - 1) / rows_per_block; + dim3 row_grid(1, row_groups, 1); + dim3 grid(blocks_per_row, row_groups, 1); + + encoder.set_output_array(target_prefix_dev); + encoder.set_output_array(prefix_mask_dev); + encoder.set_output_array(k_values_dev); + encoder.set_output_array(row_hist_dev); + encoder.add_kernel_node( + cu::radix_select_tiled_init_state_kernel, + row_grid, + dim3(32, 1, 1), + 0, + gpu_ptr(target_prefix_dev), + gpu_ptr(prefix_mask_dev), + gpu_ptr(k_values_dev), + gpu_ptr(row_hist_dev), + kth, + n_rows, + rows_per_block); + + for (int pass = NUM_PASSES - 1; pass >= 0; --pass) { + int start_bit = pass * cu::RADIX_BITS; + + encoder.set_input_array(in); + encoder.set_input_array(row_hist_dev); + encoder.set_input_array(target_prefix_dev); + encoder.set_input_array(prefix_mask_dev); + encoder.set_output_array(row_hist_dev); + encoder.add_kernel_node( + cu::radix_select_tiled_histogram_kernel, + grid, + dim3(BLOCK_THREADS, 1, 1), + 0, + gpu_ptr(in), + size_sorted_axis, + in_stride_sorted_axis, + in_stride_segment_axis, + gpu_ptr(target_prefix_dev), + gpu_ptr(prefix_mask_dev), + start_bit, + blocks_per_row, + n_rows, + rows_per_block, + gpu_ptr(row_hist_dev)); + + encoder.set_input_array(row_hist_dev); + encoder.set_input_array(target_prefix_dev); + encoder.set_input_array(prefix_mask_dev); + encoder.set_input_array(k_values_dev); + encoder.set_input_array(row_hist_dev); + encoder.set_output_array(target_prefix_dev); + encoder.set_output_array(prefix_mask_dev); + encoder.set_output_array(k_values_dev); + encoder.set_output_array(row_hist_dev); + encoder.add_kernel_node( + cu::radix_select_tiled_select_bin_kernel, + row_grid, + dim3(32, 1, 1), + 0, + gpu_ptr(row_hist_dev), + gpu_ptr(target_prefix_dev), + gpu_ptr(prefix_mask_dev), + gpu_ptr(k_values_dev), + pass > 0 ? 1 : 0, + start_bit, + n_rows, + rows_per_block); + } + + encoder.set_input_array(in); + encoder.set_input_array(target_prefix_dev); + encoder.set_output_array(block_less_dev); + encoder.set_output_array(block_equal_dev); + encoder.add_kernel_node( + cu::radix_select_tiled_count_kernel, + grid, + dim3(BLOCK_THREADS, 1, 1), + 0, + gpu_ptr(in), + size_sorted_axis, + in_stride_sorted_axis, + in_stride_segment_axis, + gpu_ptr(target_prefix_dev), + blocks_per_row, + n_rows, + rows_per_block, + gpu_ptr(block_less_dev), + gpu_ptr(block_equal_dev)); + + encoder.set_input_array(block_less_dev); + encoder.set_input_array(block_equal_dev); + encoder.set_output_array(less_base_dev); + encoder.set_output_array(equal_base_dev); + encoder.set_output_array(greater_base_dev); + encoder.add_kernel_node( + cu::radix_select_tiled_prefix_kernel, + row_grid, + dim3(1, 1, 1), + 0, + size_sorted_axis, + blocks_per_row, + n_rows, + rows_per_block, + gpu_ptr(block_less_dev), + gpu_ptr(block_equal_dev), + gpu_ptr(less_base_dev), + gpu_ptr(equal_base_dev), + gpu_ptr(greater_base_dev)); + + dispatch_bool(arg_partition, [&](auto arg_tag) { + constexpr bool ARG_PARTITION = decltype(arg_tag)::value; + using OutT = std::conditional_t; + encoder.set_input_array(in); + encoder.set_input_array(target_prefix_dev); + encoder.set_input_array(less_base_dev); + encoder.set_input_array(equal_base_dev); + encoder.set_input_array(greater_base_dev); + encoder.set_output_array(out); + encoder.add_kernel_node( + cu::radix_select_tiled_scatter_kernel< + ValT, + OutT, + ARG_PARTITION, + BLOCK_THREADS>, + grid, + dim3(BLOCK_THREADS, 1, 1), + 0, + gpu_ptr(in), + gpu_ptr(out), + size_sorted_axis, + in_stride_sorted_axis, + out_stride_sorted_axis, + in_stride_segment_axis, + out_stride_segment_axis, + gpu_ptr(target_prefix_dev), + blocks_per_row, + n_rows, + rows_per_block, + gpu_ptr(less_base_dev), + gpu_ptr(equal_base_dev), + gpu_ptr(greater_base_dev)); + }); + } else { + throw std::runtime_error( + "CUDA backend does not support sorting complex numbers"); + } + }); +} + void gpu_radix_partition_large( const Stream& s, const array& in, @@ -1237,9 +1523,32 @@ void gpu_radix_partition_large( contiguous &= check_strides(in, in_stride_sorted_axis); contiguous &= check_strides(out, out_stride_sorted_axis); + if (contiguous) { + const auto plan = make_radix_tiled_launch_plan(s, n_rows, size_sorted_axis); + if (plan.uses_tiled_launch()) { + const int64_t in_stride_segment_axis = + segment_stride_for_contiguous(nc_shape, in_nc_str); + const int64_t out_stride_segment_axis = + segment_stride_for_contiguous(nc_shape, out_nc_str); + return gpu_radix_partition_large_tiled( + s, + in, + out, + kth, + n_rows, + size_sorted_axis, + in_stride_sorted_axis, + out_stride_sorted_axis, + in_stride_segment_axis, + out_stride_segment_axis, + plan.blocks_per_row, + plan.rows_per_block, + arg_partition); + } + } + auto& encoder = cu::get_command_encoder(s); out.set_data(cu::malloc_async(out.nbytes(), encoder)); - encoder.set_input_array(in); encoder.set_output_array(out); @@ -1328,81 +1637,6 @@ void gpu_radix_partition_large( }); } -struct FallbackLinearModel { - int max_rows; - int axis_intercept; - int axis_slope; - int axis_min; -}; - -int axis_threshold(const FallbackLinearModel& model, int n_rows) { - return std::max( - model.axis_min, model.axis_intercept + model.axis_slope * n_rows); -} - -bool should_use_merge_sort_fallback_model( - const FallbackLinearModel& model, - int n_rows, - int size_sorted_axis) { - if (n_rows <= 0 || n_rows > model.max_rows) { - return false; - } - return size_sorted_axis >= axis_threshold(model, n_rows); -} - -bool is_integer_dtype(Dtype dtype) { - return dtype == int8 || dtype == int16 || dtype == int32 || dtype == int64 || - dtype == uint8 || dtype == uint16 || dtype == uint32 || dtype == uint64; -} - -FallbackLinearModel float_fallback_model(int dtype_size) { - return { - 8, - 22528, - 16384 / dtype_size, - 94208 / dtype_size, - }; -} - -FallbackLinearModel integer_fallback_model(int dtype_size) { - return { - dtype_size == 8 ? 12 : 6, - 51200 / dtype_size, - 16384 / dtype_size, - 6144 / dtype_size, - }; -} - -bool should_use_merge_sort_fallback( - Dtype dtype, - int n_rows, - int size_sorted_axis) { - int dtype_size = size_of(dtype); - - if (dtype == float32) { - // Use fallback model or for small axis always use merge sort - return should_use_merge_sort_fallback_model( - float_fallback_model(dtype_size), n_rows, size_sorted_axis) || - size_sorted_axis <= 512 || (n_rows <= 64 && size_sorted_axis <= 4096); - } else if (dtype == bfloat16 || dtype == float16) { - // Use fallback model or when batch is large and axis is small, merge sort - // wins - return should_use_merge_sort_fallback_model( - float_fallback_model(dtype_size), n_rows, size_sorted_axis) || - (n_rows >= 512 && size_sorted_axis <= 512); - } else if (dtype == float64) { - // float64 is not supported on the GPU - return true; - } else if (is_integer_dtype(dtype) || dtype == bool_) { - // Use fallback model for all integer types and bool - return should_use_merge_sort_fallback_model( - integer_fallback_model(dtype_size), n_rows, size_sorted_axis); - } else { - // Fallback for unknown or unsupported types - return true; - } -} - void gpu_radix_partition( const Stream& s, const array& in, @@ -1412,13 +1646,8 @@ void gpu_radix_partition( bool arg_partition) { int axis = axis_ < 0 ? axis_ + in.ndim() : axis_; int size_sorted_axis = in.shape(axis); - int n_rows = in.size() / size_sorted_axis; int kth = kth_ < 0 ? kth_ + size_sorted_axis : kth_; - if (should_use_merge_sort_fallback(in.dtype(), n_rows, size_sorted_axis)) { - return gpu_merge_sort(s, in, out, axis, arg_partition); - } - // Dispatch based on size if (size_sorted_axis <= 2048) { return gpu_radix_partition_small(s, in, out, axis, kth, arg_partition); From 159e8c151383b61de3ab298743886c377d645fe5 Mon Sep 17 00:00:00 2001 From: Lyxot Date: Tue, 10 Feb 2026 02:32:36 +0800 Subject: [PATCH 09/23] fix: make radix select tie-order deterministic --- mlx/backend/cuda/device/radix_select.cuh | 373 ++++++++++++++++------- 1 file changed, 270 insertions(+), 103 deletions(-) diff --git a/mlx/backend/cuda/device/radix_select.cuh b/mlx/backend/cuda/device/radix_select.cuh index ee24dbae51..9a1c7a2f46 100644 --- a/mlx/backend/cuda/device/radix_select.cuh +++ b/mlx/backend/cuda/device/radix_select.cuh @@ -264,6 +264,53 @@ __device__ __forceinline__ T warp_reduce_sum(T val) { return val; } +template +__device__ __forceinline__ int block_exclusive_scan( + int val, + int* shared_warp_sums, + int* block_total = nullptr) { + static_assert(BLOCK_THREADS % WARP_SIZE == 0); + constexpr int NUM_WARPS = BLOCK_THREADS / WARP_SIZE; + + int lane = threadIdx.x & (WARP_SIZE - 1); + int warp = threadIdx.x / WARP_SIZE; + + int inclusive = val; +#pragma unroll + for (int offset = 1; offset < WARP_SIZE; offset <<= 1) { + int n = __shfl_up_sync(0xFFFFFFFF, inclusive, offset); + if (lane >= offset) { + inclusive += n; + } + } + + if (lane == WARP_SIZE - 1) { + shared_warp_sums[warp] = inclusive; + } + __syncthreads(); + + if (warp == 0) { + int warp_scan = (lane < NUM_WARPS) ? shared_warp_sums[lane] : 0; +#pragma unroll + for (int offset = 1; offset < WARP_SIZE; offset <<= 1) { + int n = __shfl_up_sync(0xFFFFFFFF, warp_scan, offset); + if (lane >= offset) { + warp_scan += n; + } + } + + if (lane < NUM_WARPS) { + shared_warp_sums[lane] = warp_scan - shared_warp_sums[lane]; + } + if (block_total != nullptr && lane == NUM_WARPS - 1) { + *block_total = warp_scan; + } + } + __syncthreads(); + + return shared_warp_sums[warp] + inclusive - val; +} + /////////////////////////////////////////////////////////////////////////////// // Single-pass Radix Select for small arrays (fits in shared memory) /////////////////////////////////////////////////////////////////////////////// @@ -400,50 +447,48 @@ __global__ void radix_select_small_kernel( __syncthreads(); } - // Output partitioned array - if (threadIdx.x == 0) { - shared_count[0] = 0; - } - __syncthreads(); - - // Phase 1: output elements less than pivot + // Count per-thread bucket sizes once, then scatter in a single pass with + // deterministic per-thread offsets. + int local_less = 0; + int local_equal = 0; for (int i = threadIdx.x; i < tile_n; i += BLOCK_THREADS) { UnsignedT key = shared_keys[i]; if (key < target_prefix) { - int pos = atomicAdd(&shared_count[0], 1); - if (ARG_PARTITION) { - row_output[pos * out_stride] = shared_idxs[i]; - } else { - row_output[pos * out_stride] = row_input[shared_idxs[i] * in_stride]; - } + local_less++; + } else if (key == target_prefix) { + local_equal++; } } - __syncthreads(); - // Phase 2: output elements equal to pivot + int less_thread_offset = block_exclusive_scan( + local_less, shared_hist, &shared_count[0]); + int equal_thread_offset = block_exclusive_scan( + local_equal, shared_hist, &shared_count[1]); + + int q = tile_n / BLOCK_THREADS; + int r = tile_n - q * BLOCK_THREADS; + int prefix_total = int(threadIdx.x) * q + min(int(threadIdx.x), r); + int greater_thread_offset = + prefix_total - less_thread_offset - equal_thread_offset; + + int less_count = shared_count[0]; + int equal_count = shared_count[1]; + for (int i = threadIdx.x; i < tile_n; i += BLOCK_THREADS) { UnsignedT key = shared_keys[i]; - if (key == target_prefix) { - int pos = atomicAdd(&shared_count[0], 1); - if (ARG_PARTITION) { - row_output[pos * out_stride] = shared_idxs[i]; - } else { - row_output[pos * out_stride] = row_input[shared_idxs[i] * in_stride]; - } + int pos; + if (key < target_prefix) { + pos = less_thread_offset++; + } else if (key == target_prefix) { + pos = less_count + equal_thread_offset++; + } else { + pos = less_count + equal_count + greater_thread_offset++; } - } - __syncthreads(); - // Phase 3: output elements greater than pivot - for (int i = threadIdx.x; i < tile_n; i += BLOCK_THREADS) { - UnsignedT key = shared_keys[i]; - if (key > target_prefix) { - int pos = atomicAdd(&shared_count[0], 1); - if (ARG_PARTITION) { - row_output[pos * out_stride] = shared_idxs[i]; - } else { - row_output[pos * out_stride] = row_input[shared_idxs[i] * in_stride]; - } + if (ARG_PARTITION) { + row_output[pos * out_stride] = shared_idxs[i]; + } else { + row_output[pos * out_stride] = row_input[shared_idxs[i] * in_stride]; } } } @@ -480,7 +525,6 @@ __global__ void radix_select_large_streaming_kernel( __shared__ int shared_hist[RADIX_SIZE]; __shared__ int shared_pivot_info[2]; __shared__ int shared_counts[2]; - __shared__ int shared_output_counters[3]; int k = kth + 1; UnsignedT target_prefix = 0; @@ -551,16 +595,9 @@ __global__ void radix_select_large_streaming_kernel( UnsignedT digit_mask = UnsignedT((1 << RADIX_BITS) - 1) << start_bit; target_prefix |= UnsignedT(target_bin) << start_bit; prefix_mask |= digit_mask; - - // Initialize counters for next phase - if (threadIdx.x == 0) { - shared_counts[0] = 0; - shared_counts[1] = 0; - } - __syncthreads(); } - // Count partition sizes with warp reduction + // Count partition sizes. int local_less = 0, local_equal = 0; bool is_contiguous = (in_stride == 1); @@ -590,80 +627,136 @@ __global__ void radix_select_large_streaming_kernel( } } - // Warp reduction local_less = warp_reduce_sum(local_less); local_equal = warp_reduce_sum(local_equal); - // First lane of each warp aggregates to shared memory - int lane = threadIdx.x % WARP_SIZE; - if (lane == 0) { + if (threadIdx.x == 0) { + shared_counts[0] = 0; + shared_counts[1] = 0; + } + __syncthreads(); + + if ((threadIdx.x & (WARP_SIZE - 1)) == 0) { atomicAdd(&shared_counts[0], local_less); atomicAdd(&shared_counts[1], local_equal); } __syncthreads(); - // Read final counts int less_count = shared_counts[0]; int equal_count = shared_counts[1]; - // Initialize output counters + // Deterministic scatter in iteration order (0..n): this keeps output stable + // without thread-contention atomics in the hot scatter path. + constexpr int NUM_WARPS = BLOCK_THREADS / WARP_SIZE; + int lane = threadIdx.x & (WARP_SIZE - 1); + int warp = threadIdx.x / WARP_SIZE; + + int* warp_less = shared_hist; + int* warp_equal = shared_hist + NUM_WARPS; + int* warp_greater = shared_hist + 2 * NUM_WARPS; + int* iter_counts = shared_hist + 3 * NUM_WARPS; + int* running_bases = iter_counts + 3; + if (threadIdx.x == 0) { - shared_output_counters[0] = 0; - shared_output_counters[1] = 0; - shared_output_counters[2] = 0; + running_bases[0] = 0; + running_bases[1] = less_count; + running_bases[2] = less_count + equal_count; } __syncthreads(); - // Output partitioned elements - if (is_contiguous && out_stride == 1) { - // Fast path: both input and output are contiguous - for (int i = threadIdx.x; i < n; i += BLOCK_THREADS) { - ValT val = row_input[i]; - UnsignedT key = Traits::to_radix(val); + for (int base_i = 0; base_i < n; base_i += BLOCK_THREADS) { + int i = base_i + threadIdx.x; + bool active = i < n; + + ValT val{}; + UnsignedT key = 0; + if (active) { + val = is_contiguous ? row_input[i] : row_input[i * in_stride]; + key = Traits::to_radix(val); if (is_nan_value(val)) { key = ~UnsignedT(0); } + } - int pos; - if (key < target_prefix) { - pos = atomicAdd(&shared_output_counters[0], 1); - } else if (key == target_prefix) { - pos = less_count + atomicAdd(&shared_output_counters[1], 1); - } else { - pos = - less_count + equal_count + atomicAdd(&shared_output_counters[2], 1); + bool is_less = active && (key < target_prefix); + bool is_equal = active && (key == target_prefix); + bool is_greater = active && !is_less && !is_equal; + + unsigned less_mask = __ballot_sync(0xFFFFFFFF, is_less); + unsigned equal_mask = __ballot_sync(0xFFFFFFFF, is_equal); + unsigned greater_mask = __ballot_sync(0xFFFFFFFF, is_greater); + + unsigned lane_mask = (1u << lane) - 1u; + int less_rank = __popc(less_mask & lane_mask); + int equal_rank = __popc(equal_mask & lane_mask); + int greater_rank = __popc(greater_mask & lane_mask); + + if (lane == 0) { + warp_less[warp] = __popc(less_mask); + warp_equal[warp] = __popc(equal_mask); + warp_greater[warp] = __popc(greater_mask); + } + __syncthreads(); + + if (threadIdx.x == 0) { + int run = 0; + for (int w = 0; w < NUM_WARPS; ++w) { + int c = warp_less[w]; + warp_less[w] = run; + run += c; } + iter_counts[0] = run; - if (ARG_PARTITION) { - row_output[pos] = i; - } else { - row_output[pos] = val; + run = 0; + for (int w = 0; w < NUM_WARPS; ++w) { + int c = warp_equal[w]; + warp_equal[w] = run; + run += c; } - } - } else { - for (int i = threadIdx.x; i < n; i += BLOCK_THREADS) { - ValT val = row_input[i * in_stride]; - UnsignedT key = Traits::to_radix(val); - if (is_nan_value(val)) { - key = ~UnsignedT(0); + iter_counts[1] = run; + + run = 0; + for (int w = 0; w < NUM_WARPS; ++w) { + int c = warp_greater[w]; + warp_greater[w] = run; + run += c; } + iter_counts[2] = run; + } + __syncthreads(); + if (active) { int pos; - if (key < target_prefix) { - pos = atomicAdd(&shared_output_counters[0], 1); - } else if (key == target_prefix) { - pos = less_count + atomicAdd(&shared_output_counters[1], 1); + if (is_less) { + pos = running_bases[0] + warp_less[warp] + less_rank; + } else if (is_equal) { + pos = running_bases[1] + warp_equal[warp] + equal_rank; } else { - pos = - less_count + equal_count + atomicAdd(&shared_output_counters[2], 1); + pos = running_bases[2] + warp_greater[warp] + greater_rank; } if (ARG_PARTITION) { - row_output[pos * out_stride] = i; + if (out_stride == 1) { + row_output[pos] = i; + } else { + row_output[pos * out_stride] = i; + } } else { - row_output[pos * out_stride] = val; + if (out_stride == 1) { + row_output[pos] = val; + } else { + row_output[pos * out_stride] = val; + } } } + __syncthreads(); + + if (threadIdx.x == 0) { + running_bases[0] += iter_counts[0]; + running_bases[1] += iter_counts[1]; + running_bases[2] += iter_counts[2]; + } + __syncthreads(); } } @@ -947,7 +1040,18 @@ __global__ void radix_select_tiled_scatter_kernel( return; } - __shared__ int shared_out[3]; + constexpr int NUM_WARPS = BLOCK_THREADS / WARP_SIZE; + __shared__ int shared_warp_offsets[3 * NUM_WARPS]; + __shared__ int shared_iter_counts[3]; + __shared__ int shared_running_bases[3]; + + int lane = threadIdx.x & (WARP_SIZE - 1); + int warp = threadIdx.x / WARP_SIZE; + + int* warp_less = shared_warp_offsets; + int* warp_equal = shared_warp_offsets + NUM_WARPS; + int* warp_greater = shared_warp_offsets + 2 * NUM_WARPS; + for (int row = row_start; row < row_end; ++row) { int block_idx = row * blocks_per_row + block_in_row; const ValT* row_input = input + row * in_segment_stride; @@ -955,30 +1059,93 @@ __global__ void radix_select_tiled_scatter_kernel( UnsignedT row_prefix = target_prefix[row]; if (threadIdx.x == 0) { - shared_out[0] = 0; - shared_out[1] = 0; - shared_out[2] = 0; + shared_running_bases[0] = less_base[block_idx]; + shared_running_bases[1] = equal_base[block_idx]; + shared_running_bases[2] = greater_base[block_idx]; } __syncthreads(); - for (int i = start + threadIdx.x; i < end; i += BLOCK_THREADS) { - ValT val = row_input[i * in_stride]; - UnsignedT key = radix_key_with_nan_last(val); + for (int base_i = start; base_i < end; base_i += BLOCK_THREADS) { + int i = base_i + threadIdx.x; + bool active = i < end; - int pos; - if (key < row_prefix) { - pos = less_base[block_idx] + atomicAdd(&shared_out[0], 1); - } else if (key == row_prefix) { - pos = equal_base[block_idx] + atomicAdd(&shared_out[1], 1); - } else { - pos = greater_base[block_idx] + atomicAdd(&shared_out[2], 1); + ValT val{}; + UnsignedT key = 0; + if (active) { + val = row_input[i * in_stride]; + key = radix_key_with_nan_last(val); } - if (ARG_PARTITION) { - row_output[pos * out_stride] = static_cast(i); - } else { - row_output[pos * out_stride] = static_cast(val); + bool is_less = active && (key < row_prefix); + bool is_equal = active && (key == row_prefix); + bool is_greater = active && !is_less && !is_equal; + + unsigned less_mask = __ballot_sync(0xFFFFFFFF, is_less); + unsigned equal_mask = __ballot_sync(0xFFFFFFFF, is_equal); + unsigned greater_mask = __ballot_sync(0xFFFFFFFF, is_greater); + + unsigned lane_mask = (1u << lane) - 1u; + int less_rank = __popc(less_mask & lane_mask); + int equal_rank = __popc(equal_mask & lane_mask); + int greater_rank = __popc(greater_mask & lane_mask); + + if (lane == 0) { + warp_less[warp] = __popc(less_mask); + warp_equal[warp] = __popc(equal_mask); + warp_greater[warp] = __popc(greater_mask); + } + __syncthreads(); + + if (threadIdx.x == 0) { + int run = 0; + for (int w = 0; w < NUM_WARPS; ++w) { + int c = warp_less[w]; + warp_less[w] = run; + run += c; + } + shared_iter_counts[0] = run; + + run = 0; + for (int w = 0; w < NUM_WARPS; ++w) { + int c = warp_equal[w]; + warp_equal[w] = run; + run += c; + } + shared_iter_counts[1] = run; + + run = 0; + for (int w = 0; w < NUM_WARPS; ++w) { + int c = warp_greater[w]; + warp_greater[w] = run; + run += c; + } + shared_iter_counts[2] = run; + } + __syncthreads(); + + if (active) { + int pos; + if (is_less) { + pos = shared_running_bases[0] + warp_less[warp] + less_rank; + } else if (is_equal) { + pos = shared_running_bases[1] + warp_equal[warp] + equal_rank; + } else { + pos = shared_running_bases[2] + warp_greater[warp] + greater_rank; + } + if (ARG_PARTITION) { + row_output[pos * out_stride] = static_cast(i); + } else { + row_output[pos * out_stride] = static_cast(val); + } + } + __syncthreads(); + + if (threadIdx.x == 0) { + shared_running_bases[0] += shared_iter_counts[0]; + shared_running_bases[1] += shared_iter_counts[1]; + shared_running_bases[2] += shared_iter_counts[2]; } + __syncthreads(); } __syncthreads(); } From b4ff82556a265a82da499470d08222603726cb1a Mon Sep 17 00:00:00 2001 From: Lyxot Date: Tue, 10 Feb 2026 02:32:58 +0800 Subject: [PATCH 10/23] update benchmark scripts --- benchmarks/python/benchmark_radix_select.py | 183 ----------- benchmarks/python/radix_select_bench.py | 327 ++++++++++++++++++++ 2 files changed, 327 insertions(+), 183 deletions(-) delete mode 100644 benchmarks/python/benchmark_radix_select.py create mode 100644 benchmarks/python/radix_select_bench.py diff --git a/benchmarks/python/benchmark_radix_select.py b/benchmarks/python/benchmark_radix_select.py deleted file mode 100644 index 858d21fef1..0000000000 --- a/benchmarks/python/benchmark_radix_select.py +++ /dev/null @@ -1,183 +0,0 @@ -#!/usr/bin/env python3 -""" -Benchmark script for MLX argpartition/partition operations. -Compares radix select implementation against full argsort. -""" - -import time - -import mlx.core as mx -import numpy as np - -GREEN = "\033[92m" -YELLOW = "\033[33m" -RED = "\033[91m" -RESET = "\033[0m" - - -def color_speedup(speedup): - s = f"{speedup:>5.2f}x" - if 0.9 <= speedup <= 1.1: - return f"{YELLOW}{s}{RESET}" - elif speedup > 1.1: - return f"{GREEN}{s}{RESET}" - else: - return f"{RED}{s}{RESET}" - - -def benchmark_argpartition(b, v, k, dtype=mx.bfloat16, warmup=5, iterations=100): - x = mx.random.uniform(shape=(b, v)).astype(dtype) - mx.eval(x) - for _ in range(warmup): - mx.eval(mx.argpartition(x, kth=k, axis=-1)) - start = time.perf_counter() - for _ in range(iterations): - mx.eval(mx.argpartition(x, kth=k, axis=-1)) - return (time.perf_counter() - start) / iterations * 1000 - - -def benchmark_argsort(b, v, dtype=mx.bfloat16, warmup=5, iterations=100): - x = mx.random.uniform(shape=(b, v)).astype(dtype) - mx.eval(x) - for _ in range(warmup): - mx.eval(mx.argsort(x, axis=-1)) - start = time.perf_counter() - for _ in range(iterations): - mx.eval(mx.argsort(x, axis=-1)) - return (time.perf_counter() - start) / iterations * 1000 - - -def verify_correctness(b, v, k): - x = mx.random.uniform(shape=(b, v)).astype(mx.float32) - mx.eval(x) - indices = mx.argpartition(x, kth=k, axis=-1) - mx.eval(indices) - x_np = np.array(x) - indices_np = np.array(indices) - for i in range(b): - pv = x_np[i, indices_np[i]] - assert np.all(pv[:k] <= pv[k]), f"Row {i}: elements before k not all <= kth" - assert np.all(pv[k + 1 :] >= pv[k]), f"Row {i}: elements after k not all >= kth" - return True - - -def sweep_boundary(dtype=mx.bfloat16, k_ratio=0.004, warmup=10, iterations=50): - dtype_name = str(dtype).split(".")[-1] - print(f"\nDtype={dtype_name} k=vocab*{k_ratio:.3f}") - print() - - batch_sizes = [1, 2, 4, 8, 16, 32, 48, 64, 96, 128, 256, 512, 1024, 2048] - vocab_sizes = [512, 1024, 2048, 4096, 8192, 16384, 32768, 65536, 131072] - - col_w = 10 - print(f"{'':>8}", end="") - for v in vocab_sizes: - label = f"v={v}" - print(f" {label:^{col_w}}", end="") - print() - - for b in batch_sizes: - print(f"b={b:<6}", end="") - for v in vocab_sizes: - k = max(1, int(v * k_ratio)) - try: - x = mx.random.uniform(shape=(b, v)).astype(dtype) - mx.eval(x) - for _ in range(warmup): - mx.eval(mx.argpartition(x, kth=k, axis=-1)) - start = time.perf_counter() - for _ in range(iterations): - mx.eval(mx.argpartition(x, kth=k, axis=-1)) - radix_ms = (time.perf_counter() - start) / iterations * 1000 - for _ in range(warmup): - mx.eval(mx.argsort(x, axis=-1)) - start = time.perf_counter() - for _ in range(iterations): - mx.eval(mx.argsort(x, axis=-1)) - argsort_ms = (time.perf_counter() - start) / iterations * 1000 - - speedup = argsort_ms / radix_ms - cell = color_speedup(speedup) - # pad accounting for invisible ANSI codes - print(f" {cell:^{col_w + len(GREEN) + len(RESET)}}", end="") - except Exception: - print(f" {'ERR':^{col_w}}", end="") - print() - - -def main(): - print("=" * 70) - print("MLX Radix Select Benchmark") - print("=" * 70) - - configs = [ - (2048, 8192, 32), - (2048, 4096, 32), - (1024, 4096, 16), - (512, 2048, 64), - (256, 1024, 32), - (128, 512, 16), - (1, 128000, 64), - (1, 512, 32), - (16, 8192, 32), - (32, 8192, 32), - (64, 8192, 32), - ] - - dtypes = [(mx.bfloat16, "bfloat16"), (mx.float32, "float32")] - - print("\n1. Correctness Verification") - print("-" * 40) - for b, v, k in configs: - try: - verify_correctness(b, v, k) - print(f" {GREEN}[PASS]{RESET} b={b}, v={v}, k={k}") - except AssertionError as e: - print(f" {RED}[FAIL]{RESET} b={b}, v={v}, k={k}: {e}") - - print("\n2. Performance Benchmarks") - print("-" * 70) - - for dtype, dtype_name in dtypes: - print(f"\nDtype: {dtype_name}") - print(f"{'Config':<25} {'ArgPartition':>14} {'ArgSort':>12} {'Speedup':>10}") - print("-" * 80) - - for b, v, k in configs: - try: - argpart_ms = benchmark_argpartition( - b, v, k, dtype, warmup=3, iterations=50 - ) - argsort_ms = benchmark_argsort(b, v, dtype, warmup=3, iterations=50) - speedup = argsort_ms / argpart_ms - config_str = f"b={b}, v={v}, k={k}" - print( - f"{config_str:<25} {argpart_ms:>12.3f}ms" - f" {argsort_ms:>10.3f}ms {color_speedup(speedup)}" - ) - except Exception as e: - print(f"b={b}, v={v}, k={k}: Error - {e}") - - print("\n3. Boundary Sweep") - print("-" * 70) - # sweep_boundary(mx.bool_) - sweep_boundary(mx.bfloat16) - # sweep_boundary(mx.float16) - sweep_boundary(mx.float32) - # sweep_boundary(mx.float64) - # sweep_boundary(mx.int8) - # sweep_boundary(mx.int16) - # sweep_boundary(mx.int32) - # sweep_boundary(mx.int64) - # sweep_boundary(mx.uint8) - # sweep_boundary(mx.uint16) - # sweep_boundary(mx.uint32) - # sweep_boundary(mx.uint64) - - print("\n" + "=" * 70) - print("Benchmark Complete") - print("=" * 70) - - -if __name__ == "__main__": - main() diff --git a/benchmarks/python/radix_select_bench.py b/benchmarks/python/radix_select_bench.py new file mode 100644 index 0000000000..757492f31e --- /dev/null +++ b/benchmarks/python/radix_select_bench.py @@ -0,0 +1,327 @@ +#!/usr/bin/env python3 +""" +Benchmark script for MLX argpartition/partition operations. +Compares radix select implementation against full argsort. +""" + +import argparse +import time + +import mlx.core as mx +import numpy as np + +# Mapping from string names to MLX dtype objects +DTYPE_MAP = { + "bool": mx.bool_, + "bfloat16": mx.bfloat16, + "float16": mx.float16, + "float32": mx.float32, + "float64": mx.float64, + "int8": mx.int8, + "int16": mx.int16, + "int32": mx.int32, + "int64": mx.int64, + "uint8": mx.uint8, + "uint16": mx.uint16, + "uint32": mx.uint32, + "uint64": mx.uint64, +} + + +def parse_dtypes(dtype_str): + """Parse comma-separated dtype string into MLX dtype objects.""" + dtypes = [] + for dtype_str_item in dtype_str.split(","): + dtype_str_item = dtype_str_item.strip().lower() + if not dtype_str_item: + continue + if dtype_str_item not in DTYPE_MAP: + raise ValueError( + f"Unknown dtype: {dtype_str_item}. " + f"Supported dtypes: {', '.join(DTYPE_MAP.keys())}" + ) + dtypes.append((DTYPE_MAP[dtype_str_item], dtype_str_item)) + return dtypes + + +def benchmark_argpartition(b, v, k, dtype=mx.bfloat16, warmup=5, iterations=100): + x = mx.random.uniform(shape=(b, v)).astype(dtype) + mx.eval(x) + for _ in range(warmup): + mx.eval(mx.argpartition(x, kth=k, axis=-1)) + start = time.perf_counter() + for _ in range(iterations): + mx.eval(mx.argpartition(x, kth=k, axis=-1)) + return (time.perf_counter() - start) / iterations * 1000 + + +def benchmark_argsort(b, v, dtype=mx.bfloat16, warmup=5, iterations=100): + x = mx.random.uniform(shape=(b, v)).astype(dtype) + mx.eval(x) + for _ in range(warmup): + mx.eval(mx.argsort(x, axis=-1)) + start = time.perf_counter() + for _ in range(iterations): + mx.eval(mx.argsort(x, axis=-1)) + return (time.perf_counter() - start) / iterations * 1000 + + +def verify_correctness(b, v, k, dtype=mx.float32): + # Quantize random values to induce duplicates and stress tie handling. + x = mx.random.uniform(shape=(b, v)) + x = mx.floor(x * 257.0).astype(mx.float32) + x = x.astype(dtype) + mx.eval(x) + + indices = mx.argpartition(x, kth=k, axis=-1) + mx.eval(indices) + + # NumPy does not always expose bfloat16 buffers reliably in this environment. + x_np = np.array(x.astype(mx.float32)) if dtype == mx.bfloat16 else np.array(x) + indices_np = np.array(indices) + is_float = np.issubdtype(x_np.dtype, np.floating) + + assert indices_np.shape == ( + b, + v, + ), f"Unexpected argpartition output shape: got {indices_np.shape}, expected {(b, v)}" + assert np.issubdtype( + indices_np.dtype, np.integer + ), f"Argpartition indices must be integer, got {indices_np.dtype}" + + for i in range(b): + row = x_np[i] + row_idx = indices_np[i] + + assert np.all( + (row_idx >= 0) & (row_idx < v) + ), f"Row {i}: out-of-range indices found" + assert ( + np.unique(row_idx).size == v + ), f"Row {i}: indices are not a permutation of [0, {v})" + + pv = row[row_idx] + pivot = pv[k] + left = pv[:k] + right = pv[k + 1 :] + + if is_float and np.isnan(pivot): + non_nan_count = np.count_nonzero(~np.isnan(row)) + assert ( + non_nan_count <= k + ), f"Row {i}: pivot is NaN before all finite values are placed" + assert np.all( + np.isnan(pv[k:]) + ), f"Row {i}: values after NaN pivot must all be NaN" + continue + + if is_float: + left_ok = np.all((~np.isnan(left)) & (left <= pivot)) + right_ok = np.all(np.isnan(right) | (right >= pivot)) + else: + left_ok = np.all(left <= pivot) + right_ok = np.all(right >= pivot) + + assert left_ok, f"Row {i}: elements before kth violate partition property" + assert right_ok, f"Row {i}: elements after kth violate partition property" + + # Rank consistency: kth must lie within [count(8}", end="") + for v in vocab_sizes: + label = f"v={v}" + print(f" {label:^{col_w}}", end="") + print() + + for b in batch_sizes: + print(f"b={b:<6}", end="") + for v in vocab_sizes: + k = max(1, int(v * k_ratio)) + try: + x = mx.random.uniform(shape=(b, v)).astype(dtype) + mx.eval(x) + for _ in range(warmup): + mx.eval(mx.argpartition(x, kth=k, axis=-1)) + start = time.perf_counter() + for _ in range(iterations): + mx.eval(mx.argpartition(x, kth=k, axis=-1)) + radix_ms = (time.perf_counter() - start) / iterations * 1000 + for _ in range(warmup): + mx.eval(mx.argsort(x, axis=-1)) + start = time.perf_counter() + for _ in range(iterations): + mx.eval(mx.argsort(x, axis=-1)) + argsort_ms = (time.perf_counter() - start) / iterations * 1000 + + if verify: + verify_correctness(b, v, k, dtype=dtype) + verify_tie_determinism(b, v, k, dtype=dtype) + + speedup = argsort_ms / radix_ms + cell = f"{speedup:>5.2f}x" + print(f" {cell:^{col_w}}", end="") + except Exception: + print(f" {'ERR':^{col_w}}", end="") + print() + + +def main(): + parser = argparse.ArgumentParser( + description="Benchmark MLX radix select implementation" + ) + parser.add_argument( + "--boundary-sweep", + action="store_true", + help="Enable boundary sweep test (default: disabled)", + ) + parser.add_argument( + "--verify", + action="store_true", + help="Enable correctness verification (default: disabled). " + "Disabled when --boundary-sweep is enabled.", + ) + parser.add_argument( + "--dtypes", + type=str, + default="bfloat16,float32", + help="Comma-separated data types to test (default: bfloat16,float32). " + "Supported: bool, bfloat16, float16, float32, float64, " + "int8, int16, int32, int64, uint8, uint16, uint32, uint64", + ) + args = parser.parse_args() + + print("=" * 70) + print("MLX Radix Select Benchmark") + print("=" * 70) + + configs = [ + (2048, 8192, 32), + (2048, 4096, 32), + (1024, 4096, 16), + (512, 2048, 64), + (256, 1024, 32), + (128, 512, 16), + (1, 128000, 64), + (1, 512, 32), + (16, 8192, 32), + (32, 8192, 32), + (64, 8192, 32), + ] + + try: + dtypes = parse_dtypes(args.dtypes) + except ValueError as e: + print(f"Error: {e}") + return + + if not args.boundary_sweep: + if args.verify: + print("\n1. Correctness Verification") + print("-" * 40) + for dtype, dtype_name in dtypes: + for b, v, k in configs: + try: + verify_correctness(b, v, k, dtype=dtype) + print(f" [PASS] b={b}, v={v}, k={k}, dtype={dtype_name}") + except AssertionError as e: + print(f" [FAIL] b={b}, v={v}, k={k}, dtype={dtype_name}: {e}") + + print("\n2. Tie Determinism Verification") + print("-" * 40) + for dtype, dtype_name in dtypes: + for b, v, k in configs: + try: + verify_tie_determinism(b=b, v=v, k=k, dtype=dtype) + print( + f" [PASS] all-equal input " + f"(b={b}, v={v}, k={k}), dtype={dtype_name}, runs=8" + ) + except AssertionError as e: + print( + f" [FAIL] all-equal input " + f"(b={b}, v={v}, k={k}), dtype={dtype_name}, runs=8: {e}" + ) + + print("\n3. Performance Benchmarks") + else: + print("\nPerformance Benchmarks") + print("-" * 70) + + for dtype, dtype_name in dtypes: + print(f"\nDtype: {dtype_name}") + print( + f"{'Config':<25} {'ArgPartition':>14} {'ArgSort':>12} {'Speedup':>10}" + ) + print("-" * 80) + + for b, v, k in configs: + try: + argpart_ms = benchmark_argpartition( + b, v, k, dtype, warmup=3, iterations=50 + ) + argsort_ms = benchmark_argsort(b, v, dtype, warmup=3, iterations=50) + speedup = argsort_ms / argpart_ms + config_str = f"b={b}, v={v}, k={k}" + print( + f"{config_str:<25} {argpart_ms:>12.3f}ms" + f" {argsort_ms:>10.3f}ms {speedup:>5.2f}x" + ) + except Exception as e: + print(f"b={b}, v={v}, k={k}: Error - {e}") + + if args.boundary_sweep: + print("\nBoundary Sweep" + (" (with verification)" if args.verify else "")) + print("-" * 70) + for dtype, dtype_name in dtypes: + sweep_boundary(dtype, verify=args.verify) + + print("\n" + "=" * 70) + print("Benchmark Complete") + print("=" * 70) + + +if __name__ == "__main__": + main() From 6a160ac401b2c30a7baff76506bb29a688fd4188 Mon Sep 17 00:00:00 2001 From: Lyxot Date: Tue, 10 Feb 2026 03:04:17 +0800 Subject: [PATCH 11/23] fix: canonicalize signed zero in CUDA radix keys for deterministic radix select ties --- mlx/backend/cuda/device/radix_select.cuh | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/mlx/backend/cuda/device/radix_select.cuh b/mlx/backend/cuda/device/radix_select.cuh index 9a1c7a2f46..b73dcce384 100644 --- a/mlx/backend/cuda/device/radix_select.cuh +++ b/mlx/backend/cuda/device/radix_select.cuh @@ -35,6 +35,9 @@ struct RadixTraits { __device__ __forceinline__ static UnsignedT to_radix(float val) { uint32_t bits = __float_as_uint(val); + if ((bits << 1) == 0) { + bits = 0; // Canonicalize +/-0.0 to +0.0 for stable equal-value ties. + } uint32_t mask = -int32_t(bits >> 31) | 0x80000000u; return bits ^ mask; } @@ -52,6 +55,9 @@ struct RadixTraits { __device__ __forceinline__ static UnsignedT to_radix(double val) { uint64_t bits = __double_as_longlong(val); + if ((bits << 1) == 0) { + bits = 0; // Canonicalize +/-0.0 to +0.0 for stable equal-value ties. + } uint64_t mask = -int64_t(bits >> 63) | 0x8000000000000000ull; return bits ^ mask; } @@ -69,6 +75,9 @@ struct RadixTraits<__half> { __device__ __forceinline__ static UnsignedT to_radix(__half val) { uint16_t bits = __half_as_ushort(val); + if ((bits & 0x7FFFu) == 0) { + bits = 0; // Canonicalize +/-0 to +0 for stable equal-value ties. + } uint16_t mask = -int16_t(bits >> 15) | 0x8000u; return bits ^ mask; } @@ -86,6 +95,9 @@ struct RadixTraits<__nv_bfloat16> { __device__ __forceinline__ static UnsignedT to_radix(__nv_bfloat16 val) { uint16_t bits = __bfloat16_as_ushort(val); + if ((bits & 0x7FFFu) == 0) { + bits = 0; // Canonicalize +/-0 to +0 for stable equal-value ties. + } uint16_t mask = -int16_t(bits >> 15) | 0x8000u; return bits ^ mask; } From 8054b150483a4dff02a241882e5bdf5db66deb49 Mon Sep 17 00:00:00 2001 From: Lyxot Date: Wed, 11 Feb 2026 01:45:54 +0800 Subject: [PATCH 12/23] fix: make radix select tie-order deterministic for small kernel --- benchmarks/python/radix_select_bench.py | 10 ++ mlx/backend/cuda/device/radix_select.cuh | 120 +++++++++++++++++++---- 2 files changed, 109 insertions(+), 21 deletions(-) diff --git a/benchmarks/python/radix_select_bench.py b/benchmarks/python/radix_select_bench.py index 757492f31e..e68059a7fb 100644 --- a/benchmarks/python/radix_select_bench.py +++ b/benchmarks/python/radix_select_bench.py @@ -156,6 +156,16 @@ def verify_tie_determinism(b=64, v=1024, k=None, dtype=mx.float32, axis=-1): f"{unique_outputs}/8 unique outputs for all-equal input " f"(shape=({b}, {v}), kth={k}, dtype={dtype})" ) + + # If deterministic, verify tie ordering matches original merge-sort order. + expected = mx.argsort(x, axis=axis) + mx.eval(expected) + expected_np = np.array(expected) + if not np.array_equal(outputs[0], expected_np): + raise AssertionError( + "Deterministic tie ordering does not match merge-sort baseline " + f"(shape=({b}, {v}), kth={k}, dtype={dtype})" + ) return True diff --git a/mlx/backend/cuda/device/radix_select.cuh b/mlx/backend/cuda/device/radix_select.cuh index b73dcce384..0c2aa2bfe6 100644 --- a/mlx/backend/cuda/device/radix_select.cuh +++ b/mlx/backend/cuda/device/radix_select.cuh @@ -472,36 +472,114 @@ __global__ void radix_select_small_kernel( } } - int less_thread_offset = block_exclusive_scan( + (void)block_exclusive_scan( local_less, shared_hist, &shared_count[0]); - int equal_thread_offset = block_exclusive_scan( + (void)block_exclusive_scan( local_equal, shared_hist, &shared_count[1]); - int q = tile_n / BLOCK_THREADS; - int r = tile_n - q * BLOCK_THREADS; - int prefix_total = int(threadIdx.x) * q + min(int(threadIdx.x), r); - int greater_thread_offset = - prefix_total - less_thread_offset - equal_thread_offset; - int less_count = shared_count[0]; int equal_count = shared_count[1]; - for (int i = threadIdx.x; i < tile_n; i += BLOCK_THREADS) { - UnsignedT key = shared_keys[i]; - int pos; - if (key < target_prefix) { - pos = less_thread_offset++; - } else if (key == target_prefix) { - pos = less_count + equal_thread_offset++; - } else { - pos = less_count + equal_count + greater_thread_offset++; + // Scatter in increasing i order to keep tie behavior aligned with merge sort. + constexpr int NUM_WARPS = BLOCK_THREADS / WARP_SIZE; + static_assert(3 * NUM_WARPS + 6 <= RADIX_SIZE); + + int lane = threadIdx.x & (WARP_SIZE - 1); + int warp = threadIdx.x / WARP_SIZE; + + int* warp_less = shared_hist; + int* warp_equal = shared_hist + NUM_WARPS; + int* warp_greater = shared_hist + 2 * NUM_WARPS; + int* iter_counts = shared_hist + 3 * NUM_WARPS; + int* running_bases = iter_counts + 3; + + if (threadIdx.x == 0) { + running_bases[0] = 0; + running_bases[1] = less_count; + running_bases[2] = less_count + equal_count; + } + __syncthreads(); + + for (int base_i = 0; base_i < tile_n; base_i += BLOCK_THREADS) { + int i = base_i + threadIdx.x; + bool active = i < tile_n; + + UnsignedT key = 0; + if (active) { + key = shared_keys[i]; } - if (ARG_PARTITION) { - row_output[pos * out_stride] = shared_idxs[i]; - } else { - row_output[pos * out_stride] = row_input[shared_idxs[i] * in_stride]; + bool is_less = active && (key < target_prefix); + bool is_equal = active && (key == target_prefix); + bool is_greater = active && !is_less && !is_equal; + + unsigned less_mask = __ballot_sync(0xFFFFFFFF, is_less); + unsigned equal_mask = __ballot_sync(0xFFFFFFFF, is_equal); + unsigned greater_mask = __ballot_sync(0xFFFFFFFF, is_greater); + + unsigned lane_mask = (1u << lane) - 1u; + int less_rank = __popc(less_mask & lane_mask); + int equal_rank = __popc(equal_mask & lane_mask); + int greater_rank = __popc(greater_mask & lane_mask); + + if (lane == 0) { + warp_less[warp] = __popc(less_mask); + warp_equal[warp] = __popc(equal_mask); + warp_greater[warp] = __popc(greater_mask); } + __syncthreads(); + + if (threadIdx.x == 0) { + int run = 0; + for (int w = 0; w < NUM_WARPS; ++w) { + int c = warp_less[w]; + warp_less[w] = run; + run += c; + } + iter_counts[0] = run; + + run = 0; + for (int w = 0; w < NUM_WARPS; ++w) { + int c = warp_equal[w]; + warp_equal[w] = run; + run += c; + } + iter_counts[1] = run; + + run = 0; + for (int w = 0; w < NUM_WARPS; ++w) { + int c = warp_greater[w]; + warp_greater[w] = run; + run += c; + } + iter_counts[2] = run; + } + __syncthreads(); + + if (active) { + int pos; + if (is_less) { + pos = running_bases[0] + warp_less[warp] + less_rank; + } else if (is_equal) { + pos = running_bases[1] + warp_equal[warp] + equal_rank; + } else { + pos = running_bases[2] + warp_greater[warp] + greater_rank; + } + + if (ARG_PARTITION) { + row_output[pos * out_stride] = shared_idxs[i]; + } else { + row_output[pos * out_stride] = row_input[shared_idxs[i] * in_stride]; + } + } + __syncthreads(); + + if (threadIdx.x == 0) { + running_bases[0] += iter_counts[0]; + running_bases[1] += iter_counts[1]; + running_bases[2] += iter_counts[2]; + } + __syncthreads(); } } From 999b3386092965c522ef9d0229ef3bad4aea5ba4 Mon Sep 17 00:00:00 2001 From: Lyxot Date: Wed, 11 Feb 2026 02:30:16 +0800 Subject: [PATCH 13/23] fall back to merge sort when radix nc_dim > MAX_NDIM --- mlx/backend/cuda/device/radix_select.cuh | 24 ++--- mlx/backend/cuda/sort.cu | 106 ++++------------------- 2 files changed, 31 insertions(+), 99 deletions(-) diff --git a/mlx/backend/cuda/device/radix_select.cuh b/mlx/backend/cuda/device/radix_select.cuh index 0c2aa2bfe6..31e816a6d6 100644 --- a/mlx/backend/cuda/device/radix_select.cuh +++ b/mlx/backend/cuda/device/radix_select.cuh @@ -352,9 +352,9 @@ __global__ void radix_select_small_kernel( int64_t out_stride, int64_t in_segment_stride, int64_t out_segment_stride, - const int32_t* nc_shape, - const int64_t* in_nc_strides, - const int64_t* out_nc_strides, + const __grid_constant__ Shape nc_shape, + const __grid_constant__ Strides in_nc_strides, + const __grid_constant__ Strides out_nc_strides, int nc_dim) { using Traits = RadixTraits; using UnsignedT = typename Traits::UnsignedT; @@ -380,10 +380,10 @@ __global__ void radix_select_small_kernel( row_input = input + row * in_segment_stride; row_output = output + row * out_segment_stride; } else { - int64_t in_block_idx = - elem_to_loc(int64_t(row), nc_shape, in_nc_strides, nc_dim); - int64_t out_block_idx = - elem_to_loc(int64_t(row), nc_shape, out_nc_strides, nc_dim); + int64_t in_block_idx = elem_to_loc( + int64_t(row), nc_shape.data(), in_nc_strides.data(), nc_dim); + int64_t out_block_idx = elem_to_loc( + int64_t(row), nc_shape.data(), out_nc_strides.data(), nc_dim); row_input = input + in_block_idx; row_output = output + out_block_idx; } @@ -595,9 +595,9 @@ __global__ void radix_select_large_streaming_kernel( int kth, int64_t in_stride, int64_t out_stride, - const int32_t* nc_shape, - const int64_t* in_nc_strides, - const int64_t* out_nc_strides, + const __grid_constant__ Shape nc_shape, + const __grid_constant__ Strides in_nc_strides, + const __grid_constant__ Strides out_nc_strides, int nc_dim) { using Traits = RadixTraits; using UnsignedT = typename Traits::UnsignedT; @@ -605,9 +605,9 @@ __global__ void radix_select_large_streaming_kernel( int row = blockIdx.y; int64_t in_block_idx = - elem_to_loc(int64_t(row), nc_shape, in_nc_strides, nc_dim); + elem_to_loc(int64_t(row), nc_shape.data(), in_nc_strides.data(), nc_dim); int64_t out_block_idx = - elem_to_loc(int64_t(row), nc_shape, out_nc_strides, nc_dim); + elem_to_loc(int64_t(row), nc_shape.data(), out_nc_strides.data(), nc_dim); const ValT* row_input = input + in_block_idx; OutT* row_output = output + out_block_idx; diff --git a/mlx/backend/cuda/sort.cu b/mlx/backend/cuda/sort.cu index f0f774e492..f4a9bf7ccc 100644 --- a/mlx/backend/cuda/sort.cu +++ b/mlx/backend/cuda/sort.cu @@ -1086,46 +1086,9 @@ void gpu_radix_partition_small( encoder.set_input_array(in); encoder.set_output_array(out); - const int32_t* nc_shape_ptr = nullptr; - const int64_t* in_nc_strides_ptr = nullptr; - const int64_t* out_nc_strides_ptr = nullptr; - if (!contiguous && nc_dim > 0) { - array nc_shape_dev({nc_dim}, int32, nullptr, {}); - array in_nc_strides_dev({nc_dim}, int64, nullptr, {}); - array out_nc_strides_dev({nc_dim}, int64, nullptr, {}); - nc_shape_dev.set_data(cu::malloc_async(nc_shape_dev.nbytes(), encoder)); - in_nc_strides_dev.set_data( - cu::malloc_async(in_nc_strides_dev.nbytes(), encoder)); - out_nc_strides_dev.set_data( - cu::malloc_async(out_nc_strides_dev.nbytes(), encoder)); - - CHECK_CUDA_ERROR(cudaMemcpyAsync( - gpu_ptr(nc_shape_dev), - nc_shape.data(), - nc_shape_dev.nbytes(), - cudaMemcpyHostToDevice, - encoder.stream())); - CHECK_CUDA_ERROR(cudaMemcpyAsync( - gpu_ptr(in_nc_strides_dev), - in_nc_str.data(), - in_nc_strides_dev.nbytes(), - cudaMemcpyHostToDevice, - encoder.stream())); - CHECK_CUDA_ERROR(cudaMemcpyAsync( - gpu_ptr(out_nc_strides_dev), - out_nc_str.data(), - out_nc_strides_dev.nbytes(), - cudaMemcpyHostToDevice, - encoder.stream())); - - nc_shape_ptr = gpu_ptr(nc_shape_dev); - in_nc_strides_ptr = gpu_ptr(in_nc_strides_dev); - out_nc_strides_ptr = gpu_ptr(out_nc_strides_dev); - - encoder.add_temporary(nc_shape_dev); - encoder.add_temporary(in_nc_strides_dev); - encoder.add_temporary(out_nc_strides_dev); - } + auto nc_shape_param = const_param(nc_shape); + auto in_nc_strides_param = const_param(in_nc_str); + auto out_nc_strides_param = const_param(out_nc_str); dispatch_all_types(in.dtype(), [&](auto type_tag) { using CTYPE = MLX_GET_TYPE(type_tag); @@ -1189,9 +1152,9 @@ void gpu_radix_partition_small( out_stride_sorted_axis, in_stride_segment_axis, out_stride_segment_axis, - nc_shape_ptr, - in_nc_strides_ptr, - out_nc_strides_ptr, + nc_shape_param, + in_nc_strides_param, + out_nc_strides_param, nc_dim); }); }); @@ -1552,46 +1515,9 @@ void gpu_radix_partition_large( encoder.set_input_array(in); encoder.set_output_array(out); - const int32_t* nc_shape_ptr = nullptr; - const int64_t* in_nc_strides_ptr = nullptr; - const int64_t* out_nc_strides_ptr = nullptr; - if (nc_dim > 0) { - array nc_shape_dev({nc_dim}, int32, nullptr, {}); - array in_nc_strides_dev({nc_dim}, int64, nullptr, {}); - array out_nc_strides_dev({nc_dim}, int64, nullptr, {}); - nc_shape_dev.set_data(cu::malloc_async(nc_shape_dev.nbytes(), encoder)); - in_nc_strides_dev.set_data( - cu::malloc_async(in_nc_strides_dev.nbytes(), encoder)); - out_nc_strides_dev.set_data( - cu::malloc_async(out_nc_strides_dev.nbytes(), encoder)); - - CHECK_CUDA_ERROR(cudaMemcpyAsync( - gpu_ptr(nc_shape_dev), - nc_shape.data(), - nc_shape_dev.nbytes(), - cudaMemcpyHostToDevice, - encoder.stream())); - CHECK_CUDA_ERROR(cudaMemcpyAsync( - gpu_ptr(in_nc_strides_dev), - in_nc_str.data(), - in_nc_strides_dev.nbytes(), - cudaMemcpyHostToDevice, - encoder.stream())); - CHECK_CUDA_ERROR(cudaMemcpyAsync( - gpu_ptr(out_nc_strides_dev), - out_nc_str.data(), - out_nc_strides_dev.nbytes(), - cudaMemcpyHostToDevice, - encoder.stream())); - - nc_shape_ptr = gpu_ptr(nc_shape_dev); - in_nc_strides_ptr = gpu_ptr(in_nc_strides_dev); - out_nc_strides_ptr = gpu_ptr(out_nc_strides_dev); - - encoder.add_temporary(nc_shape_dev); - encoder.add_temporary(in_nc_strides_dev); - encoder.add_temporary(out_nc_strides_dev); - } + auto nc_shape_param = const_param(nc_shape); + auto in_nc_strides_param = const_param(in_nc_str); + auto out_nc_strides_param = const_param(out_nc_str); dispatch_all_types(in.dtype(), [&](auto type_tag) { using CTYPE = MLX_GET_TYPE(type_tag); @@ -1625,9 +1551,9 @@ void gpu_radix_partition_large( kth, in_stride_sorted_axis, out_stride_sorted_axis, - nc_shape_ptr, - in_nc_strides_ptr, - out_nc_strides_ptr, + nc_shape_param, + in_nc_strides_param, + out_nc_strides_param, nc_dim); }); } else { @@ -1647,6 +1573,12 @@ void gpu_radix_partition( int axis = axis_ < 0 ? axis_ + in.ndim() : axis_; int size_sorted_axis = in.shape(axis); int kth = kth_ < 0 ? kth_ + size_sorted_axis : kth_; + int nc_dim = static_cast(in.ndim()) - 1; + + // Fixed-size const_param metadata is capped by MAX_NDIM. + if (nc_dim > MAX_NDIM) { + return gpu_merge_sort(s, in, out, axis, arg_partition); + } // Dispatch based on size if (size_sorted_axis <= 2048) { @@ -1690,4 +1622,4 @@ void Partition::eval_gpu(const std::vector& inputs, array& out) { gpu_radix_partition(stream(), inputs[0], out, axis_, kth_, false); } -} // namespace mlx::core \ No newline at end of file +} // namespace mlx::core From d00af6216a64c8ffaad4cf438917e5b3bf00da76 Mon Sep 17 00:00:00 2001 From: Lyxot Date: Mon, 16 Feb 2026 05:53:13 +0800 Subject: [PATCH 14/23] tune radix-select small-kernel launch and radix width --- benchmarks/python/radix_select_bench.py | 36 +++++- mlx/backend/cuda/device/radix_select.cuh | 40 +++---- mlx/backend/cuda/sort.cu | 133 ++++++++++++++++------- 3 files changed, 143 insertions(+), 66 deletions(-) diff --git a/benchmarks/python/radix_select_bench.py b/benchmarks/python/radix_select_bench.py index e68059a7fb..30bfe978db 100644 --- a/benchmarks/python/radix_select_bench.py +++ b/benchmarks/python/radix_select_bench.py @@ -170,14 +170,27 @@ def verify_tie_determinism(b=64, v=1024, k=None, dtype=mx.float32, axis=-1): def sweep_boundary( - dtype=mx.bfloat16, k_ratio=0.004, warmup=10, iterations=50, verify=False + dtype=mx.bfloat16, + k_ratio=0.004, + warmup=10, + iterations=50, + verify=False, + small_kernel=False, ): dtype_name = str(dtype).split(".")[-1] print(f"\nDtype={dtype_name} k=vocab*{k_ratio:.3f}") print() - batch_sizes = [1, 2, 4, 8, 16, 32, 48, 64, 96, 128, 256, 512, 1024, 2048] - vocab_sizes = [512, 1024, 2048, 4096, 8192, 16384, 32768, 65536, 131072] + batch_sizes = ( + [1, 4, 8, 16, 32, 48, 64, 128, 256, 512, 1024, 2048, 4096, 8192] + if small_kernel + else [1, 2, 4, 8, 16, 32, 48, 64, 96, 128, 256, 512, 1024, 2048] + ) + vocab_sizes = ( + [32, 64, 96, 128, 160, 192, 256, 384, 512, 1024, 2048] + if small_kernel + else [3072, 4096, 8192, 16384, 32768, 65536, 131072] + ) col_w = 10 print(f"{'':>8}", end="") @@ -227,6 +240,11 @@ def main(): action="store_true", help="Enable boundary sweep test (default: disabled)", ) + parser.add_argument( + "--small-kernel-sweep", + action="store_true", + help="Enable small-kernel-only sweep (axis <= 2048 by default)", + ) parser.add_argument( "--verify", action="store_true", @@ -267,7 +285,11 @@ def main(): print(f"Error: {e}") return - if not args.boundary_sweep: + if args.boundary_sweep and args.small_kernel_sweep: + print("Error: choose only one of --boundary-sweep or --small-kernel-sweep") + return + + if not args.boundary_sweep and not args.small_kernel_sweep: if args.verify: print("\n1. Correctness Verification") print("-" * 40) @@ -322,11 +344,13 @@ def main(): except Exception as e: print(f"b={b}, v={v}, k={k}: Error - {e}") - if args.boundary_sweep: + if args.boundary_sweep or args.small_kernel_sweep: print("\nBoundary Sweep" + (" (with verification)" if args.verify else "")) print("-" * 70) for dtype, dtype_name in dtypes: - sweep_boundary(dtype, verify=args.verify) + sweep_boundary( + dtype, verify=args.verify, small_kernel=args.small_kernel_sweep + ) print("\n" + "=" * 70) print("Benchmark Complete") diff --git a/mlx/backend/cuda/device/radix_select.cuh b/mlx/backend/cuda/device/radix_select.cuh index 31e816a6d6..7de9585aea 100644 --- a/mlx/backend/cuda/device/radix_select.cuh +++ b/mlx/backend/cuda/device/radix_select.cuh @@ -20,6 +20,8 @@ namespace mlx::core::cu { // Radix configuration constexpr int RADIX_BITS = 8; constexpr int RADIX_SIZE = 1 << RADIX_BITS; // 256 bins +constexpr int SMALL_RADIX_BITS = 5; +constexpr int SMALL_RADIX_SIZE = 1 << SMALL_RADIX_BITS; // 32 bins /////////////////////////////////////////////////////////////////////////////// // Bit manipulation for radix sorting @@ -328,12 +330,13 @@ __device__ __forceinline__ int block_exclusive_scan( /////////////////////////////////////////////////////////////////////////////// // Helper to calculate required shared memory size for small kernel -template +template constexpr size_t radix_select_small_shared_mem_size() { + constexpr int NUM_WARPS = BLOCK_THREADS / WARP_SIZE; return TILE_SIZE * sizeof(UnsignedT) + // shared_keys TILE_SIZE * sizeof(uint32_t) + // shared_idxs - RADIX_SIZE * sizeof(int) + // shared_hist - 2 * sizeof(int); // shared_count + SMALL_RADIX_SIZE * sizeof(int) + // shared_hist + (2 + 3 * NUM_WARPS + 6) * sizeof(int); // shared_count + scatter scratch } template < @@ -360,7 +363,8 @@ __global__ void radix_select_small_kernel( using UnsignedT = typename Traits::UnsignedT; constexpr int TILE_SIZE = BLOCK_THREADS * ITEMS_PER_THREAD; - constexpr int NUM_PASSES = (Traits::BITS + RADIX_BITS - 1) / RADIX_BITS; + constexpr int NUM_PASSES = + (Traits::BITS + SMALL_RADIX_BITS - 1) / SMALL_RADIX_BITS; // Dynamic shared memory layout extern __shared__ char shared_mem[]; @@ -369,7 +373,14 @@ __global__ void radix_select_small_kernel( UnsignedT* shared_keys = reinterpret_cast(shared_mem); uint32_t* shared_idxs = reinterpret_cast(shared_keys + TILE_SIZE); int* shared_hist = reinterpret_cast(shared_idxs + TILE_SIZE); - int* shared_count = shared_hist + RADIX_SIZE; + int* shared_count = shared_hist + SMALL_RADIX_SIZE; + constexpr int NUM_WARPS = BLOCK_THREADS / WARP_SIZE; + int* scatter_scratch = shared_count + 2; + int* warp_less = scatter_scratch; + int* warp_equal = warp_less + NUM_WARPS; + int* warp_greater = warp_equal + NUM_WARPS; + int* iter_counts = warp_greater + NUM_WARPS; + int* running_bases = iter_counts + 3; int row = blockIdx.y; @@ -413,10 +424,10 @@ __global__ void radix_select_small_kernel( UnsignedT prefix_mask = 0; for (int pass = NUM_PASSES - 1; pass >= 0; pass--) { - int start_bit = pass * RADIX_BITS; + int start_bit = pass * SMALL_RADIX_BITS; // Clear histogram - for (int i = threadIdx.x; i < RADIX_SIZE; i += BLOCK_THREADS) { + for (int i = threadIdx.x; i < SMALL_RADIX_SIZE; i += BLOCK_THREADS) { shared_hist[i] = 0; } __syncthreads(); @@ -425,7 +436,7 @@ __global__ void radix_select_small_kernel( for (int i = threadIdx.x; i < tile_n; i += BLOCK_THREADS) { UnsignedT key = shared_keys[i]; if ((key & prefix_mask) == target_prefix) { - int digit = extract_digit(key, start_bit, RADIX_BITS); + int digit = extract_digit(key, start_bit, SMALL_RADIX_BITS); atomicAdd(&shared_hist[digit], 1); } } @@ -435,7 +446,7 @@ __global__ void radix_select_small_kernel( if (threadIdx.x == 0) { int cumsum = 0; int target_bin = 0; - for (int bin = 0; bin < RADIX_SIZE; bin++) { + for (int bin = 0; bin < SMALL_RADIX_SIZE; bin++) { int count = shared_hist[bin]; if (cumsum + count >= k) { target_bin = bin; @@ -452,7 +463,7 @@ __global__ void radix_select_small_kernel( int target_bin = shared_count[0]; k = shared_count[1]; - UnsignedT digit_mask = UnsignedT((1 << RADIX_BITS) - 1) << start_bit; + UnsignedT digit_mask = UnsignedT((1 << SMALL_RADIX_BITS) - 1) << start_bit; target_prefix |= UnsignedT(target_bin) << start_bit; prefix_mask |= digit_mask; @@ -481,18 +492,9 @@ __global__ void radix_select_small_kernel( int equal_count = shared_count[1]; // Scatter in increasing i order to keep tie behavior aligned with merge sort. - constexpr int NUM_WARPS = BLOCK_THREADS / WARP_SIZE; - static_assert(3 * NUM_WARPS + 6 <= RADIX_SIZE); - int lane = threadIdx.x & (WARP_SIZE - 1); int warp = threadIdx.x / WARP_SIZE; - int* warp_less = shared_hist; - int* warp_equal = shared_hist + NUM_WARPS; - int* warp_greater = shared_hist + 2 * NUM_WARPS; - int* iter_counts = shared_hist + 3 * NUM_WARPS; - int* running_bases = iter_counts + 3; - if (threadIdx.x == 0) { running_bases[0] = 0; running_bases[1] = less_count; diff --git a/mlx/backend/cuda/sort.cu b/mlx/backend/cuda/sort.cu index f4a9bf7ccc..f350dd607e 100644 --- a/mlx/backend/cuda/sort.cu +++ b/mlx/backend/cuda/sort.cu @@ -1095,11 +1095,28 @@ void gpu_radix_partition_small( if constexpr (!std::is_same_v) { using ValT = cuda_type_t; - constexpr int BLOCK_THREADS = 256; - constexpr int ITEMS_PER_THREAD = 8; + int block_threads = 256; + if (size_sorted_axis <= 128) { + block_threads = 16; + } else if (size_sorted_axis <= 256) { + block_threads = 32; + } else if (size_sorted_axis <= 512) { + block_threads = 64; + } else if (size_sorted_axis <= 1024) { + block_threads = 128; + } - dim3 grid(1, n_rows, 1); - dim3 block(BLOCK_THREADS, 1, 1); + int items_per_thread = + (size_sorted_axis + block_threads - 1) / block_threads; + if (items_per_thread <= 1) { + items_per_thread = 1; + } else if (items_per_thread <= 2) { + items_per_thread = 2; + } else if (items_per_thread <= 4) { + items_per_thread = 4; + } else { + items_per_thread = 8; + } dispatch_bool(arg_partition, [&](auto arg_tag) { constexpr bool ARG_PARTITION = decltype(arg_tag)::value; @@ -1119,43 +1136,59 @@ void gpu_radix_partition_small( } } - dispatch_bool(contiguous, [&](auto contiguous_tag) { - constexpr bool USE_SIMPLE_STRIDE = decltype(contiguous_tag)::value; - - auto kernel = cu::radix_select_small_kernel< - ValT, - OutT, - ARG_PARTITION, - USE_SIMPLE_STRIDE, - BLOCK_THREADS, - ITEMS_PER_THREAD>; - - // Calculate dynamic shared memory size - using UnsignedT = typename cu::RadixTraits::UnsignedT; - constexpr int TILE_SIZE_VAL = BLOCK_THREADS * ITEMS_PER_THREAD; - constexpr size_t shared_mem_bytes = - TILE_SIZE_VAL * sizeof(UnsignedT) + // shared_keys - TILE_SIZE_VAL * sizeof(uint32_t) + // shared_idxs - 256 * sizeof(int) + // shared_hist (RADIX_SIZE=256) - 2 * sizeof(int); // shared_count - - encoder.add_kernel_node( - kernel, - grid, - block, - shared_mem_bytes, - gpu_ptr(in), - gpu_ptr(out), - kth, - size_sorted_axis, - in_stride_sorted_axis, - out_stride_sorted_axis, - in_stride_segment_axis, - out_stride_segment_axis, - nc_shape_param, - in_nc_strides_param, - out_nc_strides_param, - nc_dim); + dispatch_block_dim(block_threads, [&](auto block_dim) { + constexpr int BLOCK_THREADS = block_dim(); + dim3 grid(1, n_rows, 1); + dim3 block(BLOCK_THREADS, 1, 1); + + dispatch_radix_items_per_thread( + items_per_thread, [&](auto items_per_thread_tag) { + constexpr int ITEMS_PER_THREAD = items_per_thread_tag(); + + dispatch_bool(contiguous, [&](auto contiguous_tag) { + constexpr bool USE_SIMPLE_STRIDE = + decltype(contiguous_tag)::value; + + auto kernel = cu::radix_select_small_kernel< + ValT, + OutT, + ARG_PARTITION, + USE_SIMPLE_STRIDE, + BLOCK_THREADS, + ITEMS_PER_THREAD>; + + // Calculate dynamic shared memory size + using UnsignedT = typename cu::RadixTraits::UnsignedT; + constexpr int TILE_SIZE_VAL = + BLOCK_THREADS * ITEMS_PER_THREAD; + constexpr int NUM_WARPS = BLOCK_THREADS / WARP_SIZE; + constexpr size_t shared_mem_bytes = + TILE_SIZE_VAL * sizeof(UnsignedT) + // shared_keys + TILE_SIZE_VAL * sizeof(uint32_t) + // shared_idxs + cu::SMALL_RADIX_SIZE * + sizeof(int) + // shared_hist for small kernel + (2 + 3 * NUM_WARPS + 6) * + sizeof(int); // shared_count + scatter scratch + + encoder.add_kernel_node( + kernel, + grid, + block, + shared_mem_bytes, + gpu_ptr(in), + gpu_ptr(out), + kth, + size_sorted_axis, + in_stride_sorted_axis, + out_stride_sorted_axis, + in_stride_segment_axis, + out_stride_segment_axis, + nc_shape_param, + in_nc_strides_param, + out_nc_strides_param, + nc_dim); + }); + }); }); }); } else { @@ -1187,6 +1220,24 @@ struct RadixLaunchPlan { } }; +template +void dispatch_radix_items_per_thread(int items_per_thread, F&& f) { + switch (items_per_thread) { + case 1: + f(std::integral_constant{}); + break; + case 2: + f(std::integral_constant{}); + break; + case 4: + f(std::integral_constant{}); + break; + default: + f(std::integral_constant{}); + break; + } +} + RadixLaunchPlan make_radix_tiled_launch_plan( const Stream& s, int n_rows, From 9de3d09a2b3f206c3a8ebd8a46953afd43e04259 Mon Sep 17 00:00:00 2001 From: Lyxot Date: Mon, 16 Feb 2026 06:50:58 +0800 Subject: [PATCH 15/23] remove multi-row-per-block --- mlx/backend/cuda/device/radix_select.cuh | 441 ++++++++++------------- mlx/backend/cuda/sort.cu | 72 +--- 2 files changed, 196 insertions(+), 317 deletions(-) diff --git a/mlx/backend/cuda/device/radix_select.cuh b/mlx/backend/cuda/device/radix_select.cuh index 7de9585aea..a0ec0047c4 100644 --- a/mlx/backend/cuda/device/radix_select.cuh +++ b/mlx/backend/cuda/device/radix_select.cuh @@ -329,16 +329,6 @@ __device__ __forceinline__ int block_exclusive_scan( // Single-pass Radix Select for small arrays (fits in shared memory) /////////////////////////////////////////////////////////////////////////////// -// Helper to calculate required shared memory size for small kernel -template -constexpr size_t radix_select_small_shared_mem_size() { - constexpr int NUM_WARPS = BLOCK_THREADS / WARP_SIZE; - return TILE_SIZE * sizeof(UnsignedT) + // shared_keys - TILE_SIZE * sizeof(uint32_t) + // shared_idxs - SMALL_RADIX_SIZE * sizeof(int) + // shared_hist - (2 + 3 * NUM_WARPS + 6) * sizeof(int); // shared_count + scatter scratch -} - template < typename ValT, typename OutT, @@ -857,7 +847,7 @@ __global__ void radix_select_large_streaming_kernel( // // These kernels run with a 2D launch: // - x-dimension tiles one row across multiple blocks (multi-block-per-row) -// - y-dimension packs multiple rows into one block group (multi-row-per-block) +// - y-dimension selects the row index /////////////////////////////////////////////////////////////////////////////// template @@ -866,21 +856,16 @@ __global__ void radix_select_tiled_init_state_kernel( UnsignedT* prefix_mask, int* k_values, int* row_hist, - int kth, - int n_rows, - int rows_per_block) { - int row_start = blockIdx.y * rows_per_block; - int row_end = min(n_rows, row_start + rows_per_block); - for (int row = row_start; row < row_end; ++row) { - if (threadIdx.x == 0) { - target_prefix[row] = UnsignedT(0); - prefix_mask[row] = UnsignedT(0); - k_values[row] = kth + 1; - } - int* hist = row_hist + row * RADIX_SIZE; - for (int i = threadIdx.x; i < RADIX_SIZE; i += blockDim.x) { - hist[i] = 0; - } + int kth) { + int row = blockIdx.y; + if (threadIdx.x == 0) { + target_prefix[row] = UnsignedT(0); + prefix_mask[row] = UnsignedT(0); + k_values[row] = kth + 1; + } + int* hist = row_hist + row * RADIX_SIZE; + for (int i = threadIdx.x; i < RADIX_SIZE; i += blockDim.x) { + hist[i] = 0; } } @@ -894,49 +879,43 @@ __global__ void radix_select_tiled_histogram_kernel( const typename RadixTraits::UnsignedT* prefix_mask, int start_bit, int blocks_per_row, - int n_rows, - int rows_per_block, int* row_hist) { using Traits = RadixTraits; using UnsignedT = typename Traits::UnsignedT; int block_in_row = blockIdx.x; - int row_start = blockIdx.y * rows_per_block; - int row_end = min(n_rows, row_start + rows_per_block); + int row = blockIdx.y; int chunk = (n + blocks_per_row - 1) / blocks_per_row; int start = block_in_row * chunk; int end = min(n, start + chunk); - if (start >= n || row_start >= row_end) { + if (start >= n) { return; } __shared__ int shared_hist[RADIX_SIZE]; - for (int row = row_start; row < row_end; ++row) { - const ValT* row_input = input + row * in_segment_stride; - UnsignedT row_prefix = target_prefix[row]; - UnsignedT row_mask = prefix_mask[row]; + const ValT* row_input = input + row * in_segment_stride; + UnsignedT row_prefix = target_prefix[row]; + UnsignedT row_mask = prefix_mask[row]; - for (int i = threadIdx.x; i < RADIX_SIZE; i += BLOCK_THREADS) { - shared_hist[i] = 0; - } - __syncthreads(); + for (int i = threadIdx.x; i < RADIX_SIZE; i += BLOCK_THREADS) { + shared_hist[i] = 0; + } + __syncthreads(); - for (int i = start + threadIdx.x; i < end; i += BLOCK_THREADS) { - ValT val = row_input[i * in_stride]; - UnsignedT key = radix_key_with_nan_last(val); - if ((key & row_mask) == row_prefix) { - int digit = extract_digit(key, start_bit, RADIX_BITS); - atomicAdd(&shared_hist[digit], 1); - } + for (int i = start + threadIdx.x; i < end; i += BLOCK_THREADS) { + ValT val = row_input[i * in_stride]; + UnsignedT key = radix_key_with_nan_last(val); + if ((key & row_mask) == row_prefix) { + int digit = extract_digit(key, start_bit, RADIX_BITS); + atomicAdd(&shared_hist[digit], 1); } - __syncthreads(); + } + __syncthreads(); - int* hist = row_hist + row * RADIX_SIZE; - for (int i = threadIdx.x; i < RADIX_SIZE; i += BLOCK_THREADS) { - atomicAdd(&hist[i], shared_hist[i]); - } - __syncthreads(); + int* hist = row_hist + row * RADIX_SIZE; + for (int i = threadIdx.x; i < RADIX_SIZE; i += BLOCK_THREADS) { + atomicAdd(&hist[i], shared_hist[i]); } } @@ -947,42 +926,36 @@ __global__ void radix_select_tiled_select_bin_kernel( UnsignedT* prefix_mask, int* k_values, int clear_hist_for_next_pass, - int start_bit, - int n_rows, - int rows_per_block) { - int row_start = blockIdx.y * rows_per_block; - int row_end = min(n_rows, row_start + rows_per_block); - for (int row = row_start; row < row_end; ++row) { - int* hist = row_hist + row * RADIX_SIZE; + int start_bit) { + int row = blockIdx.y; + int* hist = row_hist + row * RADIX_SIZE; - if (threadIdx.x == 0) { - int k = k_values[row]; - int cumsum = 0; - int target_bin = 0; - for (int bin = 0; bin < RADIX_SIZE; bin++) { - int count = hist[bin]; - if (cumsum + count >= k) { - target_bin = bin; - k -= cumsum; - break; - } - cumsum += count; + if (threadIdx.x == 0) { + int k = k_values[row]; + int cumsum = 0; + int target_bin = 0; + for (int bin = 0; bin < RADIX_SIZE; bin++) { + int count = hist[bin]; + if (cumsum + count >= k) { + target_bin = bin; + k -= cumsum; + break; } - k_values[row] = k; - - UnsignedT digit_mask = - (UnsignedT((UnsignedT(1) << RADIX_BITS) - UnsignedT(1)) << start_bit); - target_prefix[row] |= UnsignedT(target_bin) << start_bit; - prefix_mask[row] |= digit_mask; + cumsum += count; } - __syncthreads(); + k_values[row] = k; - if (clear_hist_for_next_pass) { - for (int i = threadIdx.x; i < RADIX_SIZE; i += blockDim.x) { - hist[i] = 0; - } + UnsignedT digit_mask = + (UnsignedT((UnsignedT(1) << RADIX_BITS) - UnsignedT(1)) << start_bit); + target_prefix[row] |= UnsignedT(target_bin) << start_bit; + prefix_mask[row] |= digit_mask; + } + __syncthreads(); + + if (clear_hist_for_next_pass) { + for (int i = threadIdx.x; i < RADIX_SIZE; i += blockDim.x) { + hist[i] = 0; } - __syncthreads(); } } @@ -994,111 +967,56 @@ __global__ void radix_select_tiled_count_kernel( int64_t in_segment_stride, const typename RadixTraits::UnsignedT* target_prefix, int blocks_per_row, - int n_rows, - int rows_per_block, int* block_less, int* block_equal) { using Traits = RadixTraits; using UnsignedT = typename Traits::UnsignedT; int block_in_row = blockIdx.x; - int row_start = blockIdx.y * rows_per_block; - int row_end = min(n_rows, row_start + rows_per_block); + int row = blockIdx.y; int chunk = (n + blocks_per_row - 1) / blocks_per_row; int start = block_in_row * chunk; int end = min(n, start + chunk); - if (row_start >= row_end) { + if (start >= n) { return; } __shared__ int shared_counts[2]; - for (int row = row_start; row < row_end; ++row) { - int block_idx = row * blocks_per_row + block_in_row; - const ValT* row_input = input + row * in_segment_stride; - UnsignedT row_prefix = target_prefix[row]; - - int local_less = 0; - int local_equal = 0; - for (int i = start + threadIdx.x; i < end; i += BLOCK_THREADS) { - ValT val = row_input[i * in_stride]; - UnsignedT key = radix_key_with_nan_last(val); - if (key < row_prefix) { - local_less++; - } else if (key == row_prefix) { - local_equal++; - } - } + int block_idx = row * blocks_per_row + block_in_row; + const ValT* row_input = input + row * in_segment_stride; + UnsignedT row_prefix = target_prefix[row]; - local_less = warp_reduce_sum(local_less); - local_equal = warp_reduce_sum(local_equal); - - if (threadIdx.x == 0) { - shared_counts[0] = 0; - shared_counts[1] = 0; + int local_less = 0; + int local_equal = 0; + for (int i = start + threadIdx.x; i < end; i += BLOCK_THREADS) { + ValT val = row_input[i * in_stride]; + UnsignedT key = radix_key_with_nan_last(val); + if (key < row_prefix) { + local_less++; + } else if (key == row_prefix) { + local_equal++; } - __syncthreads(); + } - if ((threadIdx.x % WARP_SIZE) == 0) { - atomicAdd(&shared_counts[0], local_less); - atomicAdd(&shared_counts[1], local_equal); - } - __syncthreads(); + local_less = warp_reduce_sum(local_less); + local_equal = warp_reduce_sum(local_equal); - if (threadIdx.x == 0) { - block_less[block_idx] = shared_counts[0]; - block_equal[block_idx] = shared_counts[1]; - } - __syncthreads(); + if (threadIdx.x == 0) { + shared_counts[0] = 0; + shared_counts[1] = 0; } -} + __syncthreads(); -__global__ void radix_select_tiled_prefix_kernel( - int n, - int blocks_per_row, - int n_rows, - int rows_per_block, - const int* block_less, - const int* block_equal, - int* less_base, - int* equal_base, - int* greater_base) { - if (threadIdx.x != 0) { - return; + if ((threadIdx.x % WARP_SIZE) == 0) { + atomicAdd(&shared_counts[0], local_less); + atomicAdd(&shared_counts[1], local_equal); } + __syncthreads(); - int row_start = blockIdx.y * rows_per_block; - int row_end = min(n_rows, row_start + rows_per_block); - int chunk = (n + blocks_per_row - 1) / blocks_per_row; - - for (int row = row_start; row < row_end; ++row) { - int row_off = row * blocks_per_row; - int total_less = 0; - int total_equal = 0; - for (int b = 0; b < blocks_per_row; b++) { - int idx = row_off + b; - total_less += block_less[idx]; - total_equal += block_equal[idx]; - } - - int run_less = 0; - int run_equal = 0; - int run_greater = 0; - for (int b = 0; b < blocks_per_row; b++) { - int idx = row_off + b; - less_base[idx] = run_less; - equal_base[idx] = total_less + run_equal; - greater_base[idx] = total_less + total_equal + run_greater; - - int start = b * chunk; - int end = min(n, start + chunk); - int chunk_size = max(0, end - start); - int greater_count = chunk_size - block_less[idx] - block_equal[idx]; - - run_less += block_less[idx]; - run_equal += block_equal[idx]; - run_greater += greater_count; - } + if (threadIdx.x == 0) { + block_less[block_idx] = shared_counts[0]; + block_equal[block_idx] = shared_counts[1]; } } @@ -1113,22 +1031,18 @@ __global__ void radix_select_tiled_scatter_kernel( int64_t out_segment_stride, const typename RadixTraits::UnsignedT* target_prefix, int blocks_per_row, - int n_rows, - int rows_per_block, - const int* less_base, - const int* equal_base, - const int* greater_base) { + const int* block_less, + const int* block_equal) { using Traits = RadixTraits; using UnsignedT = typename Traits::UnsignedT; int block_in_row = blockIdx.x; - int row_start = blockIdx.y * rows_per_block; - int row_end = min(n_rows, row_start + rows_per_block); + int row = blockIdx.y; int chunk = (n + blocks_per_row - 1) / blocks_per_row; int start = block_in_row * chunk; int end = min(n, start + chunk); - if (start >= n || row_start >= row_end) { + if (start >= n) { return; } @@ -1144,100 +1058,119 @@ __global__ void radix_select_tiled_scatter_kernel( int* warp_equal = shared_warp_offsets + NUM_WARPS; int* warp_greater = shared_warp_offsets + 2 * NUM_WARPS; - for (int row = row_start; row < row_end; ++row) { - int block_idx = row * blocks_per_row + block_in_row; - const ValT* row_input = input + row * in_segment_stride; - OutT* row_output = output + row * out_segment_stride; - UnsignedT row_prefix = target_prefix[row]; + int row_off = row * blocks_per_row; + const ValT* row_input = input + row * in_segment_stride; + OutT* row_output = output + row * out_segment_stride; + UnsignedT row_prefix = target_prefix[row]; - if (threadIdx.x == 0) { - shared_running_bases[0] = less_base[block_idx]; - shared_running_bases[1] = equal_base[block_idx]; - shared_running_bases[2] = greater_base[block_idx]; + if (threadIdx.x == 0) { + int total_less = 0; + int total_equal = 0; + for (int b = 0; b < blocks_per_row; ++b) { + int idx = row_off + b; + total_less += block_less[idx]; + total_equal += block_equal[idx]; } - __syncthreads(); - for (int base_i = start; base_i < end; base_i += BLOCK_THREADS) { - int i = base_i + threadIdx.x; - bool active = i < end; + int run_less = 0; + int run_equal = 0; + int run_greater = 0; + for (int b = 0; b < block_in_row; ++b) { + int idx = row_off + b; + int b_start = b * chunk; + int b_end = min(n, b_start + chunk); + int b_chunk_size = max(0, b_end - b_start); + int greater_count = b_chunk_size - block_less[idx] - block_equal[idx]; + run_less += block_less[idx]; + run_equal += block_equal[idx]; + run_greater += greater_count; + } - ValT val{}; - UnsignedT key = 0; - if (active) { - val = row_input[i * in_stride]; - key = radix_key_with_nan_last(val); - } + shared_running_bases[0] = run_less; + shared_running_bases[1] = total_less + run_equal; + shared_running_bases[2] = total_less + total_equal + run_greater; + } + __syncthreads(); - bool is_less = active && (key < row_prefix); - bool is_equal = active && (key == row_prefix); - bool is_greater = active && !is_less && !is_equal; + for (int base_i = start; base_i < end; base_i += BLOCK_THREADS) { + int i = base_i + threadIdx.x; + bool active = i < end; - unsigned less_mask = __ballot_sync(0xFFFFFFFF, is_less); - unsigned equal_mask = __ballot_sync(0xFFFFFFFF, is_equal); - unsigned greater_mask = __ballot_sync(0xFFFFFFFF, is_greater); + ValT val{}; + UnsignedT key = 0; + if (active) { + val = row_input[i * in_stride]; + key = radix_key_with_nan_last(val); + } - unsigned lane_mask = (1u << lane) - 1u; - int less_rank = __popc(less_mask & lane_mask); - int equal_rank = __popc(equal_mask & lane_mask); - int greater_rank = __popc(greater_mask & lane_mask); + bool is_less = active && (key < row_prefix); + bool is_equal = active && (key == row_prefix); + bool is_greater = active && !is_less && !is_equal; - if (lane == 0) { - warp_less[warp] = __popc(less_mask); - warp_equal[warp] = __popc(equal_mask); - warp_greater[warp] = __popc(greater_mask); - } - __syncthreads(); - - if (threadIdx.x == 0) { - int run = 0; - for (int w = 0; w < NUM_WARPS; ++w) { - int c = warp_less[w]; - warp_less[w] = run; - run += c; - } - shared_iter_counts[0] = run; + unsigned less_mask = __ballot_sync(0xFFFFFFFF, is_less); + unsigned equal_mask = __ballot_sync(0xFFFFFFFF, is_equal); + unsigned greater_mask = __ballot_sync(0xFFFFFFFF, is_greater); - run = 0; - for (int w = 0; w < NUM_WARPS; ++w) { - int c = warp_equal[w]; - warp_equal[w] = run; - run += c; - } - shared_iter_counts[1] = run; + unsigned lane_mask = (1u << lane) - 1u; + int less_rank = __popc(less_mask & lane_mask); + int equal_rank = __popc(equal_mask & lane_mask); + int greater_rank = __popc(greater_mask & lane_mask); - run = 0; - for (int w = 0; w < NUM_WARPS; ++w) { - int c = warp_greater[w]; - warp_greater[w] = run; - run += c; - } - shared_iter_counts[2] = run; + if (lane == 0) { + warp_less[warp] = __popc(less_mask); + warp_equal[warp] = __popc(equal_mask); + warp_greater[warp] = __popc(greater_mask); + } + __syncthreads(); + + if (threadIdx.x == 0) { + int run = 0; + for (int w = 0; w < NUM_WARPS; ++w) { + int c = warp_less[w]; + warp_less[w] = run; + run += c; } - __syncthreads(); - - if (active) { - int pos; - if (is_less) { - pos = shared_running_bases[0] + warp_less[warp] + less_rank; - } else if (is_equal) { - pos = shared_running_bases[1] + warp_equal[warp] + equal_rank; - } else { - pos = shared_running_bases[2] + warp_greater[warp] + greater_rank; - } - if (ARG_PARTITION) { - row_output[pos * out_stride] = static_cast(i); - } else { - row_output[pos * out_stride] = static_cast(val); - } + shared_iter_counts[0] = run; + + run = 0; + for (int w = 0; w < NUM_WARPS; ++w) { + int c = warp_equal[w]; + warp_equal[w] = run; + run += c; } - __syncthreads(); + shared_iter_counts[1] = run; - if (threadIdx.x == 0) { - shared_running_bases[0] += shared_iter_counts[0]; - shared_running_bases[1] += shared_iter_counts[1]; - shared_running_bases[2] += shared_iter_counts[2]; + run = 0; + for (int w = 0; w < NUM_WARPS; ++w) { + int c = warp_greater[w]; + warp_greater[w] = run; + run += c; } - __syncthreads(); + shared_iter_counts[2] = run; + } + __syncthreads(); + + if (active) { + int pos; + if (is_less) { + pos = shared_running_bases[0] + warp_less[warp] + less_rank; + } else if (is_equal) { + pos = shared_running_bases[1] + warp_equal[warp] + equal_rank; + } else { + pos = shared_running_bases[2] + warp_greater[warp] + greater_rank; + } + if (ARG_PARTITION) { + row_output[pos * out_stride] = static_cast(i); + } else { + row_output[pos * out_stride] = static_cast(val); + } + } + __syncthreads(); + + if (threadIdx.x == 0) { + shared_running_bases[0] += shared_iter_counts[0]; + shared_running_bases[1] += shared_iter_counts[1]; + shared_running_bases[2] += shared_iter_counts[2]; } __syncthreads(); } diff --git a/mlx/backend/cuda/sort.cu b/mlx/backend/cuda/sort.cu index f350dd607e..bdce67ef30 100644 --- a/mlx/backend/cuda/sort.cu +++ b/mlx/backend/cuda/sort.cu @@ -1213,7 +1213,6 @@ int64_t segment_stride_for_contiguous( struct RadixLaunchPlan { int blocks_per_row{1}; - int rows_per_block{1}; bool uses_tiled_launch() const { return blocks_per_row > 1; @@ -1249,7 +1248,6 @@ RadixLaunchPlan make_radix_tiled_launch_plan( constexpr int kBlocksPerSmTarget = 4; constexpr int kMinElemsPerBlock = 1024; constexpr int kMaxBlocksPerRow = 32; - constexpr int kMaxRowsPerBlock = 4; int sm_count = 0; CHECK_CUDA_ERROR(cudaDeviceGetAttribute( @@ -1267,18 +1265,7 @@ RadixLaunchPlan make_radix_tiled_launch_plan( if (blocks_per_row <= 1) { return {}; } - - int total_blocks = n_rows * blocks_per_row; - int chunk = (size_sorted_axis + blocks_per_row - 1) / blocks_per_row; - - int rows_per_block = 1; - if (total_blocks > target_blocks && chunk <= 2 * kMinElemsPerBlock) { - rows_per_block = std::min( - {kMaxRowsPerBlock, - n_rows, - std::max(1, (total_blocks + target_blocks - 1) / target_blocks)}); - } - return {blocks_per_row, rows_per_block}; + return {blocks_per_row}; } Dtype unsigned_dtype_for_size(int size) { @@ -1308,7 +1295,6 @@ void gpu_radix_partition_large_tiled( int64_t in_stride_segment_axis, int64_t out_stride_segment_axis, int blocks_per_row, - int rows_per_block, bool arg_partition) { auto& encoder = cu::get_command_encoder(s); out.set_data(cu::malloc_async(out.nbytes(), encoder)); @@ -1334,9 +1320,6 @@ void gpu_radix_partition_large_tiled( int total_blocks = n_rows * blocks_per_row; array block_less_dev({total_blocks}, int32, nullptr, {}); array block_equal_dev({total_blocks}, int32, nullptr, {}); - array less_base_dev({total_blocks}, int32, nullptr, {}); - array equal_base_dev({total_blocks}, int32, nullptr, {}); - array greater_base_dev({total_blocks}, int32, nullptr, {}); auto allocate_temporary = [&](array& a) { a.set_data(cu::malloc_async(a.nbytes(), encoder)); @@ -1348,13 +1331,9 @@ void gpu_radix_partition_large_tiled( allocate_temporary(row_hist_dev); allocate_temporary(block_less_dev); allocate_temporary(block_equal_dev); - allocate_temporary(less_base_dev); - allocate_temporary(equal_base_dev); - allocate_temporary(greater_base_dev); - int row_groups = (n_rows + rows_per_block - 1) / rows_per_block; - dim3 row_grid(1, row_groups, 1); - dim3 grid(blocks_per_row, row_groups, 1); + dim3 row_grid(1, n_rows, 1); + dim3 grid(blocks_per_row, n_rows, 1); encoder.set_output_array(target_prefix_dev); encoder.set_output_array(prefix_mask_dev); @@ -1369,9 +1348,7 @@ void gpu_radix_partition_large_tiled( gpu_ptr(prefix_mask_dev), gpu_ptr(k_values_dev), gpu_ptr(row_hist_dev), - kth, - n_rows, - rows_per_block); + kth); for (int pass = NUM_PASSES - 1; pass >= 0; --pass) { int start_bit = pass * cu::RADIX_BITS; @@ -1394,8 +1371,6 @@ void gpu_radix_partition_large_tiled( gpu_ptr(prefix_mask_dev), start_bit, blocks_per_row, - n_rows, - rows_per_block, gpu_ptr(row_hist_dev)); encoder.set_input_array(row_hist_dev); @@ -1417,9 +1392,7 @@ void gpu_radix_partition_large_tiled( gpu_ptr(prefix_mask_dev), gpu_ptr(k_values_dev), pass > 0 ? 1 : 0, - start_bit, - n_rows, - rows_per_block); + start_bit); } encoder.set_input_array(in); @@ -1437,39 +1410,16 @@ void gpu_radix_partition_large_tiled( in_stride_segment_axis, gpu_ptr(target_prefix_dev), blocks_per_row, - n_rows, - rows_per_block, gpu_ptr(block_less_dev), gpu_ptr(block_equal_dev)); - encoder.set_input_array(block_less_dev); - encoder.set_input_array(block_equal_dev); - encoder.set_output_array(less_base_dev); - encoder.set_output_array(equal_base_dev); - encoder.set_output_array(greater_base_dev); - encoder.add_kernel_node( - cu::radix_select_tiled_prefix_kernel, - row_grid, - dim3(1, 1, 1), - 0, - size_sorted_axis, - blocks_per_row, - n_rows, - rows_per_block, - gpu_ptr(block_less_dev), - gpu_ptr(block_equal_dev), - gpu_ptr(less_base_dev), - gpu_ptr(equal_base_dev), - gpu_ptr(greater_base_dev)); - dispatch_bool(arg_partition, [&](auto arg_tag) { constexpr bool ARG_PARTITION = decltype(arg_tag)::value; using OutT = std::conditional_t; encoder.set_input_array(in); encoder.set_input_array(target_prefix_dev); - encoder.set_input_array(less_base_dev); - encoder.set_input_array(equal_base_dev); - encoder.set_input_array(greater_base_dev); + encoder.set_input_array(block_less_dev); + encoder.set_input_array(block_equal_dev); encoder.set_output_array(out); encoder.add_kernel_node( cu::radix_select_tiled_scatter_kernel< @@ -1489,11 +1439,8 @@ void gpu_radix_partition_large_tiled( out_stride_segment_axis, gpu_ptr(target_prefix_dev), blocks_per_row, - n_rows, - rows_per_block, - gpu_ptr(less_base_dev), - gpu_ptr(equal_base_dev), - gpu_ptr(greater_base_dev)); + gpu_ptr(block_less_dev), + gpu_ptr(block_equal_dev)); }); } else { throw std::runtime_error( @@ -1556,7 +1503,6 @@ void gpu_radix_partition_large( in_stride_segment_axis, out_stride_segment_axis, plan.blocks_per_row, - plan.rows_per_block, arg_partition); } } From c034ff104ad1a9e90b6440cc8de66ae4db749676 Mon Sep 17 00:00:00 2001 From: Lyxot Date: Mon, 16 Feb 2026 18:02:46 +0800 Subject: [PATCH 16/23] replace the fixed argpartition small-kernel cutoff with a runtime check based on estimated shared-memory usage and device limits --- benchmarks/python/radix_select_bench.py | 195 +++++++++++++++++++++--- mlx/backend/cuda/sort.cu | 159 +++++++++++++------ 2 files changed, 289 insertions(+), 65 deletions(-) diff --git a/benchmarks/python/radix_select_bench.py b/benchmarks/python/radix_select_bench.py index 30bfe978db..49decd89b8 100644 --- a/benchmarks/python/radix_select_bench.py +++ b/benchmarks/python/radix_select_bench.py @@ -5,6 +5,7 @@ """ import argparse +import ctypes import time import mlx.core as mx @@ -27,6 +28,135 @@ "uint64": mx.uint64, } +# Benchmark-side model for cross-GPU small-kernel dispatch policy. +RADIX_ITEMS_BUCKETS = (1, 2, 4, 8, 12, 16, 24, 32, 48, 64) +MAX_RADIX_ITEMS_PER_THREAD = 64 +SMALL_RADIX_SIZE = 32 +WARP_SIZE = 32 +CUDA_DEV_ATTR_MAX_SHARED_MEMORY_PER_BLOCK = 8 +CUDA_DEV_ATTR_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN = 97 + +DTYPE_SIZE_BYTES = { + "bool_": 1, + "bfloat16": 2, + "float16": 2, + "float32": 4, + "float64": 8, + "int8": 1, + "int16": 2, + "int32": 4, + "int64": 8, + "uint8": 1, + "uint16": 2, + "uint32": 4, + "uint64": 8, +} + + +def _dtype_size_bytes(dtype): + dtype_name = str(dtype).split(".")[-1] + return DTYPE_SIZE_BYTES[dtype_name] + + +def _cuda_max_shared_mem_per_block(default=48 * 1024): + """Query max(base, optin) shared memory per block; fallback to 48KB.""" + try: + cudart = ctypes.CDLL("libcudart.so") + + cuda_get_device = cudart.cudaGetDevice + cuda_get_device.argtypes = [ctypes.POINTER(ctypes.c_int)] + cuda_get_device.restype = ctypes.c_int + + cuda_device_get_attribute = cudart.cudaDeviceGetAttribute + cuda_device_get_attribute.argtypes = [ + ctypes.POINTER(ctypes.c_int), + ctypes.c_int, + ctypes.c_int, + ] + cuda_device_get_attribute.restype = ctypes.c_int + + dev = ctypes.c_int() + if cuda_get_device(ctypes.byref(dev)) != 0: + return default + + smem_base = ctypes.c_int() + if ( + cuda_device_get_attribute( + ctypes.byref(smem_base), + CUDA_DEV_ATTR_MAX_SHARED_MEMORY_PER_BLOCK, + dev.value, + ) + != 0 + ): + return default + + smem_optin = ctypes.c_int() + optin_rc = cuda_device_get_attribute( + ctypes.byref(smem_optin), + CUDA_DEV_ATTR_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN, + dev.value, + ) + if optin_rc == 0: + return max(int(smem_base.value), int(smem_optin.value)) + return int(smem_base.value) + except Exception: + return default + + +def _radix_small_block_threads(vocab_size): + if vocab_size <= 128: + return 16 + if vocab_size <= 256: + return 32 + if vocab_size <= 512: + return 64 + if vocab_size <= 1024: + return 128 + return 256 + + +def _radix_small_dispatch_items(required_items): + for bucket in RADIX_ITEMS_BUCKETS: + if required_items <= bucket: + return bucket + return None + + +def _radix_small_shared_mem_bytes(dtype_size, block_threads, items_per_thread): + tile_size = block_threads * items_per_thread + num_warps = block_threads // WARP_SIZE + return ( + tile_size * dtype_size + + tile_size * 4 + + SMALL_RADIX_SIZE * 4 + + (2 + 3 * num_warps + 6) * 4 + ) + + +def estimate_small_kernel_limit(dtype): + """Estimate max small-kernel axis for dtype under current CUDA radix policy.""" + dtype_size = _dtype_size_bytes(dtype) + smem_limit = _cuda_max_shared_mem_per_block() + max_axis = 0 + # 256 is the largest block_threads in sort.cu launch selection. + for v in range(1, 256 * MAX_RADIX_ITEMS_PER_THREAD + 1): + block_threads = _radix_small_block_threads(v) + required_items = (v + block_threads - 1) // block_threads + if required_items > MAX_RADIX_ITEMS_PER_THREAD: + continue + items_per_thread = _radix_small_dispatch_items(required_items) + if items_per_thread is None: + continue + if ( + _radix_small_shared_mem_bytes(dtype_size, block_threads, items_per_thread) + <= smem_limit + ): + max_axis = v + return { + "max_axis": max_axis, + "smem_limit": smem_limit, + } + def parse_dtypes(dtype_str): """Parse comma-separated dtype string into MLX dtype objects.""" @@ -169,7 +299,7 @@ def verify_tie_determinism(b=64, v=1024, k=None, dtype=mx.float32, axis=-1): return True -def sweep_boundary( +def sweep_kernel( dtype=mx.bfloat16, k_ratio=0.004, warmup=10, @@ -178,19 +308,40 @@ def sweep_boundary( small_kernel=False, ): dtype_name = str(dtype).split(".")[-1] - print(f"\nDtype={dtype_name} k=vocab*{k_ratio:.3f}") + limit = estimate_small_kernel_limit(dtype) + max_small_axis = limit["max_axis"] + smem_kb = limit["smem_limit"] / 1024.0 + print( + f"\nDtype={dtype_name} k=vocab*{k_ratio:.3f} " + f"small-kernel-limit≈{max_small_axis} " + f"smem={smem_kb:.1f}KB" + ) print() - batch_sizes = ( - [1, 4, 8, 16, 32, 48, 64, 128, 256, 512, 1024, 2048, 4096, 8192] - if small_kernel - else [1, 2, 4, 8, 16, 32, 48, 64, 96, 128, 256, 512, 1024, 2048] - ) - vocab_sizes = ( - [32, 64, 96, 128, 160, 192, 256, 384, 512, 1024, 2048] - if small_kernel - else [3072, 4096, 8192, 16384, 32768, 65536, 131072] - ) + candidate_vocab = { + 32, + 64, + 96, + 160, + 256, + 384, + 512, + 1024, + 2048, + 4096, + 8192, + 16384, + 32768, + 65536, + 131072, + } + + if small_kernel: + batch_sizes = [1, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192] + vocab_sizes = sorted({int(v) for v in candidate_vocab if v <= max_small_axis}) + else: + batch_sizes = [1, 2, 4, 8, 16, 32, 48, 64, 96, 128, 256, 512, 1024, 2048] + vocab_sizes = sorted({int(v) for v in candidate_vocab if v > small_kernel}) col_w = 10 print(f"{'':>8}", end="") @@ -236,20 +387,20 @@ def main(): description="Benchmark MLX radix select implementation" ) parser.add_argument( - "--boundary-sweep", + "--large-kernel-sweep", action="store_true", - help="Enable boundary sweep test (default: disabled)", + help="Enable large-kernel-focused sweep (default: disabled)", ) parser.add_argument( "--small-kernel-sweep", action="store_true", - help="Enable small-kernel-only sweep (axis <= 2048 by default)", + help="Enable small-kernel-focused sweep around the estimated boundary", ) parser.add_argument( "--verify", action="store_true", help="Enable correctness verification (default: disabled). " - "Disabled when --boundary-sweep is enabled.", + "Disabled when --large-kernel-sweep is enabled.", ) parser.add_argument( "--dtypes", @@ -285,11 +436,11 @@ def main(): print(f"Error: {e}") return - if args.boundary_sweep and args.small_kernel_sweep: - print("Error: choose only one of --boundary-sweep or --small-kernel-sweep") + if args.large_kernel_sweep and args.small_kernel_sweep: + print("Error: choose only one of --large-kernel-sweep or --small-kernel-sweep") return - if not args.boundary_sweep and not args.small_kernel_sweep: + if not args.large_kernel_sweep and not args.small_kernel_sweep: if args.verify: print("\n1. Correctness Verification") print("-" * 40) @@ -344,11 +495,11 @@ def main(): except Exception as e: print(f"b={b}, v={v}, k={k}: Error - {e}") - if args.boundary_sweep or args.small_kernel_sweep: - print("\nBoundary Sweep" + (" (with verification)" if args.verify else "")) + if args.large_kernel_sweep or args.small_kernel_sweep: + print("\nKernel Sweep" + (" (with verification)" if args.verify else "")) print("-" * 70) for dtype, dtype_name in dtypes: - sweep_boundary( + sweep_kernel( dtype, verify=args.verify, small_kernel=args.small_kernel_sweep ) diff --git a/mlx/backend/cuda/sort.cu b/mlx/backend/cuda/sort.cu index bdce67ef30..7f5284dedd 100644 --- a/mlx/backend/cuda/sort.cu +++ b/mlx/backend/cuda/sort.cu @@ -1046,6 +1046,113 @@ void gpu_merge_sort( // Radix partition functions /////////////////////////////////////////////////////////////////////////////// +// Upper bound for small-kernel tiling. Keep this aligned with the +// items-per-thread dispatch set and per-block shared-memory budget. +constexpr int MAX_RADIX_ITEMS_PER_THREAD = 64; + +int radix_small_block_threads(int size_sorted_axis) { + int block_threads = 256; + if (size_sorted_axis <= 128) { + block_threads = 16; + } else if (size_sorted_axis <= 256) { + block_threads = 32; + } else if (size_sorted_axis <= 512) { + block_threads = 64; + } else if (size_sorted_axis <= 1024) { + block_threads = 128; + } + return block_threads; +} + +template +void dispatch_radix_items_per_thread( + int size_sorted_axis, + int block_threads, + F&& f) { + int items_per_thread = (size_sorted_axis + block_threads - 1) / block_threads; + if (items_per_thread <= 1) { + f(std::integral_constant{}); + } else if (items_per_thread <= 2) { + f(std::integral_constant{}); + } else if (items_per_thread <= 4) { + f(std::integral_constant{}); + } else if (items_per_thread <= 8) { + f(std::integral_constant{}); + } else if (items_per_thread <= 12) { + f(std::integral_constant{}); + } else if (items_per_thread <= 16) { + f(std::integral_constant{}); + } else if (items_per_thread <= 24) { + f(std::integral_constant{}); + } else if (items_per_thread <= 32) { + f(std::integral_constant{}); + } else if (items_per_thread <= 48) { + f(std::integral_constant{}); + } else { + f(std::integral_constant{}); + } +} + +size_t radix_small_shared_mem_bytes( + size_t key_size, + int block_threads, + int items_per_thread) { + size_t tile_size = static_cast(block_threads) * + static_cast(items_per_thread); + size_t num_warps = static_cast(block_threads / WARP_SIZE); + return tile_size * key_size + // shared_keys + tile_size * sizeof(uint32_t) + // shared_idxs + cu::SMALL_RADIX_SIZE * sizeof(int) + // shared_hist for small kernel + (2 + 3 * num_warps + 6) * sizeof(int); // shared_count + scatter scratch +} + +int radix_max_shared_mem_per_block(const Stream& s) { + int max_shared_mem_per_block = 0; + CHECK_CUDA_ERROR(cudaDeviceGetAttribute( + &max_shared_mem_per_block, + cudaDevAttrMaxSharedMemoryPerBlock, + s.device.index)); + + int max_shared_mem_per_block_optin = 0; + cudaError_t optin_err = cudaDeviceGetAttribute( + &max_shared_mem_per_block_optin, + cudaDevAttrMaxSharedMemoryPerBlockOptin, + s.device.index); + if (optin_err == cudaSuccess) { + max_shared_mem_per_block = + std::max(max_shared_mem_per_block, max_shared_mem_per_block_optin); + } else { + cudaGetLastError(); + } + return max_shared_mem_per_block; +} + +bool radix_small_fits_shared_memory( + const Stream& s, + Dtype dtype, + int size_sorted_axis) { + if (size_sorted_axis <= 0) { + return false; + } + + int block_threads = radix_small_block_threads(size_sorted_axis); + int required_items = (size_sorted_axis + block_threads - 1) / block_threads; + if (required_items > MAX_RADIX_ITEMS_PER_THREAD) { + return false; + } + + size_t required_shared_mem = 0; + dispatch_radix_items_per_thread( + size_sorted_axis, block_threads, [&](auto items_per_thread_tag) { + constexpr int ITEMS_PER_THREAD = items_per_thread_tag(); + required_shared_mem = radix_small_shared_mem_bytes( + size_of(dtype), block_threads, ITEMS_PER_THREAD); + }); + + int max_shared_mem_per_block = radix_max_shared_mem_per_block(s); + return required_shared_mem <= static_cast(max_shared_mem_per_block); +} + void gpu_radix_partition_small( const Stream& s, const array& in, @@ -1095,28 +1202,7 @@ void gpu_radix_partition_small( if constexpr (!std::is_same_v) { using ValT = cuda_type_t; - int block_threads = 256; - if (size_sorted_axis <= 128) { - block_threads = 16; - } else if (size_sorted_axis <= 256) { - block_threads = 32; - } else if (size_sorted_axis <= 512) { - block_threads = 64; - } else if (size_sorted_axis <= 1024) { - block_threads = 128; - } - - int items_per_thread = - (size_sorted_axis + block_threads - 1) / block_threads; - if (items_per_thread <= 1) { - items_per_thread = 1; - } else if (items_per_thread <= 2) { - items_per_thread = 2; - } else if (items_per_thread <= 4) { - items_per_thread = 4; - } else { - items_per_thread = 8; - } + int block_threads = radix_small_block_threads(size_sorted_axis); dispatch_bool(arg_partition, [&](auto arg_tag) { constexpr bool ARG_PARTITION = decltype(arg_tag)::value; @@ -1142,7 +1228,7 @@ void gpu_radix_partition_small( dim3 block(BLOCK_THREADS, 1, 1); dispatch_radix_items_per_thread( - items_per_thread, [&](auto items_per_thread_tag) { + size_sorted_axis, block_threads, [&](auto items_per_thread_tag) { constexpr int ITEMS_PER_THREAD = items_per_thread_tag(); dispatch_bool(contiguous, [&](auto contiguous_tag) { @@ -1170,6 +1256,11 @@ void gpu_radix_partition_small( (2 + 3 * NUM_WARPS + 6) * sizeof(int); // shared_count + scatter scratch + CHECK_CUDA_ERROR(cudaFuncSetAttribute( + kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, + static_cast(shared_mem_bytes))); + encoder.add_kernel_node( kernel, grid, @@ -1219,24 +1310,6 @@ struct RadixLaunchPlan { } }; -template -void dispatch_radix_items_per_thread(int items_per_thread, F&& f) { - switch (items_per_thread) { - case 1: - f(std::integral_constant{}); - break; - case 2: - f(std::integral_constant{}); - break; - case 4: - f(std::integral_constant{}); - break; - default: - f(std::integral_constant{}); - break; - } -} - RadixLaunchPlan make_radix_tiled_launch_plan( const Stream& s, int n_rows, @@ -1577,8 +1650,8 @@ void gpu_radix_partition( return gpu_merge_sort(s, in, out, axis, arg_partition); } - // Dispatch based on size - if (size_sorted_axis <= 2048) { + // Dispatch based on whether the small kernel tile fits in shared memory. + if (radix_small_fits_shared_memory(s, in.dtype(), size_sorted_axis)) { return gpu_radix_partition_small(s, in, out, axis, kth, arg_partition); } else { return gpu_radix_partition_large(s, in, out, axis, kth, arg_partition); From 56d484448484be62f2c46c1d5f9dcc2d11a55dfc Mon Sep 17 00:00:00 2001 From: Lyxot Date: Mon, 2 Mar 2026 20:31:47 +0800 Subject: [PATCH 17/23] remove large kernel --- benchmarks/python/radix_select_bench.py | 4 +- mlx/backend/cuda/device/radix_select.cuh | 737 +---------------------- mlx/backend/cuda/sort.cu | 353 +---------- 3 files changed, 32 insertions(+), 1062 deletions(-) diff --git a/benchmarks/python/radix_select_bench.py b/benchmarks/python/radix_select_bench.py index 49decd89b8..1fcc3b436d 100644 --- a/benchmarks/python/radix_select_bench.py +++ b/benchmarks/python/radix_select_bench.py @@ -31,7 +31,7 @@ # Benchmark-side model for cross-GPU small-kernel dispatch policy. RADIX_ITEMS_BUCKETS = (1, 2, 4, 8, 12, 16, 24, 32, 48, 64) MAX_RADIX_ITEMS_PER_THREAD = 64 -SMALL_RADIX_SIZE = 32 +RADIX_SIZE = 32 WARP_SIZE = 32 CUDA_DEV_ATTR_MAX_SHARED_MEMORY_PER_BLOCK = 8 CUDA_DEV_ATTR_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN = 97 @@ -128,7 +128,7 @@ def _radix_small_shared_mem_bytes(dtype_size, block_threads, items_per_thread): return ( tile_size * dtype_size + tile_size * 4 - + SMALL_RADIX_SIZE * 4 + + RADIX_SIZE * 4 + (2 + 3 * num_warps + 6) * 4 ) diff --git a/mlx/backend/cuda/device/radix_select.cuh b/mlx/backend/cuda/device/radix_select.cuh index a0ec0047c4..9584aafaca 100644 --- a/mlx/backend/cuda/device/radix_select.cuh +++ b/mlx/backend/cuda/device/radix_select.cuh @@ -4,7 +4,6 @@ #include #include -#include #include #include "mlx/backend/cuda/device/utils.cuh" @@ -17,11 +16,9 @@ namespace mlx::core::cu { // Uses IEEE 754 bit manipulation for correct floating-point ordering. /////////////////////////////////////////////////////////////////////////////// -// Radix configuration -constexpr int RADIX_BITS = 8; -constexpr int RADIX_SIZE = 1 << RADIX_BITS; // 256 bins -constexpr int SMALL_RADIX_BITS = 5; -constexpr int SMALL_RADIX_SIZE = 1 << SMALL_RADIX_BITS; // 32 bins +// Radix configuration used by the small shared-memory kernel. +constexpr int RADIX_BITS = 5; +constexpr int RADIX_SIZE = 1 << RADIX_BITS; // 32 bins /////////////////////////////////////////////////////////////////////////////// // Bit manipulation for radix sorting @@ -43,11 +40,6 @@ struct RadixTraits { uint32_t mask = -int32_t(bits >> 31) | 0x80000000u; return bits ^ mask; } - - __device__ __forceinline__ static float from_radix(UnsignedT bits) { - uint32_t mask = ((bits >> 31) - 1) | 0x80000000u; - return __uint_as_float(bits ^ mask); - } }; template <> @@ -63,11 +55,6 @@ struct RadixTraits { uint64_t mask = -int64_t(bits >> 63) | 0x8000000000000000ull; return bits ^ mask; } - - __device__ __forceinline__ static double from_radix(UnsignedT bits) { - uint64_t mask = ((bits >> 63) - 1) | 0x8000000000000000ull; - return __longlong_as_double(bits ^ mask); - } }; template <> @@ -83,11 +70,6 @@ struct RadixTraits<__half> { uint16_t mask = -int16_t(bits >> 15) | 0x8000u; return bits ^ mask; } - - __device__ __forceinline__ static __half from_radix(UnsignedT bits) { - uint16_t mask = ((bits >> 15) - 1) | 0x8000u; - return __ushort_as_half(bits ^ mask); - } }; template <> @@ -103,11 +85,6 @@ struct RadixTraits<__nv_bfloat16> { uint16_t mask = -int16_t(bits >> 15) | 0x8000u; return bits ^ mask; } - - __device__ __forceinline__ static __nv_bfloat16 from_radix(UnsignedT bits) { - uint16_t mask = ((bits >> 15) - 1) | 0x8000u; - return __ushort_as_bfloat16(bits ^ mask); - } }; template <> @@ -118,10 +95,6 @@ struct RadixTraits { __device__ __forceinline__ static UnsignedT to_radix(int8_t val) { return static_cast(val) ^ 0x80u; } - - __device__ __forceinline__ static int8_t from_radix(UnsignedT bits) { - return static_cast(bits ^ 0x80u); - } }; template <> @@ -132,10 +105,6 @@ struct RadixTraits { __device__ __forceinline__ static UnsignedT to_radix(int16_t val) { return static_cast(val) ^ 0x8000u; } - - __device__ __forceinline__ static int16_t from_radix(UnsignedT bits) { - return static_cast(bits ^ 0x8000u); - } }; template <> @@ -146,10 +115,6 @@ struct RadixTraits { __device__ __forceinline__ static UnsignedT to_radix(int32_t val) { return static_cast(val) ^ 0x80000000u; } - - __device__ __forceinline__ static int32_t from_radix(UnsignedT bits) { - return static_cast(bits ^ 0x80000000u); - } }; template <> @@ -160,10 +125,6 @@ struct RadixTraits { __device__ __forceinline__ static UnsignedT to_radix(int64_t val) { return static_cast(val) ^ 0x8000000000000000ull; } - - __device__ __forceinline__ static int64_t from_radix(UnsignedT bits) { - return static_cast(bits ^ 0x8000000000000000ull); - } }; template <> @@ -174,10 +135,6 @@ struct RadixTraits { __device__ __forceinline__ static UnsignedT to_radix(bool val) { return static_cast(val); } - - __device__ __forceinline__ static bool from_radix(UnsignedT bits) { - return bits != 0; - } }; template <> @@ -188,10 +145,6 @@ struct RadixTraits { __device__ __forceinline__ static UnsignedT to_radix(uint8_t val) { return val; } - - __device__ __forceinline__ static uint8_t from_radix(UnsignedT bits) { - return bits; - } }; template <> @@ -202,10 +155,6 @@ struct RadixTraits { __device__ __forceinline__ static UnsignedT to_radix(uint16_t val) { return val; } - - __device__ __forceinline__ static uint16_t from_radix(UnsignedT bits) { - return bits; - } }; template <> @@ -216,10 +165,6 @@ struct RadixTraits { __device__ __forceinline__ static UnsignedT to_radix(uint32_t val) { return val; } - - __device__ __forceinline__ static uint32_t from_radix(UnsignedT bits) { - return bits; - } }; template <> @@ -230,54 +175,12 @@ struct RadixTraits { __device__ __forceinline__ static UnsignedT to_radix(uint64_t val) { return val; } - - __device__ __forceinline__ static uint64_t from_radix(UnsignedT bits) { - return bits; - } }; -template -__device__ __forceinline__ int -extract_digit(UnsignedT val, int start_bit, int num_bits) { - return (val >> start_bit) & ((1 << num_bits) - 1); -} - -template -__device__ __forceinline__ bool is_nan_value(T val) { - if constexpr (cuda::std::is_floating_point_v) { - return cuda::std::isnan(val); - } else if constexpr (cuda::std::is_same_v) { - return __hisnan(val); - } else if constexpr (cuda::std::is_same_v) { - return __hisnan(val); - } else { - return false; - } -} - -template -__device__ __forceinline__ typename RadixTraits::UnsignedT -radix_key_with_nan_last(ValT val) { - using UnsignedT = typename RadixTraits::UnsignedT; - UnsignedT key = RadixTraits::to_radix(val); - if (is_nan_value(val)) { - key = ~UnsignedT(0); - } - return key; -} - /////////////////////////////////////////////////////////////////////////////// // Warp-level utilities /////////////////////////////////////////////////////////////////////////////// -template -__device__ __forceinline__ T warp_reduce_sum(T val) { - for (int offset = 16; offset > 0; offset /= 2) { - val += __shfl_down_sync(0xFFFFFFFF, val, offset); - } - return val; -} - template __device__ __forceinline__ int block_exclusive_scan( int val, @@ -326,7 +229,7 @@ __device__ __forceinline__ int block_exclusive_scan( } /////////////////////////////////////////////////////////////////////////////// -// Single-pass Radix Select for small arrays (fits in shared memory) +// Radix Select for small arrays (fits in shared memory) /////////////////////////////////////////////////////////////////////////////// template < @@ -353,8 +256,7 @@ __global__ void radix_select_small_kernel( using UnsignedT = typename Traits::UnsignedT; constexpr int TILE_SIZE = BLOCK_THREADS * ITEMS_PER_THREAD; - constexpr int NUM_PASSES = - (Traits::BITS + SMALL_RADIX_BITS - 1) / SMALL_RADIX_BITS; + constexpr int NUM_PASSES = (Traits::BITS + RADIX_BITS - 1) / RADIX_BITS; // Dynamic shared memory layout extern __shared__ char shared_mem[]; @@ -363,7 +265,7 @@ __global__ void radix_select_small_kernel( UnsignedT* shared_keys = reinterpret_cast(shared_mem); uint32_t* shared_idxs = reinterpret_cast(shared_keys + TILE_SIZE); int* shared_hist = reinterpret_cast(shared_idxs + TILE_SIZE); - int* shared_count = shared_hist + SMALL_RADIX_SIZE; + int* shared_count = shared_hist + RADIX_SIZE; constexpr int NUM_WARPS = BLOCK_THREADS / WARP_SIZE; int* scatter_scratch = shared_count + 2; int* warp_less = scatter_scratch; @@ -396,8 +298,20 @@ __global__ void radix_select_small_kernel( if (i < tile_n) { ValT val = row_input[i * in_stride]; UnsignedT key = Traits::to_radix(val); - if (is_nan_value(val)) { - key = ~UnsignedT(0); + if constexpr (cuda::std::is_floating_point_v) { + if (cuda::std::isnan(val)) { + key = ~UnsignedT(0); + } + } else if constexpr (cuda::std::is_same_v) { + if (__hisnan(val)) { + key = ~UnsignedT(0); + } + } else if constexpr (cuda::std::is_same_v) { + if (__hisnan(val)) { + key = ~UnsignedT(0); + } + } else { + // Non-floating types cannot produce NaN keys. } shared_keys[i] = key; shared_idxs[i] = i; @@ -414,10 +328,10 @@ __global__ void radix_select_small_kernel( UnsignedT prefix_mask = 0; for (int pass = NUM_PASSES - 1; pass >= 0; pass--) { - int start_bit = pass * SMALL_RADIX_BITS; + int start_bit = pass * RADIX_BITS; // Clear histogram - for (int i = threadIdx.x; i < SMALL_RADIX_SIZE; i += BLOCK_THREADS) { + for (int i = threadIdx.x; i < RADIX_SIZE; i += BLOCK_THREADS) { shared_hist[i] = 0; } __syncthreads(); @@ -426,7 +340,7 @@ __global__ void radix_select_small_kernel( for (int i = threadIdx.x; i < tile_n; i += BLOCK_THREADS) { UnsignedT key = shared_keys[i]; if ((key & prefix_mask) == target_prefix) { - int digit = extract_digit(key, start_bit, SMALL_RADIX_BITS); + int digit = (key >> start_bit) & ((1 << RADIX_BITS) - 1); atomicAdd(&shared_hist[digit], 1); } } @@ -436,7 +350,7 @@ __global__ void radix_select_small_kernel( if (threadIdx.x == 0) { int cumsum = 0; int target_bin = 0; - for (int bin = 0; bin < SMALL_RADIX_SIZE; bin++) { + for (int bin = 0; bin < RADIX_SIZE; bin++) { int count = shared_hist[bin]; if (cumsum + count >= k) { target_bin = bin; @@ -453,7 +367,7 @@ __global__ void radix_select_small_kernel( int target_bin = shared_count[0]; k = shared_count[1]; - UnsignedT digit_mask = UnsignedT((1 << SMALL_RADIX_BITS) - 1) << start_bit; + UnsignedT digit_mask = UnsignedT((1 << RADIX_BITS) - 1) << start_bit; target_prefix |= UnsignedT(target_bin) << start_bit; prefix_mask |= digit_mask; @@ -575,605 +489,4 @@ __global__ void radix_select_small_kernel( } } -/////////////////////////////////////////////////////////////////////////////// -// Large array streaming kernel (multi-pass, in-place) -/////////////////////////////////////////////////////////////////////////////// - -template -__global__ void radix_select_large_streaming_kernel( - const ValT* input, - OutT* output, - int n, - int kth, - int64_t in_stride, - int64_t out_stride, - const __grid_constant__ Shape nc_shape, - const __grid_constant__ Strides in_nc_strides, - const __grid_constant__ Strides out_nc_strides, - int nc_dim) { - using Traits = RadixTraits; - using UnsignedT = typename Traits::UnsignedT; - constexpr int NUM_PASSES = (Traits::BITS + RADIX_BITS - 1) / RADIX_BITS; - - int row = blockIdx.y; - int64_t in_block_idx = - elem_to_loc(int64_t(row), nc_shape.data(), in_nc_strides.data(), nc_dim); - int64_t out_block_idx = - elem_to_loc(int64_t(row), nc_shape.data(), out_nc_strides.data(), nc_dim); - const ValT* row_input = input + in_block_idx; - OutT* row_output = output + out_block_idx; - - // Shared memory - __shared__ int shared_hist[RADIX_SIZE]; - __shared__ int shared_pivot_info[2]; - __shared__ int shared_counts[2]; - - int k = kth + 1; - UnsignedT target_prefix = 0; - UnsignedT prefix_mask = 0; - - // Multi-pass to find pivot - for (int pass = NUM_PASSES - 1; pass >= 0; pass--) { - int start_bit = pass * RADIX_BITS; - - // Clear histogram - for (int i = threadIdx.x; i < RADIX_SIZE; i += BLOCK_THREADS) { - shared_hist[i] = 0; - } - __syncthreads(); - - // Build histogram - bool is_contiguous = (in_stride == 1); - if (is_contiguous) { - for (int i = threadIdx.x; i < n; i += BLOCK_THREADS) { - ValT val = row_input[i]; - UnsignedT key = Traits::to_radix(val); - if (is_nan_value(val)) { - key = ~UnsignedT(0); - } - - if ((key & prefix_mask) == target_prefix) { - int digit = extract_digit(key, start_bit, RADIX_BITS); - atomicAdd(&shared_hist[digit], 1); - } - } - } else { - for (int i = threadIdx.x; i < n; i += BLOCK_THREADS) { - ValT val = row_input[i * in_stride]; - UnsignedT key = Traits::to_radix(val); - if (is_nan_value(val)) { - key = ~UnsignedT(0); - } - - if ((key & prefix_mask) == target_prefix) { - int digit = extract_digit(key, start_bit, RADIX_BITS); - atomicAdd(&shared_hist[digit], 1); - } - } - } - __syncthreads(); - - // Find target bin - if (threadIdx.x == 0) { - int cumsum = 0; - int target_bin = 0; - for (int bin = 0; bin < RADIX_SIZE; bin++) { - int count = shared_hist[bin]; - if (cumsum + count >= k) { - target_bin = bin; - k = k - cumsum; - break; - } - cumsum += count; - } - shared_pivot_info[0] = target_bin; - shared_pivot_info[1] = k; - } - __syncthreads(); - - int target_bin = shared_pivot_info[0]; - k = shared_pivot_info[1]; - - UnsignedT digit_mask = UnsignedT((1 << RADIX_BITS) - 1) << start_bit; - target_prefix |= UnsignedT(target_bin) << start_bit; - prefix_mask |= digit_mask; - } - - // Count partition sizes. - int local_less = 0, local_equal = 0; - bool is_contiguous = (in_stride == 1); - - if (is_contiguous) { - for (int i = threadIdx.x; i < n; i += BLOCK_THREADS) { - ValT val = row_input[i]; - UnsignedT key = Traits::to_radix(val); - if (is_nan_value(val)) { - key = ~UnsignedT(0); - } - if (key < target_prefix) - local_less++; - else if (key == target_prefix) - local_equal++; - } - } else { - for (int i = threadIdx.x; i < n; i += BLOCK_THREADS) { - ValT val = row_input[i * in_stride]; - UnsignedT key = Traits::to_radix(val); - if (is_nan_value(val)) { - key = ~UnsignedT(0); - } - if (key < target_prefix) - local_less++; - else if (key == target_prefix) - local_equal++; - } - } - - local_less = warp_reduce_sum(local_less); - local_equal = warp_reduce_sum(local_equal); - - if (threadIdx.x == 0) { - shared_counts[0] = 0; - shared_counts[1] = 0; - } - __syncthreads(); - - if ((threadIdx.x & (WARP_SIZE - 1)) == 0) { - atomicAdd(&shared_counts[0], local_less); - atomicAdd(&shared_counts[1], local_equal); - } - __syncthreads(); - - int less_count = shared_counts[0]; - int equal_count = shared_counts[1]; - - // Deterministic scatter in iteration order (0..n): this keeps output stable - // without thread-contention atomics in the hot scatter path. - constexpr int NUM_WARPS = BLOCK_THREADS / WARP_SIZE; - int lane = threadIdx.x & (WARP_SIZE - 1); - int warp = threadIdx.x / WARP_SIZE; - - int* warp_less = shared_hist; - int* warp_equal = shared_hist + NUM_WARPS; - int* warp_greater = shared_hist + 2 * NUM_WARPS; - int* iter_counts = shared_hist + 3 * NUM_WARPS; - int* running_bases = iter_counts + 3; - - if (threadIdx.x == 0) { - running_bases[0] = 0; - running_bases[1] = less_count; - running_bases[2] = less_count + equal_count; - } - __syncthreads(); - - for (int base_i = 0; base_i < n; base_i += BLOCK_THREADS) { - int i = base_i + threadIdx.x; - bool active = i < n; - - ValT val{}; - UnsignedT key = 0; - if (active) { - val = is_contiguous ? row_input[i] : row_input[i * in_stride]; - key = Traits::to_radix(val); - if (is_nan_value(val)) { - key = ~UnsignedT(0); - } - } - - bool is_less = active && (key < target_prefix); - bool is_equal = active && (key == target_prefix); - bool is_greater = active && !is_less && !is_equal; - - unsigned less_mask = __ballot_sync(0xFFFFFFFF, is_less); - unsigned equal_mask = __ballot_sync(0xFFFFFFFF, is_equal); - unsigned greater_mask = __ballot_sync(0xFFFFFFFF, is_greater); - - unsigned lane_mask = (1u << lane) - 1u; - int less_rank = __popc(less_mask & lane_mask); - int equal_rank = __popc(equal_mask & lane_mask); - int greater_rank = __popc(greater_mask & lane_mask); - - if (lane == 0) { - warp_less[warp] = __popc(less_mask); - warp_equal[warp] = __popc(equal_mask); - warp_greater[warp] = __popc(greater_mask); - } - __syncthreads(); - - if (threadIdx.x == 0) { - int run = 0; - for (int w = 0; w < NUM_WARPS; ++w) { - int c = warp_less[w]; - warp_less[w] = run; - run += c; - } - iter_counts[0] = run; - - run = 0; - for (int w = 0; w < NUM_WARPS; ++w) { - int c = warp_equal[w]; - warp_equal[w] = run; - run += c; - } - iter_counts[1] = run; - - run = 0; - for (int w = 0; w < NUM_WARPS; ++w) { - int c = warp_greater[w]; - warp_greater[w] = run; - run += c; - } - iter_counts[2] = run; - } - __syncthreads(); - - if (active) { - int pos; - if (is_less) { - pos = running_bases[0] + warp_less[warp] + less_rank; - } else if (is_equal) { - pos = running_bases[1] + warp_equal[warp] + equal_rank; - } else { - pos = running_bases[2] + warp_greater[warp] + greater_rank; - } - - if (ARG_PARTITION) { - if (out_stride == 1) { - row_output[pos] = i; - } else { - row_output[pos * out_stride] = i; - } - } else { - if (out_stride == 1) { - row_output[pos] = val; - } else { - row_output[pos * out_stride] = val; - } - } - } - __syncthreads(); - - if (threadIdx.x == 0) { - running_bases[0] += iter_counts[0]; - running_bases[1] += iter_counts[1]; - running_bases[2] += iter_counts[2]; - } - __syncthreads(); - } -} - -/////////////////////////////////////////////////////////////////////////////// -// Tiled large-array kernels -// -// These kernels run with a 2D launch: -// - x-dimension tiles one row across multiple blocks (multi-block-per-row) -// - y-dimension selects the row index -/////////////////////////////////////////////////////////////////////////////// - -template -__global__ void radix_select_tiled_init_state_kernel( - UnsignedT* target_prefix, - UnsignedT* prefix_mask, - int* k_values, - int* row_hist, - int kth) { - int row = blockIdx.y; - if (threadIdx.x == 0) { - target_prefix[row] = UnsignedT(0); - prefix_mask[row] = UnsignedT(0); - k_values[row] = kth + 1; - } - int* hist = row_hist + row * RADIX_SIZE; - for (int i = threadIdx.x; i < RADIX_SIZE; i += blockDim.x) { - hist[i] = 0; - } -} - -template -__global__ void radix_select_tiled_histogram_kernel( - const ValT* input, - int n, - int64_t in_stride, - int64_t in_segment_stride, - const typename RadixTraits::UnsignedT* target_prefix, - const typename RadixTraits::UnsignedT* prefix_mask, - int start_bit, - int blocks_per_row, - int* row_hist) { - using Traits = RadixTraits; - using UnsignedT = typename Traits::UnsignedT; - - int block_in_row = blockIdx.x; - int row = blockIdx.y; - - int chunk = (n + blocks_per_row - 1) / blocks_per_row; - int start = block_in_row * chunk; - int end = min(n, start + chunk); - if (start >= n) { - return; - } - - __shared__ int shared_hist[RADIX_SIZE]; - const ValT* row_input = input + row * in_segment_stride; - UnsignedT row_prefix = target_prefix[row]; - UnsignedT row_mask = prefix_mask[row]; - - for (int i = threadIdx.x; i < RADIX_SIZE; i += BLOCK_THREADS) { - shared_hist[i] = 0; - } - __syncthreads(); - - for (int i = start + threadIdx.x; i < end; i += BLOCK_THREADS) { - ValT val = row_input[i * in_stride]; - UnsignedT key = radix_key_with_nan_last(val); - if ((key & row_mask) == row_prefix) { - int digit = extract_digit(key, start_bit, RADIX_BITS); - atomicAdd(&shared_hist[digit], 1); - } - } - __syncthreads(); - - int* hist = row_hist + row * RADIX_SIZE; - for (int i = threadIdx.x; i < RADIX_SIZE; i += BLOCK_THREADS) { - atomicAdd(&hist[i], shared_hist[i]); - } -} - -template -__global__ void radix_select_tiled_select_bin_kernel( - int* row_hist, - UnsignedT* target_prefix, - UnsignedT* prefix_mask, - int* k_values, - int clear_hist_for_next_pass, - int start_bit) { - int row = blockIdx.y; - int* hist = row_hist + row * RADIX_SIZE; - - if (threadIdx.x == 0) { - int k = k_values[row]; - int cumsum = 0; - int target_bin = 0; - for (int bin = 0; bin < RADIX_SIZE; bin++) { - int count = hist[bin]; - if (cumsum + count >= k) { - target_bin = bin; - k -= cumsum; - break; - } - cumsum += count; - } - k_values[row] = k; - - UnsignedT digit_mask = - (UnsignedT((UnsignedT(1) << RADIX_BITS) - UnsignedT(1)) << start_bit); - target_prefix[row] |= UnsignedT(target_bin) << start_bit; - prefix_mask[row] |= digit_mask; - } - __syncthreads(); - - if (clear_hist_for_next_pass) { - for (int i = threadIdx.x; i < RADIX_SIZE; i += blockDim.x) { - hist[i] = 0; - } - } -} - -template -__global__ void radix_select_tiled_count_kernel( - const ValT* input, - int n, - int64_t in_stride, - int64_t in_segment_stride, - const typename RadixTraits::UnsignedT* target_prefix, - int blocks_per_row, - int* block_less, - int* block_equal) { - using Traits = RadixTraits; - using UnsignedT = typename Traits::UnsignedT; - - int block_in_row = blockIdx.x; - int row = blockIdx.y; - - int chunk = (n + blocks_per_row - 1) / blocks_per_row; - int start = block_in_row * chunk; - int end = min(n, start + chunk); - if (start >= n) { - return; - } - - __shared__ int shared_counts[2]; - int block_idx = row * blocks_per_row + block_in_row; - const ValT* row_input = input + row * in_segment_stride; - UnsignedT row_prefix = target_prefix[row]; - - int local_less = 0; - int local_equal = 0; - for (int i = start + threadIdx.x; i < end; i += BLOCK_THREADS) { - ValT val = row_input[i * in_stride]; - UnsignedT key = radix_key_with_nan_last(val); - if (key < row_prefix) { - local_less++; - } else if (key == row_prefix) { - local_equal++; - } - } - - local_less = warp_reduce_sum(local_less); - local_equal = warp_reduce_sum(local_equal); - - if (threadIdx.x == 0) { - shared_counts[0] = 0; - shared_counts[1] = 0; - } - __syncthreads(); - - if ((threadIdx.x % WARP_SIZE) == 0) { - atomicAdd(&shared_counts[0], local_less); - atomicAdd(&shared_counts[1], local_equal); - } - __syncthreads(); - - if (threadIdx.x == 0) { - block_less[block_idx] = shared_counts[0]; - block_equal[block_idx] = shared_counts[1]; - } -} - -template -__global__ void radix_select_tiled_scatter_kernel( - const ValT* input, - OutT* output, - int n, - int64_t in_stride, - int64_t out_stride, - int64_t in_segment_stride, - int64_t out_segment_stride, - const typename RadixTraits::UnsignedT* target_prefix, - int blocks_per_row, - const int* block_less, - const int* block_equal) { - using Traits = RadixTraits; - using UnsignedT = typename Traits::UnsignedT; - - int block_in_row = blockIdx.x; - int row = blockIdx.y; - - int chunk = (n + blocks_per_row - 1) / blocks_per_row; - int start = block_in_row * chunk; - int end = min(n, start + chunk); - if (start >= n) { - return; - } - - constexpr int NUM_WARPS = BLOCK_THREADS / WARP_SIZE; - __shared__ int shared_warp_offsets[3 * NUM_WARPS]; - __shared__ int shared_iter_counts[3]; - __shared__ int shared_running_bases[3]; - - int lane = threadIdx.x & (WARP_SIZE - 1); - int warp = threadIdx.x / WARP_SIZE; - - int* warp_less = shared_warp_offsets; - int* warp_equal = shared_warp_offsets + NUM_WARPS; - int* warp_greater = shared_warp_offsets + 2 * NUM_WARPS; - - int row_off = row * blocks_per_row; - const ValT* row_input = input + row * in_segment_stride; - OutT* row_output = output + row * out_segment_stride; - UnsignedT row_prefix = target_prefix[row]; - - if (threadIdx.x == 0) { - int total_less = 0; - int total_equal = 0; - for (int b = 0; b < blocks_per_row; ++b) { - int idx = row_off + b; - total_less += block_less[idx]; - total_equal += block_equal[idx]; - } - - int run_less = 0; - int run_equal = 0; - int run_greater = 0; - for (int b = 0; b < block_in_row; ++b) { - int idx = row_off + b; - int b_start = b * chunk; - int b_end = min(n, b_start + chunk); - int b_chunk_size = max(0, b_end - b_start); - int greater_count = b_chunk_size - block_less[idx] - block_equal[idx]; - run_less += block_less[idx]; - run_equal += block_equal[idx]; - run_greater += greater_count; - } - - shared_running_bases[0] = run_less; - shared_running_bases[1] = total_less + run_equal; - shared_running_bases[2] = total_less + total_equal + run_greater; - } - __syncthreads(); - - for (int base_i = start; base_i < end; base_i += BLOCK_THREADS) { - int i = base_i + threadIdx.x; - bool active = i < end; - - ValT val{}; - UnsignedT key = 0; - if (active) { - val = row_input[i * in_stride]; - key = radix_key_with_nan_last(val); - } - - bool is_less = active && (key < row_prefix); - bool is_equal = active && (key == row_prefix); - bool is_greater = active && !is_less && !is_equal; - - unsigned less_mask = __ballot_sync(0xFFFFFFFF, is_less); - unsigned equal_mask = __ballot_sync(0xFFFFFFFF, is_equal); - unsigned greater_mask = __ballot_sync(0xFFFFFFFF, is_greater); - - unsigned lane_mask = (1u << lane) - 1u; - int less_rank = __popc(less_mask & lane_mask); - int equal_rank = __popc(equal_mask & lane_mask); - int greater_rank = __popc(greater_mask & lane_mask); - - if (lane == 0) { - warp_less[warp] = __popc(less_mask); - warp_equal[warp] = __popc(equal_mask); - warp_greater[warp] = __popc(greater_mask); - } - __syncthreads(); - - if (threadIdx.x == 0) { - int run = 0; - for (int w = 0; w < NUM_WARPS; ++w) { - int c = warp_less[w]; - warp_less[w] = run; - run += c; - } - shared_iter_counts[0] = run; - - run = 0; - for (int w = 0; w < NUM_WARPS; ++w) { - int c = warp_equal[w]; - warp_equal[w] = run; - run += c; - } - shared_iter_counts[1] = run; - - run = 0; - for (int w = 0; w < NUM_WARPS; ++w) { - int c = warp_greater[w]; - warp_greater[w] = run; - run += c; - } - shared_iter_counts[2] = run; - } - __syncthreads(); - - if (active) { - int pos; - if (is_less) { - pos = shared_running_bases[0] + warp_less[warp] + less_rank; - } else if (is_equal) { - pos = shared_running_bases[1] + warp_equal[warp] + equal_rank; - } else { - pos = shared_running_bases[2] + warp_greater[warp] + greater_rank; - } - if (ARG_PARTITION) { - row_output[pos * out_stride] = static_cast(i); - } else { - row_output[pos * out_stride] = static_cast(val); - } - } - __syncthreads(); - - if (threadIdx.x == 0) { - shared_running_bases[0] += shared_iter_counts[0]; - shared_running_bases[1] += shared_iter_counts[1]; - shared_running_bases[2] += shared_iter_counts[2]; - } - __syncthreads(); - } -} - } // namespace mlx::core::cu diff --git a/mlx/backend/cuda/sort.cu b/mlx/backend/cuda/sort.cu index 7f5284dedd..9daef01865 100644 --- a/mlx/backend/cuda/sort.cu +++ b/mlx/backend/cuda/sort.cu @@ -1102,7 +1102,7 @@ size_t radix_small_shared_mem_bytes( size_t num_warps = static_cast(block_threads / WARP_SIZE); return tile_size * key_size + // shared_keys tile_size * sizeof(uint32_t) + // shared_idxs - cu::SMALL_RADIX_SIZE * sizeof(int) + // shared_hist for small kernel + cu::RADIX_SIZE * sizeof(int) + // shared_hist for small kernel (2 + 3 * num_warps + 6) * sizeof(int); // shared_count + scatter scratch } @@ -1251,7 +1251,7 @@ void gpu_radix_partition_small( constexpr size_t shared_mem_bytes = TILE_SIZE_VAL * sizeof(UnsignedT) + // shared_keys TILE_SIZE_VAL * sizeof(uint32_t) + // shared_idxs - cu::SMALL_RADIX_SIZE * + cu::RADIX_SIZE * sizeof(int) + // shared_hist for small kernel (2 + 3 * NUM_WARPS + 6) * sizeof(int); // shared_count + scatter scratch @@ -1261,10 +1261,11 @@ void gpu_radix_partition_small( cudaFuncAttributeMaxDynamicSharedMemorySize, static_cast(shared_mem_bytes))); - encoder.add_kernel_node( + encoder.add_kernel_node_ex( kernel, grid, block, + {}, shared_mem_bytes, gpu_ptr(in), gpu_ptr(out), @@ -1289,350 +1290,6 @@ void gpu_radix_partition_small( }); } -int64_t segment_stride_for_contiguous( - const Shape& shape_no_axis, - const Strides& strides_no_axis) { - int64_t stride = INT64_MAX; - for (size_t i = 0; i < shape_no_axis.size(); ++i) { - if (shape_no_axis[i] == 1) { - continue; - } - stride = std::min(stride, strides_no_axis[i]); - } - return (stride == INT64_MAX) ? int64_t(0) : stride; -} - -struct RadixLaunchPlan { - int blocks_per_row{1}; - - bool uses_tiled_launch() const { - return blocks_per_row > 1; - } -}; - -RadixLaunchPlan make_radix_tiled_launch_plan( - const Stream& s, - int n_rows, - int size_sorted_axis) { - if (n_rows <= 0 || size_sorted_axis <= 0 || size_sorted_axis < 8192) { - return {}; - } - - constexpr int kBlocksPerSmTarget = 4; - constexpr int kMinElemsPerBlock = 1024; - constexpr int kMaxBlocksPerRow = 32; - - int sm_count = 0; - CHECK_CUDA_ERROR(cudaDeviceGetAttribute( - &sm_count, cudaDevAttrMultiProcessorCount, s.device.index)); - sm_count = std::max(sm_count, 1); - - int target_blocks = std::max(1, sm_count * kBlocksPerSmTarget); - int needed_blocks_per_row = - std::max(1, (target_blocks + n_rows - 1) / n_rows); - - int max_blocks_by_work = std::max( - 1, std::min(kMaxBlocksPerRow, size_sorted_axis / kMinElemsPerBlock)); - int blocks_per_row = std::min(needed_blocks_per_row, max_blocks_by_work); - - if (blocks_per_row <= 1) { - return {}; - } - return {blocks_per_row}; -} - -Dtype unsigned_dtype_for_size(int size) { - switch (size) { - case 1: - return uint8; - case 2: - return uint16; - case 4: - return uint32; - case 8: - return uint64; - default: - throw std::runtime_error("Unsupported radix key size"); - } -} - -void gpu_radix_partition_large_tiled( - const Stream& s, - const array& in, - array& out, - int kth, - int n_rows, - int size_sorted_axis, - int64_t in_stride_sorted_axis, - int64_t out_stride_sorted_axis, - int64_t in_stride_segment_axis, - int64_t out_stride_segment_axis, - int blocks_per_row, - bool arg_partition) { - auto& encoder = cu::get_command_encoder(s); - out.set_data(cu::malloc_async(out.nbytes(), encoder)); - encoder.set_input_array(in); - encoder.set_output_array(out); - - dispatch_all_types(in.dtype(), [&](auto type_tag) { - using CTYPE = MLX_GET_TYPE(type_tag); - if constexpr (!std::is_same_v) { - using ValT = cuda_type_t; - constexpr int BLOCK_THREADS = 256; - constexpr int NUM_PASSES = - (cu::RadixTraits::BITS + cu::RADIX_BITS - 1) / cu::RADIX_BITS; - using UnsignedT = typename cu::RadixTraits::UnsignedT; - - Dtype unsigned_dtype = unsigned_dtype_for_size(sizeof(UnsignedT)); - - array target_prefix_dev({n_rows}, unsigned_dtype, nullptr, {}); - array prefix_mask_dev({n_rows}, unsigned_dtype, nullptr, {}); - array k_values_dev({n_rows}, int32, nullptr, {}); - array row_hist_dev({n_rows, cu::RADIX_SIZE}, int32, nullptr, {}); - - int total_blocks = n_rows * blocks_per_row; - array block_less_dev({total_blocks}, int32, nullptr, {}); - array block_equal_dev({total_blocks}, int32, nullptr, {}); - - auto allocate_temporary = [&](array& a) { - a.set_data(cu::malloc_async(a.nbytes(), encoder)); - encoder.add_temporary(a); - }; - allocate_temporary(target_prefix_dev); - allocate_temporary(prefix_mask_dev); - allocate_temporary(k_values_dev); - allocate_temporary(row_hist_dev); - allocate_temporary(block_less_dev); - allocate_temporary(block_equal_dev); - - dim3 row_grid(1, n_rows, 1); - dim3 grid(blocks_per_row, n_rows, 1); - - encoder.set_output_array(target_prefix_dev); - encoder.set_output_array(prefix_mask_dev); - encoder.set_output_array(k_values_dev); - encoder.set_output_array(row_hist_dev); - encoder.add_kernel_node( - cu::radix_select_tiled_init_state_kernel, - row_grid, - dim3(32, 1, 1), - 0, - gpu_ptr(target_prefix_dev), - gpu_ptr(prefix_mask_dev), - gpu_ptr(k_values_dev), - gpu_ptr(row_hist_dev), - kth); - - for (int pass = NUM_PASSES - 1; pass >= 0; --pass) { - int start_bit = pass * cu::RADIX_BITS; - - encoder.set_input_array(in); - encoder.set_input_array(row_hist_dev); - encoder.set_input_array(target_prefix_dev); - encoder.set_input_array(prefix_mask_dev); - encoder.set_output_array(row_hist_dev); - encoder.add_kernel_node( - cu::radix_select_tiled_histogram_kernel, - grid, - dim3(BLOCK_THREADS, 1, 1), - 0, - gpu_ptr(in), - size_sorted_axis, - in_stride_sorted_axis, - in_stride_segment_axis, - gpu_ptr(target_prefix_dev), - gpu_ptr(prefix_mask_dev), - start_bit, - blocks_per_row, - gpu_ptr(row_hist_dev)); - - encoder.set_input_array(row_hist_dev); - encoder.set_input_array(target_prefix_dev); - encoder.set_input_array(prefix_mask_dev); - encoder.set_input_array(k_values_dev); - encoder.set_input_array(row_hist_dev); - encoder.set_output_array(target_prefix_dev); - encoder.set_output_array(prefix_mask_dev); - encoder.set_output_array(k_values_dev); - encoder.set_output_array(row_hist_dev); - encoder.add_kernel_node( - cu::radix_select_tiled_select_bin_kernel, - row_grid, - dim3(32, 1, 1), - 0, - gpu_ptr(row_hist_dev), - gpu_ptr(target_prefix_dev), - gpu_ptr(prefix_mask_dev), - gpu_ptr(k_values_dev), - pass > 0 ? 1 : 0, - start_bit); - } - - encoder.set_input_array(in); - encoder.set_input_array(target_prefix_dev); - encoder.set_output_array(block_less_dev); - encoder.set_output_array(block_equal_dev); - encoder.add_kernel_node( - cu::radix_select_tiled_count_kernel, - grid, - dim3(BLOCK_THREADS, 1, 1), - 0, - gpu_ptr(in), - size_sorted_axis, - in_stride_sorted_axis, - in_stride_segment_axis, - gpu_ptr(target_prefix_dev), - blocks_per_row, - gpu_ptr(block_less_dev), - gpu_ptr(block_equal_dev)); - - dispatch_bool(arg_partition, [&](auto arg_tag) { - constexpr bool ARG_PARTITION = decltype(arg_tag)::value; - using OutT = std::conditional_t; - encoder.set_input_array(in); - encoder.set_input_array(target_prefix_dev); - encoder.set_input_array(block_less_dev); - encoder.set_input_array(block_equal_dev); - encoder.set_output_array(out); - encoder.add_kernel_node( - cu::radix_select_tiled_scatter_kernel< - ValT, - OutT, - ARG_PARTITION, - BLOCK_THREADS>, - grid, - dim3(BLOCK_THREADS, 1, 1), - 0, - gpu_ptr(in), - gpu_ptr(out), - size_sorted_axis, - in_stride_sorted_axis, - out_stride_sorted_axis, - in_stride_segment_axis, - out_stride_segment_axis, - gpu_ptr(target_prefix_dev), - blocks_per_row, - gpu_ptr(block_less_dev), - gpu_ptr(block_equal_dev)); - }); - } else { - throw std::runtime_error( - "CUDA backend does not support sorting complex numbers"); - } - }); -} - -void gpu_radix_partition_large( - const Stream& s, - const array& in, - array& out, - int axis, - int kth, - bool arg_partition) { - int n_rows = in.size() / in.shape(axis); - - int size_sorted_axis = in.shape(axis); - int64_t in_stride_sorted_axis = in.strides()[axis]; - int64_t out_stride_sorted_axis = out.strides()[axis]; - - auto in_nc_str = in.strides(); - in_nc_str.erase(in_nc_str.begin() + axis); - - auto out_nc_str = out.strides(); - out_nc_str.erase(out_nc_str.begin() + axis); - - auto nc_shape = in.shape(); - nc_shape.erase(nc_shape.begin() + axis); - - int nc_dim = nc_shape.size(); - - bool contiguous = in.flags().contiguous; - auto check_strides = [](const array& x, int64_t sort_stride) { - int64_t min_stride = - *std::min_element(x.strides().begin(), x.strides().end()); - int64_t max_stride = - *std::max_element(x.strides().begin(), x.strides().end()); - return sort_stride == min_stride || sort_stride == max_stride; - }; - contiguous &= check_strides(in, in_stride_sorted_axis); - contiguous &= check_strides(out, out_stride_sorted_axis); - - if (contiguous) { - const auto plan = make_radix_tiled_launch_plan(s, n_rows, size_sorted_axis); - if (plan.uses_tiled_launch()) { - const int64_t in_stride_segment_axis = - segment_stride_for_contiguous(nc_shape, in_nc_str); - const int64_t out_stride_segment_axis = - segment_stride_for_contiguous(nc_shape, out_nc_str); - return gpu_radix_partition_large_tiled( - s, - in, - out, - kth, - n_rows, - size_sorted_axis, - in_stride_sorted_axis, - out_stride_sorted_axis, - in_stride_segment_axis, - out_stride_segment_axis, - plan.blocks_per_row, - arg_partition); - } - } - - auto& encoder = cu::get_command_encoder(s); - out.set_data(cu::malloc_async(out.nbytes(), encoder)); - encoder.set_input_array(in); - encoder.set_output_array(out); - - auto nc_shape_param = const_param(nc_shape); - auto in_nc_strides_param = const_param(in_nc_str); - auto out_nc_strides_param = const_param(out_nc_str); - - dispatch_all_types(in.dtype(), [&](auto type_tag) { - using CTYPE = MLX_GET_TYPE(type_tag); - if constexpr (!std::is_same_v) { - using ValT = cuda_type_t; - - constexpr int BLOCK_THREADS = 256; - - dim3 grid(1, n_rows, 1); - dim3 block(BLOCK_THREADS, 1, 1); - - dispatch_bool(arg_partition, [&](auto arg_tag) { - constexpr bool ARG_PARTITION = decltype(arg_tag)::value; - using OutT = std::conditional_t; - - // Large kernel always uses elem_to_loc addressing - auto kernel = cu::radix_select_large_streaming_kernel< - ValT, - OutT, - ARG_PARTITION, - BLOCK_THREADS>; - - encoder.add_kernel_node( - kernel, - grid, - block, - 0, - gpu_ptr(in), - gpu_ptr(out), - size_sorted_axis, - kth, - in_stride_sorted_axis, - out_stride_sorted_axis, - nc_shape_param, - in_nc_strides_param, - out_nc_strides_param, - nc_dim); - }); - } else { - throw std::runtime_error( - "CUDA backend does not support sorting complex numbers"); - } - }); -} - void gpu_radix_partition( const Stream& s, const array& in, @@ -1654,7 +1311,7 @@ void gpu_radix_partition( if (radix_small_fits_shared_memory(s, in.dtype(), size_sorted_axis)) { return gpu_radix_partition_small(s, in, out, axis, kth, arg_partition); } else { - return gpu_radix_partition_large(s, in, out, axis, kth, arg_partition); + return gpu_merge_sort(s, in, out, axis, arg_partition); } } From d924266e5ca15a5b6e810a25b3b3336834aa45f1 Mon Sep 17 00:00:00 2001 From: Lyxot Date: Tue, 3 Mar 2026 00:22:11 +0800 Subject: [PATCH 18/23] remove runtime fit routing --- benchmarks/python/radix_select_bench.py | 61 ++----- mlx/backend/cuda/sort.cu | 202 ++++++++++-------------- 2 files changed, 97 insertions(+), 166 deletions(-) diff --git a/benchmarks/python/radix_select_bench.py b/benchmarks/python/radix_select_bench.py index 1fcc3b436d..97a51f97d1 100644 --- a/benchmarks/python/radix_select_bench.py +++ b/benchmarks/python/radix_select_bench.py @@ -5,7 +5,6 @@ """ import argparse -import ctypes import time import mlx.core as mx @@ -33,8 +32,7 @@ MAX_RADIX_ITEMS_PER_THREAD = 64 RADIX_SIZE = 32 WARP_SIZE = 32 -CUDA_DEV_ATTR_MAX_SHARED_MEMORY_PER_BLOCK = 8 -CUDA_DEV_ATTR_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN = 97 +RADIX_SMALL_SHARED_MEM_BUDGET_BYTES = 48 * 1024 DTYPE_SIZE_BYTES = { "bool_": 1, @@ -58,51 +56,6 @@ def _dtype_size_bytes(dtype): return DTYPE_SIZE_BYTES[dtype_name] -def _cuda_max_shared_mem_per_block(default=48 * 1024): - """Query max(base, optin) shared memory per block; fallback to 48KB.""" - try: - cudart = ctypes.CDLL("libcudart.so") - - cuda_get_device = cudart.cudaGetDevice - cuda_get_device.argtypes = [ctypes.POINTER(ctypes.c_int)] - cuda_get_device.restype = ctypes.c_int - - cuda_device_get_attribute = cudart.cudaDeviceGetAttribute - cuda_device_get_attribute.argtypes = [ - ctypes.POINTER(ctypes.c_int), - ctypes.c_int, - ctypes.c_int, - ] - cuda_device_get_attribute.restype = ctypes.c_int - - dev = ctypes.c_int() - if cuda_get_device(ctypes.byref(dev)) != 0: - return default - - smem_base = ctypes.c_int() - if ( - cuda_device_get_attribute( - ctypes.byref(smem_base), - CUDA_DEV_ATTR_MAX_SHARED_MEMORY_PER_BLOCK, - dev.value, - ) - != 0 - ): - return default - - smem_optin = ctypes.c_int() - optin_rc = cuda_device_get_attribute( - ctypes.byref(smem_optin), - CUDA_DEV_ATTR_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN, - dev.value, - ) - if optin_rc == 0: - return max(int(smem_base.value), int(smem_optin.value)) - return int(smem_base.value) - except Exception: - return default - - def _radix_small_block_threads(vocab_size): if vocab_size <= 128: return 16 @@ -134,9 +87,9 @@ def _radix_small_shared_mem_bytes(dtype_size, block_threads, items_per_thread): def estimate_small_kernel_limit(dtype): - """Estimate max small-kernel axis for dtype under current CUDA radix policy.""" + """Estimate max small-kernel axis for dtype under the fixed 48KB budget.""" dtype_size = _dtype_size_bytes(dtype) - smem_limit = _cuda_max_shared_mem_per_block() + smem_limit = RADIX_SMALL_SHARED_MEM_BUDGET_BYTES max_axis = 0 # 256 is the largest block_threads in sort.cu launch selection. for v in range(1, 256 * MAX_RADIX_ITEMS_PER_THREAD + 1): @@ -341,7 +294,13 @@ def sweep_kernel( vocab_sizes = sorted({int(v) for v in candidate_vocab if v <= max_small_axis}) else: batch_sizes = [1, 2, 4, 8, 16, 32, 48, 64, 96, 128, 256, 512, 1024, 2048] - vocab_sizes = sorted({int(v) for v in candidate_vocab if v > small_kernel}) + vocab_sizes = sorted({int(v) for v in candidate_vocab if v > max_small_axis}) + + if not vocab_sizes: + print( + "No vocabulary sizes in sweep range for this dtype and shared-memory budget." + ) + return col_w = 10 print(f"{'':>8}", end="") diff --git a/mlx/backend/cuda/sort.cu b/mlx/backend/cuda/sort.cu index 9daef01865..6512dfbf39 100644 --- a/mlx/backend/cuda/sort.cu +++ b/mlx/backend/cuda/sort.cu @@ -1049,19 +1049,19 @@ void gpu_merge_sort( // Upper bound for small-kernel tiling. Keep this aligned with the // items-per-thread dispatch set and per-block shared-memory budget. constexpr int MAX_RADIX_ITEMS_PER_THREAD = 64; +constexpr size_t RADIX_SMALL_SHARED_MEM_BUDGET_BYTES = 48 * 1024; -int radix_small_block_threads(int size_sorted_axis) { - int block_threads = 256; - if (size_sorted_axis <= 128) { - block_threads = 16; - } else if (size_sorted_axis <= 256) { - block_threads = 32; +template +void dispatch_radix_small_block_threads(int size_sorted_axis, F&& f) { + if (size_sorted_axis <= 256) { + f(std::integral_constant{}); } else if (size_sorted_axis <= 512) { - block_threads = 64; + f(std::integral_constant{}); } else if (size_sorted_axis <= 1024) { - block_threads = 128; + f(std::integral_constant{}); + } else { + f(std::integral_constant{}); } - return block_threads; } template @@ -1084,10 +1084,6 @@ void dispatch_radix_items_per_thread( f(std::integral_constant{}); } else if (items_per_thread <= 24) { f(std::integral_constant{}); - } else if (items_per_thread <= 32) { - f(std::integral_constant{}); - } else if (items_per_thread <= 48) { - f(std::integral_constant{}); } else { f(std::integral_constant{}); } @@ -1106,51 +1102,30 @@ size_t radix_small_shared_mem_bytes( (2 + 3 * num_warps + 6) * sizeof(int); // shared_count + scatter scratch } -int radix_max_shared_mem_per_block(const Stream& s) { - int max_shared_mem_per_block = 0; - CHECK_CUDA_ERROR(cudaDeviceGetAttribute( - &max_shared_mem_per_block, - cudaDevAttrMaxSharedMemoryPerBlock, - s.device.index)); - - int max_shared_mem_per_block_optin = 0; - cudaError_t optin_err = cudaDeviceGetAttribute( - &max_shared_mem_per_block_optin, - cudaDevAttrMaxSharedMemoryPerBlockOptin, - s.device.index); - if (optin_err == cudaSuccess) { - max_shared_mem_per_block = - std::max(max_shared_mem_per_block, max_shared_mem_per_block_optin); - } else { - cudaGetLastError(); - } - return max_shared_mem_per_block; -} - -bool radix_small_fits_shared_memory( - const Stream& s, - Dtype dtype, - int size_sorted_axis) { +bool radix_small_fits_shared_memory(Dtype dtype, int size_sorted_axis) { if (size_sorted_axis <= 0) { return false; } - int block_threads = radix_small_block_threads(size_sorted_axis); - int required_items = (size_sorted_axis + block_threads - 1) / block_threads; - if (required_items > MAX_RADIX_ITEMS_PER_THREAD) { - return false; - } - size_t required_shared_mem = 0; - dispatch_radix_items_per_thread( - size_sorted_axis, block_threads, [&](auto items_per_thread_tag) { - constexpr int ITEMS_PER_THREAD = items_per_thread_tag(); - required_shared_mem = radix_small_shared_mem_bytes( - size_of(dtype), block_threads, ITEMS_PER_THREAD); - }); + bool fits = false; + dispatch_radix_small_block_threads(size_sorted_axis, [&](auto block_dim_tag) { + constexpr int BLOCK_THREADS = block_dim_tag(); + int required_items = (size_sorted_axis + BLOCK_THREADS - 1) / BLOCK_THREADS; + if (required_items > MAX_RADIX_ITEMS_PER_THREAD) { + fits = false; + return; + } - int max_shared_mem_per_block = radix_max_shared_mem_per_block(s); - return required_shared_mem <= static_cast(max_shared_mem_per_block); + dispatch_radix_items_per_thread( + size_sorted_axis, BLOCK_THREADS, [&](auto items_per_thread_tag) { + constexpr int ITEMS_PER_THREAD = items_per_thread_tag(); + required_shared_mem = radix_small_shared_mem_bytes( + size_of(dtype), BLOCK_THREADS, ITEMS_PER_THREAD); + fits = required_shared_mem <= RADIX_SMALL_SHARED_MEM_BUDGET_BYTES; + }); + }); + return fits; } void gpu_radix_partition_small( @@ -1202,8 +1177,6 @@ void gpu_radix_partition_small( if constexpr (!std::is_same_v) { using ValT = cuda_type_t; - int block_threads = radix_small_block_threads(size_sorted_axis); - dispatch_bool(arg_partition, [&](auto arg_tag) { constexpr bool ARG_PARTITION = decltype(arg_tag)::value; using OutT = std::conditional_t; @@ -1222,66 +1195,65 @@ void gpu_radix_partition_small( } } - dispatch_block_dim(block_threads, [&](auto block_dim) { - constexpr int BLOCK_THREADS = block_dim(); - dim3 grid(1, n_rows, 1); - dim3 block(BLOCK_THREADS, 1, 1); - - dispatch_radix_items_per_thread( - size_sorted_axis, block_threads, [&](auto items_per_thread_tag) { - constexpr int ITEMS_PER_THREAD = items_per_thread_tag(); - - dispatch_bool(contiguous, [&](auto contiguous_tag) { - constexpr bool USE_SIMPLE_STRIDE = - decltype(contiguous_tag)::value; - - auto kernel = cu::radix_select_small_kernel< - ValT, - OutT, - ARG_PARTITION, - USE_SIMPLE_STRIDE, - BLOCK_THREADS, - ITEMS_PER_THREAD>; - - // Calculate dynamic shared memory size - using UnsignedT = typename cu::RadixTraits::UnsignedT; - constexpr int TILE_SIZE_VAL = - BLOCK_THREADS * ITEMS_PER_THREAD; - constexpr int NUM_WARPS = BLOCK_THREADS / WARP_SIZE; - constexpr size_t shared_mem_bytes = - TILE_SIZE_VAL * sizeof(UnsignedT) + // shared_keys - TILE_SIZE_VAL * sizeof(uint32_t) + // shared_idxs - cu::RADIX_SIZE * - sizeof(int) + // shared_hist for small kernel - (2 + 3 * NUM_WARPS + 6) * - sizeof(int); // shared_count + scatter scratch - - CHECK_CUDA_ERROR(cudaFuncSetAttribute( - kernel, - cudaFuncAttributeMaxDynamicSharedMemorySize, - static_cast(shared_mem_bytes))); - - encoder.add_kernel_node_ex( - kernel, - grid, - block, - {}, - shared_mem_bytes, - gpu_ptr(in), - gpu_ptr(out), - kth, - size_sorted_axis, - in_stride_sorted_axis, - out_stride_sorted_axis, - in_stride_segment_axis, - out_stride_segment_axis, - nc_shape_param, - in_nc_strides_param, - out_nc_strides_param, - nc_dim); - }); - }); - }); + dispatch_radix_small_block_threads( + size_sorted_axis, [&](auto block_dim_tag) { + constexpr int BLOCK_THREADS = block_dim_tag(); + dim3 grid(1, n_rows, 1); + dim3 block(BLOCK_THREADS, 1, 1); + + dispatch_radix_items_per_thread( + size_sorted_axis, + BLOCK_THREADS, + [&](auto items_per_thread_tag) { + constexpr int ITEMS_PER_THREAD = items_per_thread_tag(); + + dispatch_bool(contiguous, [&](auto contiguous_tag) { + constexpr bool USE_SIMPLE_STRIDE = + decltype(contiguous_tag)::value; + + auto kernel = cu::radix_select_small_kernel< + ValT, + OutT, + ARG_PARTITION, + USE_SIMPLE_STRIDE, + BLOCK_THREADS, + ITEMS_PER_THREAD>; + + // Calculate dynamic shared memory size + using UnsignedT = + typename cu::RadixTraits::UnsignedT; + constexpr int TILE_SIZE_VAL = + BLOCK_THREADS * ITEMS_PER_THREAD; + constexpr int NUM_WARPS = BLOCK_THREADS / WARP_SIZE; + constexpr size_t shared_mem_bytes = + TILE_SIZE_VAL * sizeof(UnsignedT) + // shared_keys + TILE_SIZE_VAL * sizeof(uint32_t) + // shared_idxs + cu::RADIX_SIZE * + sizeof(int) + // shared_hist for small kernel + (2 + 3 * NUM_WARPS + 6) * + sizeof(int); // shared_count + scatter scratch + + encoder.add_kernel_node_ex( + kernel, + grid, + block, + {}, + shared_mem_bytes, + gpu_ptr(in), + gpu_ptr(out), + kth, + size_sorted_axis, + in_stride_sorted_axis, + out_stride_sorted_axis, + in_stride_segment_axis, + out_stride_segment_axis, + nc_shape_param, + in_nc_strides_param, + out_nc_strides_param, + nc_dim); + }); + }); + }); }); } else { throw std::runtime_error( @@ -1308,7 +1280,7 @@ void gpu_radix_partition( } // Dispatch based on whether the small kernel tile fits in shared memory. - if (radix_small_fits_shared_memory(s, in.dtype(), size_sorted_axis)) { + if (radix_small_fits_shared_memory(in.dtype(), size_sorted_axis)) { return gpu_radix_partition_small(s, in, out, axis, kth, arg_partition); } else { return gpu_merge_sort(s, in, out, axis, arg_partition); From 16cc8b690c0d7d222ad21739f37e22a0ad718a1f Mon Sep 17 00:00:00 2001 From: Lyxot Date: Thu, 26 Mar 2026 14:23:04 +0800 Subject: [PATCH 19/23] apply pr comment --- mlx/backend/cuda/CMakeLists.txt | 1 + mlx/backend/cuda/partition.cu | 267 +++++++++++++++++++++++++++++++ mlx/backend/cuda/sort.cu | 271 +++----------------------------- 3 files changed, 292 insertions(+), 247 deletions(-) create mode 100644 mlx/backend/cuda/partition.cu diff --git a/mlx/backend/cuda/CMakeLists.txt b/mlx/backend/cuda/CMakeLists.txt index 1cee777bbe..eb2db6f5bb 100644 --- a/mlx/backend/cuda/CMakeLists.txt +++ b/mlx/backend/cuda/CMakeLists.txt @@ -39,6 +39,7 @@ target_sources( ${CMAKE_CURRENT_SOURCE_DIR}/load.cpp ${CMAKE_CURRENT_SOURCE_DIR}/layer_norm.cu ${CMAKE_CURRENT_SOURCE_DIR}/logsumexp.cu + ${CMAKE_CURRENT_SOURCE_DIR}/partition.cu ${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp ${CMAKE_CURRENT_SOURCE_DIR}/random.cu ${CMAKE_CURRENT_SOURCE_DIR}/reduce.cu diff --git a/mlx/backend/cuda/partition.cu b/mlx/backend/cuda/partition.cu new file mode 100644 index 0000000000..c1fb3b3ca7 --- /dev/null +++ b/mlx/backend/cuda/partition.cu @@ -0,0 +1,267 @@ +// Copyright © 2026 Apple Inc. + +#include +#include +#include + +#include "mlx/backend/cuda/device.h" +#include "mlx/backend/cuda/device/radix_select.cuh" +#include "mlx/backend/cuda/kernel_utils.cuh" +#include "mlx/dtype.h" +#include "mlx/dtype_utils.h" + +namespace mlx::core { + +void gpu_partition_fallback( + const Stream& s, + const array& in, + array& out, + int axis, + bool arg_partition); + +namespace { + +// Upper bound for small-kernel tiling. Keep this aligned with the +// items-per-thread dispatch set and per-block shared-memory budget. +constexpr int MAX_RADIX_ITEMS_PER_THREAD = 64; +constexpr size_t RADIX_SMALL_SHARED_MEM_BUDGET_BYTES = 48 * 1024; + +template +void dispatch_radix_small_block_threads(int size_sorted_axis, F&& f) { + if (size_sorted_axis <= 256) { + f(std::integral_constant{}); + } else if (size_sorted_axis <= 512) { + f(std::integral_constant{}); + } else if (size_sorted_axis <= 1024) { + f(std::integral_constant{}); + } else { + f(std::integral_constant{}); + } +} + +template +void dispatch_radix_items_per_thread( + int size_sorted_axis, + int block_threads, + F&& f) { + int items_per_thread = (size_sorted_axis + block_threads - 1) / block_threads; + if (items_per_thread <= 1) { + f(std::integral_constant{}); + } else if (items_per_thread <= 2) { + f(std::integral_constant{}); + } else if (items_per_thread <= 4) { + f(std::integral_constant{}); + } else if (items_per_thread <= 8) { + f(std::integral_constant{}); + } else if (items_per_thread <= 12) { + f(std::integral_constant{}); + } else if (items_per_thread <= 16) { + f(std::integral_constant{}); + } else if (items_per_thread <= 24) { + f(std::integral_constant{}); + } else { + f(std::integral_constant{}); + } +} + +size_t radix_small_shared_mem_bytes( + size_t key_size, + int block_threads, + int items_per_thread) { + size_t tile_size = static_cast(block_threads) * + static_cast(items_per_thread); + size_t num_warps = static_cast(block_threads / WARP_SIZE); + return tile_size * key_size + // shared_keys + tile_size * sizeof(uint32_t) + // shared_idxs + cu::RADIX_SIZE * sizeof(int) + // shared_hist for small kernel + (2 + 3 * num_warps + 6) * sizeof(int); // shared_count + scatter scratch +} + +bool radix_small_fits_shared_memory(Dtype dtype, int size_sorted_axis) { + if (size_sorted_axis <= 0) { + return false; + } + + size_t required_shared_mem = 0; + bool fits = false; + dispatch_radix_small_block_threads(size_sorted_axis, [&](auto block_dim_tag) { + constexpr int BLOCK_THREADS = block_dim_tag(); + int required_items = (size_sorted_axis + BLOCK_THREADS - 1) / BLOCK_THREADS; + if (required_items > MAX_RADIX_ITEMS_PER_THREAD) { + fits = false; + return; + } + + dispatch_radix_items_per_thread( + size_sorted_axis, BLOCK_THREADS, [&](auto items_per_thread_tag) { + constexpr int ITEMS_PER_THREAD = items_per_thread_tag(); + required_shared_mem = radix_small_shared_mem_bytes( + size_of(dtype), BLOCK_THREADS, ITEMS_PER_THREAD); + fits = required_shared_mem <= RADIX_SMALL_SHARED_MEM_BUDGET_BYTES; + }); + }); + return fits; +} + +void gpu_radix_partition_small( + const Stream& s, + const array& in, + array& out, + int axis, + int kth, + bool arg_partition) { + int n_rows = in.size() / in.shape(axis); + + auto in_nc_str = in.strides(); + in_nc_str.erase(in_nc_str.begin() + axis); + + auto out_nc_str = out.strides(); + out_nc_str.erase(out_nc_str.begin() + axis); + + auto nc_shape = in.shape(); + nc_shape.erase(nc_shape.begin() + axis); + + int nc_dim = nc_shape.size(); + + int size_sorted_axis = in.shape(axis); + int64_t in_stride_sorted_axis = in.strides()[axis]; + int64_t out_stride_sorted_axis = out.strides()[axis]; + + bool contiguous = in.flags().contiguous; + auto check_strides = [](const array& x, int64_t sort_stride) { + int64_t min_stride = + *std::min_element(x.strides().begin(), x.strides().end()); + int64_t max_stride = + *std::max_element(x.strides().begin(), x.strides().end()); + return sort_stride == min_stride || sort_stride == max_stride; + }; + contiguous &= check_strides(in, in_stride_sorted_axis); + contiguous &= check_strides(out, out_stride_sorted_axis); + + auto& encoder = cu::get_command_encoder(s); + out.set_data(cu::malloc_async(out.nbytes(), encoder)); + encoder.set_input_array(in); + encoder.set_output_array(out); + + auto nc_shape_param = const_param(nc_shape); + auto in_nc_strides_param = const_param(in_nc_str); + auto out_nc_strides_param = const_param(out_nc_str); + + dispatch_all_types(in.dtype(), [&](auto type_tag) { + using CTYPE = MLX_GET_TYPE(type_tag); + if constexpr (!std::is_same_v) { + using ValT = cuda_type_t; + + dispatch_bool(arg_partition, [&](auto arg_tag) { + constexpr bool ARG_PARTITION = decltype(arg_tag)::value; + using OutT = std::conditional_t; + + int64_t in_stride_segment_axis = INT64_MAX; + int64_t out_stride_segment_axis = INT64_MAX; + if (contiguous) { + for (size_t i = 0; i < nc_shape.size(); i++) { + if (nc_shape[i] == 1) { + continue; + } + in_stride_segment_axis = + std::min(in_stride_segment_axis, in_nc_str[i]); + out_stride_segment_axis = + std::min(out_stride_segment_axis, out_nc_str[i]); + } + } + + dispatch_radix_small_block_threads( + size_sorted_axis, [&](auto block_dim_tag) { + constexpr int BLOCK_THREADS = block_dim_tag(); + dim3 grid(1, n_rows, 1); + dim3 block(BLOCK_THREADS, 1, 1); + + dispatch_radix_items_per_thread( + size_sorted_axis, + BLOCK_THREADS, + [&](auto items_per_thread_tag) { + constexpr int ITEMS_PER_THREAD = items_per_thread_tag(); + + dispatch_bool(contiguous, [&](auto contiguous_tag) { + constexpr bool USE_SIMPLE_STRIDE = + decltype(contiguous_tag)::value; + + auto kernel = cu::radix_select_small_kernel< + ValT, + OutT, + ARG_PARTITION, + USE_SIMPLE_STRIDE, + BLOCK_THREADS, + ITEMS_PER_THREAD>; + + // Calculate dynamic shared memory size + using UnsignedT = + typename cu::RadixTraits::UnsignedT; + constexpr int TILE_SIZE_VAL = + BLOCK_THREADS * ITEMS_PER_THREAD; + constexpr int NUM_WARPS = BLOCK_THREADS / WARP_SIZE; + constexpr size_t shared_mem_bytes = + TILE_SIZE_VAL * sizeof(UnsignedT) + // shared_keys + TILE_SIZE_VAL * sizeof(uint32_t) + // shared_idxs + cu::RADIX_SIZE * + sizeof(int) + // shared_hist for small kernel + (2 + 3 * NUM_WARPS + 6) * + sizeof(int); // shared_count + scatter scratch + + encoder.add_kernel_node_ex( + kernel, + grid, + block, + {}, + static_cast(shared_mem_bytes), + gpu_ptr(in), + gpu_ptr(out), + kth, + size_sorted_axis, + in_stride_sorted_axis, + out_stride_sorted_axis, + in_stride_segment_axis, + out_stride_segment_axis, + nc_shape_param, + in_nc_strides_param, + out_nc_strides_param, + nc_dim); + }); + }); + }); + }); + } else { + throw std::runtime_error( + "CUDA backend does not support sorting complex numbers"); + } + }); +} + +} // namespace + +void gpu_partition( + const Stream& s, + const array& in, + array& out, + int axis_, + int kth_, + bool arg_partition) { + int axis = axis_ < 0 ? axis_ + in.ndim() : axis_; + int size_sorted_axis = in.shape(axis); + int kth = kth_ < 0 ? kth_ + size_sorted_axis : kth_; + int nc_dim = static_cast(in.ndim()) - 1; + + // Fixed-size const_param metadata is capped by MAX_NDIM. + if (nc_dim > MAX_NDIM) { + return gpu_partition_fallback(s, in, out, axis, arg_partition); + } + + // Dispatch based on whether the small kernel tile fits in shared memory. + if (radix_small_fits_shared_memory(in.dtype(), size_sorted_axis)) { + return gpu_radix_partition_small(s, in, out, axis, kth, arg_partition); + } else { + return gpu_partition_fallback(s, in, out, axis, arg_partition); + } +} + +} // namespace mlx::core diff --git a/mlx/backend/cuda/sort.cu b/mlx/backend/cuda/sort.cu index 6512dfbf39..61ff3d3634 100644 --- a/mlx/backend/cuda/sort.cu +++ b/mlx/backend/cuda/sort.cu @@ -6,10 +6,8 @@ #include "mlx/backend/cuda/device.h" #include "mlx/backend/cuda/device/fp16_math.cuh" -#include "mlx/backend/cuda/device/radix_select.cuh" #include "mlx/backend/cuda/kernel_utils.cuh" #include "mlx/backend/gpu/copy.h" -#include "mlx/dtype.h" #include "mlx/dtype_utils.h" #include "mlx/primitives.h" @@ -720,6 +718,21 @@ __global__ void mb_block_merge_kernel( } // namespace cu +void gpu_partition( + const Stream& s, + const array& in, + array& out, + int axis_, + int kth_, + bool arg_partition); + +void gpu_partition_fallback( + const Stream& s, + const array& in, + array& out, + int axis, + bool arg_partition); + namespace { void single_block_sort( @@ -1042,263 +1055,27 @@ void gpu_merge_sort( return single_block_sort(s, in, out, axis, bn, argsort); } -/////////////////////////////////////////////////////////////////////////////// -// Radix partition functions -/////////////////////////////////////////////////////////////////////////////// - -// Upper bound for small-kernel tiling. Keep this aligned with the -// items-per-thread dispatch set and per-block shared-memory budget. -constexpr int MAX_RADIX_ITEMS_PER_THREAD = 64; -constexpr size_t RADIX_SMALL_SHARED_MEM_BUDGET_BYTES = 48 * 1024; - -template -void dispatch_radix_small_block_threads(int size_sorted_axis, F&& f) { - if (size_sorted_axis <= 256) { - f(std::integral_constant{}); - } else if (size_sorted_axis <= 512) { - f(std::integral_constant{}); - } else if (size_sorted_axis <= 1024) { - f(std::integral_constant{}); - } else { - f(std::integral_constant{}); - } -} - -template -void dispatch_radix_items_per_thread( - int size_sorted_axis, - int block_threads, - F&& f) { - int items_per_thread = (size_sorted_axis + block_threads - 1) / block_threads; - if (items_per_thread <= 1) { - f(std::integral_constant{}); - } else if (items_per_thread <= 2) { - f(std::integral_constant{}); - } else if (items_per_thread <= 4) { - f(std::integral_constant{}); - } else if (items_per_thread <= 8) { - f(std::integral_constant{}); - } else if (items_per_thread <= 12) { - f(std::integral_constant{}); - } else if (items_per_thread <= 16) { - f(std::integral_constant{}); - } else if (items_per_thread <= 24) { - f(std::integral_constant{}); - } else { - f(std::integral_constant{}); - } -} - -size_t radix_small_shared_mem_bytes( - size_t key_size, - int block_threads, - int items_per_thread) { - size_t tile_size = static_cast(block_threads) * - static_cast(items_per_thread); - size_t num_warps = static_cast(block_threads / WARP_SIZE); - return tile_size * key_size + // shared_keys - tile_size * sizeof(uint32_t) + // shared_idxs - cu::RADIX_SIZE * sizeof(int) + // shared_hist for small kernel - (2 + 3 * num_warps + 6) * sizeof(int); // shared_count + scatter scratch -} - -bool radix_small_fits_shared_memory(Dtype dtype, int size_sorted_axis) { - if (size_sorted_axis <= 0) { - return false; - } - - size_t required_shared_mem = 0; - bool fits = false; - dispatch_radix_small_block_threads(size_sorted_axis, [&](auto block_dim_tag) { - constexpr int BLOCK_THREADS = block_dim_tag(); - int required_items = (size_sorted_axis + BLOCK_THREADS - 1) / BLOCK_THREADS; - if (required_items > MAX_RADIX_ITEMS_PER_THREAD) { - fits = false; - return; - } - - dispatch_radix_items_per_thread( - size_sorted_axis, BLOCK_THREADS, [&](auto items_per_thread_tag) { - constexpr int ITEMS_PER_THREAD = items_per_thread_tag(); - required_shared_mem = radix_small_shared_mem_bytes( - size_of(dtype), BLOCK_THREADS, ITEMS_PER_THREAD); - fits = required_shared_mem <= RADIX_SMALL_SHARED_MEM_BUDGET_BYTES; - }); - }); - return fits; -} - -void gpu_radix_partition_small( +void gpu_sort( const Stream& s, const array& in, array& out, int axis, - int kth, - bool arg_partition) { - int n_rows = in.size() / in.shape(axis); - - auto in_nc_str = in.strides(); - in_nc_str.erase(in_nc_str.begin() + axis); - - auto out_nc_str = out.strides(); - out_nc_str.erase(out_nc_str.begin() + axis); - - auto nc_shape = in.shape(); - nc_shape.erase(nc_shape.begin() + axis); - - int nc_dim = nc_shape.size(); - - int size_sorted_axis = in.shape(axis); - int64_t in_stride_sorted_axis = in.strides()[axis]; - int64_t out_stride_sorted_axis = out.strides()[axis]; - - bool contiguous = in.flags().contiguous; - auto check_strides = [](const array& x, int64_t sort_stride) { - int64_t min_stride = - *std::min_element(x.strides().begin(), x.strides().end()); - int64_t max_stride = - *std::max_element(x.strides().begin(), x.strides().end()); - return sort_stride == min_stride || sort_stride == max_stride; - }; - contiguous &= check_strides(in, in_stride_sorted_axis); - contiguous &= check_strides(out, out_stride_sorted_axis); - + bool argsort) { auto& encoder = cu::get_command_encoder(s); - out.set_data(cu::malloc_async(out.nbytes(), encoder)); - encoder.set_input_array(in); - encoder.set_output_array(out); - - auto nc_shape_param = const_param(nc_shape); - auto in_nc_strides_param = const_param(in_nc_str); - auto out_nc_strides_param = const_param(out_nc_str); - - dispatch_all_types(in.dtype(), [&](auto type_tag) { - using CTYPE = MLX_GET_TYPE(type_tag); - if constexpr (!std::is_same_v) { - using ValT = cuda_type_t; - - dispatch_bool(arg_partition, [&](auto arg_tag) { - constexpr bool ARG_PARTITION = decltype(arg_tag)::value; - using OutT = std::conditional_t; - - int64_t in_stride_segment_axis = INT64_MAX; - int64_t out_stride_segment_axis = INT64_MAX; - if (contiguous) { - for (size_t i = 0; i < nc_shape.size(); i++) { - if (nc_shape[i] == 1) { - continue; - } - in_stride_segment_axis = - std::min(in_stride_segment_axis, in_nc_str[i]); - out_stride_segment_axis = - std::min(out_stride_segment_axis, out_nc_str[i]); - } - } - - dispatch_radix_small_block_threads( - size_sorted_axis, [&](auto block_dim_tag) { - constexpr int BLOCK_THREADS = block_dim_tag(); - dim3 grid(1, n_rows, 1); - dim3 block(BLOCK_THREADS, 1, 1); - - dispatch_radix_items_per_thread( - size_sorted_axis, - BLOCK_THREADS, - [&](auto items_per_thread_tag) { - constexpr int ITEMS_PER_THREAD = items_per_thread_tag(); - - dispatch_bool(contiguous, [&](auto contiguous_tag) { - constexpr bool USE_SIMPLE_STRIDE = - decltype(contiguous_tag)::value; - - auto kernel = cu::radix_select_small_kernel< - ValT, - OutT, - ARG_PARTITION, - USE_SIMPLE_STRIDE, - BLOCK_THREADS, - ITEMS_PER_THREAD>; - - // Calculate dynamic shared memory size - using UnsignedT = - typename cu::RadixTraits::UnsignedT; - constexpr int TILE_SIZE_VAL = - BLOCK_THREADS * ITEMS_PER_THREAD; - constexpr int NUM_WARPS = BLOCK_THREADS / WARP_SIZE; - constexpr size_t shared_mem_bytes = - TILE_SIZE_VAL * sizeof(UnsignedT) + // shared_keys - TILE_SIZE_VAL * sizeof(uint32_t) + // shared_idxs - cu::RADIX_SIZE * - sizeof(int) + // shared_hist for small kernel - (2 + 3 * NUM_WARPS + 6) * - sizeof(int); // shared_count + scatter scratch - - encoder.add_kernel_node_ex( - kernel, - grid, - block, - {}, - shared_mem_bytes, - gpu_ptr(in), - gpu_ptr(out), - kth, - size_sorted_axis, - in_stride_sorted_axis, - out_stride_sorted_axis, - in_stride_segment_axis, - out_stride_segment_axis, - nc_shape_param, - in_nc_strides_param, - out_nc_strides_param, - nc_dim); - }); - }); - }); - }); - } else { - throw std::runtime_error( - "CUDA backend does not support sorting complex numbers"); - } - }); + gpu_merge_sort(s, in, out, axis, argsort); } -void gpu_radix_partition( - const Stream& s, - const array& in, - array& out, - int axis_, - int kth_, - bool arg_partition) { - int axis = axis_ < 0 ? axis_ + in.ndim() : axis_; - int size_sorted_axis = in.shape(axis); - int kth = kth_ < 0 ? kth_ + size_sorted_axis : kth_; - int nc_dim = static_cast(in.ndim()) - 1; - - // Fixed-size const_param metadata is capped by MAX_NDIM. - if (nc_dim > MAX_NDIM) { - return gpu_merge_sort(s, in, out, axis, arg_partition); - } - - // Dispatch based on whether the small kernel tile fits in shared memory. - if (radix_small_fits_shared_memory(in.dtype(), size_sorted_axis)) { - return gpu_radix_partition_small(s, in, out, axis, kth, arg_partition); - } else { - return gpu_merge_sort(s, in, out, axis, arg_partition); - } -} +} // namespace -void gpu_sort( +void gpu_partition_fallback( const Stream& s, const array& in, array& out, int axis, - bool argsort) { - auto& encoder = cu::get_command_encoder(s); - gpu_merge_sort(s, in, out, axis, argsort); + bool arg_partition) { + gpu_merge_sort(s, in, out, axis, arg_partition); } -} // namespace - void ArgSort::eval_gpu(const std::vector& inputs, array& out) { nvtx3::scoped_range r("ArgSort::eval_gpu"); assert(inputs.size() == 1); @@ -1313,12 +1090,12 @@ void Sort::eval_gpu(const std::vector& inputs, array& out) { void ArgPartition::eval_gpu(const std::vector& inputs, array& out) { nvtx3::scoped_range r("ArgPartition::eval_gpu"); - gpu_radix_partition(stream(), inputs[0], out, axis_, kth_, true); + gpu_partition(stream(), inputs[0], out, axis_, kth_, true); } void Partition::eval_gpu(const std::vector& inputs, array& out) { nvtx3::scoped_range r("Partition::eval_gpu"); - gpu_radix_partition(stream(), inputs[0], out, axis_, kth_, false); + gpu_partition(stream(), inputs[0], out, axis_, kth_, false); } } // namespace mlx::core From 3e5712fbdd4e148f2499501d207d64f8943dfc51 Mon Sep 17 00:00:00 2001 From: Lyxot Date: Thu, 26 Mar 2026 18:52:21 +0800 Subject: [PATCH 20/23] add complex64 support --- benchmarks/python/radix_select_bench.py | 2 + mlx/backend/cuda/device/radix_select.cuh | 21 ++- mlx/backend/cuda/partition.cu | 158 +++++++++++------------ 3 files changed, 98 insertions(+), 83 deletions(-) diff --git a/benchmarks/python/radix_select_bench.py b/benchmarks/python/radix_select_bench.py index 97a51f97d1..98acd876a5 100644 --- a/benchmarks/python/radix_select_bench.py +++ b/benchmarks/python/radix_select_bench.py @@ -25,6 +25,7 @@ "uint16": mx.uint16, "uint32": mx.uint32, "uint64": mx.uint64, + "complex64": mx.complex64, } # Benchmark-side model for cross-GPU small-kernel dispatch policy. @@ -48,6 +49,7 @@ "uint16": 2, "uint32": 4, "uint64": 8, + "complex64": 8, } diff --git a/mlx/backend/cuda/device/radix_select.cuh b/mlx/backend/cuda/device/radix_select.cuh index 9584aafaca..10d36c9015 100644 --- a/mlx/backend/cuda/device/radix_select.cuh +++ b/mlx/backend/cuda/device/radix_select.cuh @@ -57,6 +57,25 @@ struct RadixTraits { } }; +template <> +struct RadixTraits { + using UnsignedT = uint64_t; + static constexpr int BITS = 64; + + __device__ __forceinline__ static UnsignedT to_radix(complex64_t val) { + float real = val.real(); + float imag = val.imag(); + if (cuda::std::isnan(real) || cuda::std::isnan(imag)) { + return ~UnsignedT(0); + } + + auto real_key = RadixTraits::to_radix(real); + auto imag_key = RadixTraits::to_radix(imag); + return (static_cast(real_key) << 32) | + static_cast(imag_key); + } +}; + template <> struct RadixTraits<__half> { using UnsignedT = uint16_t; @@ -472,7 +491,7 @@ __global__ void radix_select_small_kernel( pos = running_bases[2] + warp_greater[warp] + greater_rank; } - if (ARG_PARTITION) { + if constexpr (ARG_PARTITION) { row_output[pos * out_stride] = shared_idxs[i]; } else { row_output[pos * out_stride] = row_input[shared_idxs[i] * in_stride]; diff --git a/mlx/backend/cuda/partition.cu b/mlx/backend/cuda/partition.cu index c1fb3b3ca7..599f7930d8 100644 --- a/mlx/backend/cuda/partition.cu +++ b/mlx/backend/cuda/partition.cu @@ -149,91 +149,85 @@ void gpu_radix_partition_small( dispatch_all_types(in.dtype(), [&](auto type_tag) { using CTYPE = MLX_GET_TYPE(type_tag); - if constexpr (!std::is_same_v) { - using ValT = cuda_type_t; - - dispatch_bool(arg_partition, [&](auto arg_tag) { - constexpr bool ARG_PARTITION = decltype(arg_tag)::value; - using OutT = std::conditional_t; - - int64_t in_stride_segment_axis = INT64_MAX; - int64_t out_stride_segment_axis = INT64_MAX; - if (contiguous) { - for (size_t i = 0; i < nc_shape.size(); i++) { - if (nc_shape[i] == 1) { - continue; - } - in_stride_segment_axis = - std::min(in_stride_segment_axis, in_nc_str[i]); - out_stride_segment_axis = - std::min(out_stride_segment_axis, out_nc_str[i]); + using ValT = cuda_type_t; + + dispatch_bool(arg_partition, [&](auto arg_tag) { + constexpr bool ARG_PARTITION = decltype(arg_tag)::value; + using OutT = std::conditional_t; + + int64_t in_stride_segment_axis = INT64_MAX; + int64_t out_stride_segment_axis = INT64_MAX; + if (contiguous) { + for (size_t i = 0; i < nc_shape.size(); i++) { + if (nc_shape[i] == 1) { + continue; } + in_stride_segment_axis = + std::min(in_stride_segment_axis, in_nc_str[i]); + out_stride_segment_axis = + std::min(out_stride_segment_axis, out_nc_str[i]); } - - dispatch_radix_small_block_threads( - size_sorted_axis, [&](auto block_dim_tag) { - constexpr int BLOCK_THREADS = block_dim_tag(); - dim3 grid(1, n_rows, 1); - dim3 block(BLOCK_THREADS, 1, 1); - - dispatch_radix_items_per_thread( - size_sorted_axis, - BLOCK_THREADS, - [&](auto items_per_thread_tag) { - constexpr int ITEMS_PER_THREAD = items_per_thread_tag(); - - dispatch_bool(contiguous, [&](auto contiguous_tag) { - constexpr bool USE_SIMPLE_STRIDE = - decltype(contiguous_tag)::value; - - auto kernel = cu::radix_select_small_kernel< - ValT, - OutT, - ARG_PARTITION, - USE_SIMPLE_STRIDE, - BLOCK_THREADS, - ITEMS_PER_THREAD>; - - // Calculate dynamic shared memory size - using UnsignedT = - typename cu::RadixTraits::UnsignedT; - constexpr int TILE_SIZE_VAL = - BLOCK_THREADS * ITEMS_PER_THREAD; - constexpr int NUM_WARPS = BLOCK_THREADS / WARP_SIZE; - constexpr size_t shared_mem_bytes = - TILE_SIZE_VAL * sizeof(UnsignedT) + // shared_keys - TILE_SIZE_VAL * sizeof(uint32_t) + // shared_idxs - cu::RADIX_SIZE * - sizeof(int) + // shared_hist for small kernel - (2 + 3 * NUM_WARPS + 6) * - sizeof(int); // shared_count + scatter scratch - - encoder.add_kernel_node_ex( - kernel, - grid, - block, - {}, - static_cast(shared_mem_bytes), - gpu_ptr(in), - gpu_ptr(out), - kth, - size_sorted_axis, - in_stride_sorted_axis, - out_stride_sorted_axis, - in_stride_segment_axis, - out_stride_segment_axis, - nc_shape_param, - in_nc_strides_param, - out_nc_strides_param, - nc_dim); - }); + } + + dispatch_radix_small_block_threads( + size_sorted_axis, [&](auto block_dim_tag) { + constexpr int BLOCK_THREADS = block_dim_tag(); + dim3 grid(1, n_rows, 1); + dim3 block(BLOCK_THREADS, 1, 1); + + dispatch_radix_items_per_thread( + size_sorted_axis, + BLOCK_THREADS, + [&](auto items_per_thread_tag) { + constexpr int ITEMS_PER_THREAD = items_per_thread_tag(); + + dispatch_bool(contiguous, [&](auto contiguous_tag) { + constexpr bool USE_SIMPLE_STRIDE = + decltype(contiguous_tag)::value; + + auto kernel = cu::radix_select_small_kernel< + ValT, + OutT, + ARG_PARTITION, + USE_SIMPLE_STRIDE, + BLOCK_THREADS, + ITEMS_PER_THREAD>; + + // Calculate dynamic shared memory size + using UnsignedT = typename cu::RadixTraits::UnsignedT; + constexpr int TILE_SIZE_VAL = + BLOCK_THREADS * ITEMS_PER_THREAD; + constexpr int NUM_WARPS = BLOCK_THREADS / WARP_SIZE; + constexpr size_t shared_mem_bytes = + TILE_SIZE_VAL * sizeof(UnsignedT) + // shared_keys + TILE_SIZE_VAL * sizeof(uint32_t) + // shared_idxs + cu::RADIX_SIZE * + sizeof(int) + // shared_hist for small kernel + (2 + 3 * NUM_WARPS + 6) * + sizeof(int); // shared_count + scatter scratch + + encoder.add_kernel_node_ex( + kernel, + grid, + block, + {}, + static_cast(shared_mem_bytes), + gpu_ptr(in), + gpu_ptr(out), + kth, + size_sorted_axis, + in_stride_sorted_axis, + out_stride_sorted_axis, + in_stride_segment_axis, + out_stride_segment_axis, + nc_shape_param, + in_nc_strides_param, + out_nc_strides_param, + nc_dim); }); - }); - }); - } else { - throw std::runtime_error( - "CUDA backend does not support sorting complex numbers"); - } + }); + }); + }); }); } From 2fca6476a71e99b68fa6987b190ec4c6b62a7652 Mon Sep 17 00:00:00 2001 From: Lyxot Date: Fri, 27 Mar 2026 18:32:01 +0800 Subject: [PATCH 21/23] simplify --- benchmarks/python/radix_select_bench.py | 508 +++++++---------------- mlx/backend/cuda/device/radix_select.cuh | 45 +- mlx/backend/cuda/partition.cu | 30 +- 3 files changed, 174 insertions(+), 409 deletions(-) diff --git a/benchmarks/python/radix_select_bench.py b/benchmarks/python/radix_select_bench.py index 98acd876a5..17f1f5c9dd 100644 --- a/benchmarks/python/radix_select_bench.py +++ b/benchmarks/python/radix_select_bench.py @@ -1,8 +1,5 @@ #!/usr/bin/env python3 -""" -Benchmark script for MLX argpartition/partition operations. -Compares radix select implementation against full argsort. -""" +"""Benchmark and verify MLX argpartition/partition (radix select).""" import argparse import time @@ -10,270 +7,142 @@ import mlx.core as mx import numpy as np -# Mapping from string names to MLX dtype objects -DTYPE_MAP = { - "bool": mx.bool_, - "bfloat16": mx.bfloat16, - "float16": mx.float16, - "float32": mx.float32, - "float64": mx.float64, - "int8": mx.int8, - "int16": mx.int16, - "int32": mx.int32, - "int64": mx.int64, - "uint8": mx.uint8, - "uint16": mx.uint16, - "uint32": mx.uint32, - "uint64": mx.uint64, - "complex64": mx.complex64, -} - -# Benchmark-side model for cross-GPU small-kernel dispatch policy. -RADIX_ITEMS_BUCKETS = (1, 2, 4, 8, 12, 16, 24, 32, 48, 64) + +def _resolve_dtype(name): + dt = getattr(mx, name, None) or getattr(mx, name + "_", None) + if dt is None or not isinstance(dt, mx.Dtype): + raise ValueError(f"Unknown dtype: {name}") + return dt + + +# Must match partition.cu dispatch constants. MAX_RADIX_ITEMS_PER_THREAD = 64 RADIX_SIZE = 32 WARP_SIZE = 32 -RADIX_SMALL_SHARED_MEM_BUDGET_BYTES = 48 * 1024 - -DTYPE_SIZE_BYTES = { - "bool_": 1, - "bfloat16": 2, - "float16": 2, - "float32": 4, - "float64": 8, - "int8": 1, - "int16": 2, - "int32": 4, - "int64": 8, - "uint8": 1, - "uint16": 2, - "uint32": 4, - "uint64": 8, - "complex64": 8, -} - - -def _dtype_size_bytes(dtype): - dtype_name = str(dtype).split(".")[-1] - return DTYPE_SIZE_BYTES[dtype_name] - - -def _radix_small_block_threads(vocab_size): - if vocab_size <= 128: - return 16 - if vocab_size <= 256: - return 32 - if vocab_size <= 512: - return 64 - if vocab_size <= 1024: - return 128 +SMEM_BUDGET = 48 * 1024 + +BLOCK_THRESHOLDS = [(256, 32), (512, 64), (1024, 128)] +ITEMS_BUCKETS = (1, 2, 4, 8, 12, 16, 24, 64) + + +def _dtype_name(dtype): + return str(dtype).split(".")[-1] + + +def _key_bytes(dtype): + return dtype.size + + +def _block_threads(axis_size): + for threshold, threads in BLOCK_THRESHOLDS: + if axis_size <= threshold: + return threads return 256 -def _radix_small_dispatch_items(required_items): - for bucket in RADIX_ITEMS_BUCKETS: - if required_items <= bucket: - return bucket +def _items_per_thread(axis_size, block_threads): + needed = (axis_size + block_threads - 1) // block_threads + for b in ITEMS_BUCKETS: + if needed <= b: + return b return None -def _radix_small_shared_mem_bytes(dtype_size, block_threads, items_per_thread): - tile_size = block_threads * items_per_thread - num_warps = block_threads // WARP_SIZE - return ( - tile_size * dtype_size - + tile_size * 4 - + RADIX_SIZE * 4 - + (2 + 3 * num_warps + 6) * 4 - ) +def _smem_bytes(key_size, block_threads, items_per_thread): + tile = block_threads * items_per_thread + warps = block_threads // WARP_SIZE + return tile * key_size + RADIX_SIZE * 4 + (2 + 3 * warps + 6) * 4 -def estimate_small_kernel_limit(dtype): - """Estimate max small-kernel axis for dtype under the fixed 48KB budget.""" - dtype_size = _dtype_size_bytes(dtype) - smem_limit = RADIX_SMALL_SHARED_MEM_BUDGET_BYTES - max_axis = 0 - # 256 is the largest block_threads in sort.cu launch selection. +def max_small_kernel_axis(dtype): + """Largest axis size the radix small kernel can handle for dtype.""" + ks = _key_bytes(dtype) + best = 0 for v in range(1, 256 * MAX_RADIX_ITEMS_PER_THREAD + 1): - block_threads = _radix_small_block_threads(v) - required_items = (v + block_threads - 1) // block_threads - if required_items > MAX_RADIX_ITEMS_PER_THREAD: - continue - items_per_thread = _radix_small_dispatch_items(required_items) - if items_per_thread is None: - continue - if ( - _radix_small_shared_mem_bytes(dtype_size, block_threads, items_per_thread) - <= smem_limit - ): - max_axis = v - return { - "max_axis": max_axis, - "smem_limit": smem_limit, - } - - -def parse_dtypes(dtype_str): - """Parse comma-separated dtype string into MLX dtype objects.""" - dtypes = [] - for dtype_str_item in dtype_str.split(","): - dtype_str_item = dtype_str_item.strip().lower() - if not dtype_str_item: + bt = _block_threads(v) + ipt = _items_per_thread(v, bt) + if ipt is None or ipt > MAX_RADIX_ITEMS_PER_THREAD: continue - if dtype_str_item not in DTYPE_MAP: - raise ValueError( - f"Unknown dtype: {dtype_str_item}. " - f"Supported dtypes: {', '.join(DTYPE_MAP.keys())}" - ) - dtypes.append((DTYPE_MAP[dtype_str_item], dtype_str_item)) - return dtypes + if _smem_bytes(ks, bt, ipt) <= SMEM_BUDGET: + best = v + return best -def benchmark_argpartition(b, v, k, dtype=mx.bfloat16, warmup=5, iterations=100): - x = mx.random.uniform(shape=(b, v)).astype(dtype) - mx.eval(x) - for _ in range(warmup): - mx.eval(mx.argpartition(x, kth=k, axis=-1)) - start = time.perf_counter() - for _ in range(iterations): - mx.eval(mx.argpartition(x, kth=k, axis=-1)) - return (time.perf_counter() - start) / iterations * 1000 - - -def benchmark_argsort(b, v, dtype=mx.bfloat16, warmup=5, iterations=100): - x = mx.random.uniform(shape=(b, v)).astype(dtype) - mx.eval(x) - for _ in range(warmup): - mx.eval(mx.argsort(x, axis=-1)) - start = time.perf_counter() - for _ in range(iterations): - mx.eval(mx.argsort(x, axis=-1)) - return (time.perf_counter() - start) / iterations * 1000 +def parse_dtypes(s): + return [_resolve_dtype(n.strip().lower()) for n in s.split(",") if n.strip()] def verify_correctness(b, v, k, dtype=mx.float32): - # Quantize random values to induce duplicates and stress tie handling. - x = mx.random.uniform(shape=(b, v)) - x = mx.floor(x * 257.0).astype(mx.float32) - x = x.astype(dtype) + x = mx.floor(mx.random.uniform(shape=(b, v)) * 257.0).astype(dtype) mx.eval(x) indices = mx.argpartition(x, kth=k, axis=-1) mx.eval(indices) - # NumPy does not always expose bfloat16 buffers reliably in this environment. x_np = np.array(x.astype(mx.float32)) if dtype == mx.bfloat16 else np.array(x) - indices_np = np.array(indices) + idx = np.array(indices) is_float = np.issubdtype(x_np.dtype, np.floating) - assert indices_np.shape == ( - b, - v, - ), f"Unexpected argpartition output shape: got {indices_np.shape}, expected {(b, v)}" - assert np.issubdtype( - indices_np.dtype, np.integer - ), f"Argpartition indices must be integer, got {indices_np.dtype}" - for i in range(b): - row = x_np[i] - row_idx = indices_np[i] + row, ri = x_np[i], idx[i] + assert np.unique(ri).size == v, f"Row {i}: not a permutation" - assert np.all( - (row_idx >= 0) & (row_idx < v) - ), f"Row {i}: out-of-range indices found" - assert ( - np.unique(row_idx).size == v - ), f"Row {i}: indices are not a permutation of [0, {v})" - - pv = row[row_idx] - pivot = pv[k] - left = pv[:k] - right = pv[k + 1 :] + pv = row[ri] + pivot, left, right = pv[k], pv[:k], pv[k + 1 :] if is_float and np.isnan(pivot): - non_nan_count = np.count_nonzero(~np.isnan(row)) - assert ( - non_nan_count <= k - ), f"Row {i}: pivot is NaN before all finite values are placed" - assert np.all( - np.isnan(pv[k:]) - ), f"Row {i}: values after NaN pivot must all be NaN" + assert np.all(np.isnan(pv[k:])), f"Row {i}: non-NaN after NaN pivot" continue if is_float: - left_ok = np.all((~np.isnan(left)) & (left <= pivot)) - right_ok = np.all(np.isnan(right) | (right >= pivot)) + assert np.all( + (~np.isnan(left)) & (left <= pivot) + ), f"Row {i}: left violation" + assert np.all( + np.isnan(right) | (right >= pivot) + ), f"Row {i}: right violation" else: - left_ok = np.all(left <= pivot) - right_ok = np.all(right >= pivot) - - assert left_ok, f"Row {i}: elements before kth violate partition property" - assert right_ok, f"Row {i}: elements after kth violate partition property" + assert np.all(left <= pivot), f"Row {i}: left violation" + assert np.all(right >= pivot), f"Row {i}: right violation" - # Rank consistency: kth must lie within [count( max_small_axis}) - - if not vocab_sizes: - print( - "No vocabulary sizes in sweep range for this dtype and shared-memory budget." - ) + ] + batches = [1, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192] + vocabs = [v for v in vocabs if v <= limit] + + if not vocabs: + print(f"No vocab sizes in range for {name} (limit={limit}).") return - col_w = 10 - print(f"{'':>8}", end="") - for v in vocab_sizes: - label = f"v={v}" - print(f" {label:^{col_w}}", end="") - print() + print(f"\n**{name}** k=v*{k_ratio:.3f} small-kernel-limit={limit}\n") + print("| batch |", " | ".join(f"v={v}" for v in vocabs), "|") + print("|------:|", " | ".join("---:" for _ in vocabs), "|") - for b in batch_sizes: - print(f"b={b:<6}", end="") - for v in vocab_sizes: + for b in batches: + cells = [] + for v in vocabs: k = max(1, int(v * k_ratio)) try: x = mx.random.uniform(shape=(b, v)).astype(dtype) mx.eval(x) - for _ in range(warmup): - mx.eval(mx.argpartition(x, kth=k, axis=-1)) - start = time.perf_counter() - for _ in range(iterations): - mx.eval(mx.argpartition(x, kth=k, axis=-1)) - radix_ms = (time.perf_counter() - start) / iterations * 1000 - for _ in range(warmup): - mx.eval(mx.argsort(x, axis=-1)) - start = time.perf_counter() - for _ in range(iterations): - mx.eval(mx.argsort(x, axis=-1)) - argsort_ms = (time.perf_counter() - start) / iterations * 1000 - + ap = _bench( + x, lambda a: mx.argpartition(a, kth=k, axis=-1), warmup, iters + ) + ar = _bench(x, lambda a: mx.argsort(a, axis=-1), warmup, iters) if verify: - verify_correctness(b, v, k, dtype=dtype) - verify_tie_determinism(b, v, k, dtype=dtype) - - speedup = argsort_ms / radix_ms - cell = f"{speedup:>5.2f}x" - print(f" {cell:^{col_w}}", end="") + verify_correctness(b, v, k, dtype) + verify_determinism(b, v, k, dtype) + cells.append(f"{ar / ap:.2f}x") except Exception: - print(f" {'ERR':^{col_w}}", end="") - print() + cells.append("ERR") + print(f"| {b} |", " | ".join(cells), "|") def main(): - parser = argparse.ArgumentParser( - description="Benchmark MLX radix select implementation" - ) - parser.add_argument( - "--large-kernel-sweep", - action="store_true", - help="Enable large-kernel-focused sweep (default: disabled)", - ) - parser.add_argument( - "--small-kernel-sweep", - action="store_true", - help="Enable small-kernel-focused sweep around the estimated boundary", + p = argparse.ArgumentParser(description="Benchmark MLX radix select") + p.add_argument("--verify", action="store_true", help="Run correctness checks") + p.add_argument( + "--sweep", action="store_true", help="Sweep batch x vocab for small kernel" ) - parser.add_argument( - "--verify", - action="store_true", - help="Enable correctness verification (default: disabled). " - "Disabled when --large-kernel-sweep is enabled.", - ) - parser.add_argument( + p.add_argument( "--dtypes", - type=str, default="bfloat16,float32", - help="Comma-separated data types to test (default: bfloat16,float32). " - "Supported: bool, bfloat16, float16, float32, float64, " - "int8, int16, int32, int64, uint8, uint16, uint32, uint64", + help="Comma-separated dtypes (default: bfloat16,float32)", ) - args = parser.parse_args() - - print("=" * 70) - print("MLX Radix Select Benchmark") - print("=" * 70) + args = p.parse_args() + dtypes = parse_dtypes(args.dtypes) configs = [ (2048, 8192, 32), @@ -391,82 +218,41 @@ def main(): (64, 8192, 32), ] - try: - dtypes = parse_dtypes(args.dtypes) - except ValueError as e: - print(f"Error: {e}") - return - - if args.large_kernel_sweep and args.small_kernel_sweep: - print("Error: choose only one of --large-kernel-sweep or --small-kernel-sweep") - return - - if not args.large_kernel_sweep and not args.small_kernel_sweep: - if args.verify: - print("\n1. Correctness Verification") - print("-" * 40) - for dtype, dtype_name in dtypes: - for b, v, k in configs: - try: - verify_correctness(b, v, k, dtype=dtype) - print(f" [PASS] b={b}, v={v}, k={k}, dtype={dtype_name}") - except AssertionError as e: - print(f" [FAIL] b={b}, v={v}, k={k}, dtype={dtype_name}: {e}") - - print("\n2. Tie Determinism Verification") - print("-" * 40) - for dtype, dtype_name in dtypes: - for b, v, k in configs: - try: - verify_tie_determinism(b=b, v=v, k=k, dtype=dtype) - print( - f" [PASS] all-equal input " - f"(b={b}, v={v}, k={k}), dtype={dtype_name}, runs=8" - ) - except AssertionError as e: - print( - f" [FAIL] all-equal input " - f"(b={b}, v={v}, k={k}), dtype={dtype_name}, runs=8: {e}" - ) - - print("\n3. Performance Benchmarks") - else: - print("\nPerformance Benchmarks") - print("-" * 70) - - for dtype, dtype_name in dtypes: - print(f"\nDtype: {dtype_name}") - print( - f"{'Config':<25} {'ArgPartition':>14} {'ArgSort':>12} {'Speedup':>10}" - ) - print("-" * 80) - + if args.verify: + print("# Correctness\n") + for dtype in dtypes: + name = _dtype_name(dtype) for b, v, k in configs: try: - argpart_ms = benchmark_argpartition( - b, v, k, dtype, warmup=3, iterations=50 - ) - argsort_ms = benchmark_argsort(b, v, dtype, warmup=3, iterations=50) - speedup = argsort_ms / argpart_ms - config_str = f"b={b}, v={v}, k={k}" + verify_correctness(b, v, k, dtype) + verify_determinism(b, v, k, dtype) + print(f" PASS b={b} v={v} k={k} {name}") + except AssertionError as e: + print(f" FAIL b={b} v={v} k={k} {name}: {e}") + + if args.sweep: + print("# Sweep (speedup vs argsort)\n") + for dtype in dtypes: + sweep(dtype, verify=args.verify) + + if not args.verify and not args.sweep: + print("# Benchmark (speedup vs argsort)\n") + for dtype in dtypes: + name = _dtype_name(dtype) + print(f"\n**{name}**\n") + print("| config | argpartition | argsort | speedup |") + print("|--------|------------:|--------:|--------:|") + for b, v, k in configs: + try: + x = mx.random.uniform(shape=(b, v)).astype(dtype) + mx.eval(x) + ap = _bench(x, lambda a: mx.argpartition(a, kth=k, axis=-1), 3, 50) + ar = _bench(x, lambda a: mx.argsort(a, axis=-1), 3, 50) print( - f"{config_str:<25} {argpart_ms:>12.3f}ms" - f" {argsort_ms:>10.3f}ms {speedup:>5.2f}x" + f"| b={b} v={v} k={k} | {ap:.3f}ms | {ar:.3f}ms | {ar/ap:.2f}x |" ) except Exception as e: - print(f"b={b}, v={v}, k={k}: Error - {e}") - - if args.large_kernel_sweep or args.small_kernel_sweep: - print("\nKernel Sweep" + (" (with verification)" if args.verify else "")) - print("-" * 70) - for dtype, dtype_name in dtypes: - sweep_kernel( - dtype, verify=args.verify, small_kernel=args.small_kernel_sweep - ) - - print("\n" + "=" * 70) - print("Benchmark Complete") - print("=" * 70) + print(f"| b={b} v={v} k={k} | ERR | ERR | {e} |") if __name__ == "__main__": diff --git a/mlx/backend/cuda/device/radix_select.cuh b/mlx/backend/cuda/device/radix_select.cuh index 10d36c9015..be003d0688 100644 --- a/mlx/backend/cuda/device/radix_select.cuh +++ b/mlx/backend/cuda/device/radix_select.cuh @@ -33,6 +33,9 @@ struct RadixTraits { static constexpr int BITS = 32; __device__ __forceinline__ static UnsignedT to_radix(float val) { + if (cuda::std::isnan(val)) { + return ~UnsignedT(0); + } uint32_t bits = __float_as_uint(val); if ((bits << 1) == 0) { bits = 0; // Canonicalize +/-0.0 to +0.0 for stable equal-value ties. @@ -48,6 +51,9 @@ struct RadixTraits { static constexpr int BITS = 64; __device__ __forceinline__ static UnsignedT to_radix(double val) { + if (cuda::std::isnan(val)) { + return ~UnsignedT(0); + } uint64_t bits = __double_as_longlong(val); if ((bits << 1) == 0) { bits = 0; // Canonicalize +/-0.0 to +0.0 for stable equal-value ties. @@ -82,6 +88,9 @@ struct RadixTraits<__half> { static constexpr int BITS = 16; __device__ __forceinline__ static UnsignedT to_radix(__half val) { + if (cuda::std::isnan(val)) { + return ~UnsignedT(0); + } uint16_t bits = __half_as_ushort(val); if ((bits & 0x7FFFu) == 0) { bits = 0; // Canonicalize +/-0 to +0 for stable equal-value ties. @@ -97,6 +106,9 @@ struct RadixTraits<__nv_bfloat16> { static constexpr int BITS = 16; __device__ __forceinline__ static UnsignedT to_radix(__nv_bfloat16 val) { + if (cuda::std::isnan(val)) { + return ~UnsignedT(0); + } uint16_t bits = __bfloat16_as_ushort(val); if ((bits & 0x7FFFu) == 0) { bits = 0; // Canonicalize +/-0 to +0 for stable equal-value ties. @@ -282,8 +294,7 @@ __global__ void radix_select_small_kernel( // Calculate offsets for different arrays in shared memory UnsignedT* shared_keys = reinterpret_cast(shared_mem); - uint32_t* shared_idxs = reinterpret_cast(shared_keys + TILE_SIZE); - int* shared_hist = reinterpret_cast(shared_idxs + TILE_SIZE); + int* shared_hist = reinterpret_cast(shared_keys + TILE_SIZE); int* shared_count = shared_hist + RADIX_SIZE; constexpr int NUM_WARPS = BLOCK_THREADS / WARP_SIZE; int* scatter_scratch = shared_count + 2; @@ -314,30 +325,8 @@ __global__ void radix_select_small_kernel( // Load data into shared memory for (int i = threadIdx.x; i < TILE_SIZE; i += BLOCK_THREADS) { - if (i < tile_n) { - ValT val = row_input[i * in_stride]; - UnsignedT key = Traits::to_radix(val); - if constexpr (cuda::std::is_floating_point_v) { - if (cuda::std::isnan(val)) { - key = ~UnsignedT(0); - } - } else if constexpr (cuda::std::is_same_v) { - if (__hisnan(val)) { - key = ~UnsignedT(0); - } - } else if constexpr (cuda::std::is_same_v) { - if (__hisnan(val)) { - key = ~UnsignedT(0); - } - } else { - // Non-floating types cannot produce NaN keys. - } - shared_keys[i] = key; - shared_idxs[i] = i; - } else { - shared_keys[i] = ~UnsignedT(0); - shared_idxs[i] = i; - } + shared_keys[i] = (i < tile_n) ? Traits::to_radix(row_input[i * in_stride]) + : ~UnsignedT(0); } __syncthreads(); @@ -492,9 +481,9 @@ __global__ void radix_select_small_kernel( } if constexpr (ARG_PARTITION) { - row_output[pos * out_stride] = shared_idxs[i]; + row_output[pos * out_stride] = i; } else { - row_output[pos * out_stride] = row_input[shared_idxs[i] * in_stride]; + row_output[pos * out_stride] = row_input[i * in_stride]; } } __syncthreads(); diff --git a/mlx/backend/cuda/partition.cu b/mlx/backend/cuda/partition.cu index 599f7930d8..289ddf5da3 100644 --- a/mlx/backend/cuda/partition.cu +++ b/mlx/backend/cuda/partition.cu @@ -72,8 +72,7 @@ size_t radix_small_shared_mem_bytes( static_cast(items_per_thread); size_t num_warps = static_cast(block_threads / WARP_SIZE); return tile_size * key_size + // shared_keys - tile_size * sizeof(uint32_t) + // shared_idxs - cu::RADIX_SIZE * sizeof(int) + // shared_hist for small kernel + cu::RADIX_SIZE * sizeof(int) + // shared_hist (2 + 3 * num_warps + 6) * sizeof(int); // shared_count + scatter scratch } @@ -82,7 +81,6 @@ bool radix_small_fits_shared_memory(Dtype dtype, int size_sorted_axis) { return false; } - size_t required_shared_mem = 0; bool fits = false; dispatch_radix_small_block_threads(size_sorted_axis, [&](auto block_dim_tag) { constexpr int BLOCK_THREADS = block_dim_tag(); @@ -95,15 +93,15 @@ bool radix_small_fits_shared_memory(Dtype dtype, int size_sorted_axis) { dispatch_radix_items_per_thread( size_sorted_axis, BLOCK_THREADS, [&](auto items_per_thread_tag) { constexpr int ITEMS_PER_THREAD = items_per_thread_tag(); - required_shared_mem = radix_small_shared_mem_bytes( - size_of(dtype), BLOCK_THREADS, ITEMS_PER_THREAD); - fits = required_shared_mem <= RADIX_SMALL_SHARED_MEM_BUDGET_BYTES; + fits = radix_small_shared_mem_bytes( + size_of(dtype), BLOCK_THREADS, ITEMS_PER_THREAD) <= + RADIX_SMALL_SHARED_MEM_BUDGET_BYTES; }); }); return fits; } -void gpu_radix_partition_small( +void gpu_partition_small( const Stream& s, const array& in, array& out, @@ -193,18 +191,10 @@ void gpu_radix_partition_small( BLOCK_THREADS, ITEMS_PER_THREAD>; - // Calculate dynamic shared memory size - using UnsignedT = typename cu::RadixTraits::UnsignedT; - constexpr int TILE_SIZE_VAL = - BLOCK_THREADS * ITEMS_PER_THREAD; - constexpr int NUM_WARPS = BLOCK_THREADS / WARP_SIZE; - constexpr size_t shared_mem_bytes = - TILE_SIZE_VAL * sizeof(UnsignedT) + // shared_keys - TILE_SIZE_VAL * sizeof(uint32_t) + // shared_idxs - cu::RADIX_SIZE * - sizeof(int) + // shared_hist for small kernel - (2 + 3 * NUM_WARPS + 6) * - sizeof(int); // shared_count + scatter scratch + size_t shared_mem_bytes = radix_small_shared_mem_bytes( + sizeof(typename cu::RadixTraits::UnsignedT), + BLOCK_THREADS, + ITEMS_PER_THREAD); encoder.add_kernel_node_ex( kernel, @@ -252,7 +242,7 @@ void gpu_partition( // Dispatch based on whether the small kernel tile fits in shared memory. if (radix_small_fits_shared_memory(in.dtype(), size_sorted_axis)) { - return gpu_radix_partition_small(s, in, out, axis, kth, arg_partition); + return gpu_partition_small(s, in, out, axis, kth, arg_partition); } else { return gpu_partition_fallback(s, in, out, axis, arg_partition); } From a402fd3d54a4fb992c7e8f21046f2c93a56345d8 Mon Sep 17 00:00:00 2001 From: Lyxot Date: Mon, 30 Mar 2026 18:02:39 +0800 Subject: [PATCH 22/23] use cooperative_groups APIs --- mlx/backend/cuda/device/radix_select.cuh | 139 ++++++++++------------- mlx/backend/cuda/partition.cu | 38 +------ mlx/backend/cuda/sort.cu | 48 +++++--- 3 files changed, 97 insertions(+), 128 deletions(-) diff --git a/mlx/backend/cuda/device/radix_select.cuh b/mlx/backend/cuda/device/radix_select.cuh index be003d0688..ddda3ebf4b 100644 --- a/mlx/backend/cuda/device/radix_select.cuh +++ b/mlx/backend/cuda/device/radix_select.cuh @@ -2,6 +2,8 @@ #pragma once +#include +#include #include #include #include @@ -209,54 +211,45 @@ struct RadixTraits { }; /////////////////////////////////////////////////////////////////////////////// -// Warp-level utilities +// Block-level utilities /////////////////////////////////////////////////////////////////////////////// +namespace cg = cooperative_groups; + template __device__ __forceinline__ int block_exclusive_scan( + cg::thread_block& block, int val, int* shared_warp_sums, int* block_total = nullptr) { static_assert(BLOCK_THREADS % WARP_SIZE == 0); constexpr int NUM_WARPS = BLOCK_THREADS / WARP_SIZE; - int lane = threadIdx.x & (WARP_SIZE - 1); - int warp = threadIdx.x / WARP_SIZE; + auto warp = cg::tiled_partition(block); + int inclusive = cg::inclusive_scan(warp, val, cg::plus()); - int inclusive = val; -#pragma unroll - for (int offset = 1; offset < WARP_SIZE; offset <<= 1) { - int n = __shfl_up_sync(0xFFFFFFFF, inclusive, offset); - if (lane >= offset) { - inclusive += n; - } + if (warp.thread_rank() == WARP_SIZE - 1) { + shared_warp_sums[warp.meta_group_rank()] = inclusive; } + block.sync(); - if (lane == WARP_SIZE - 1) { - shared_warp_sums[warp] = inclusive; - } - __syncthreads(); - - if (warp == 0) { - int warp_scan = (lane < NUM_WARPS) ? shared_warp_sums[lane] : 0; -#pragma unroll - for (int offset = 1; offset < WARP_SIZE; offset <<= 1) { - int n = __shfl_up_sync(0xFFFFFFFF, warp_scan, offset); - if (lane >= offset) { - warp_scan += n; - } - } + if (warp.meta_group_rank() == 0) { + int warp_val = warp.thread_rank() < NUM_WARPS + ? shared_warp_sums[warp.thread_rank()] + : 0; + int warp_scan = cg::inclusive_scan(warp, warp_val, cg::plus()); - if (lane < NUM_WARPS) { - shared_warp_sums[lane] = warp_scan - shared_warp_sums[lane]; + if (warp.thread_rank() < NUM_WARPS) { + shared_warp_sums[warp.thread_rank()] = + warp_scan - shared_warp_sums[warp.thread_rank()]; } - if (block_total != nullptr && lane == NUM_WARPS - 1) { + if (block_total != nullptr && warp.thread_rank() == NUM_WARPS - 1) { *block_total = warp_scan; } } - __syncthreads(); + block.sync(); - return shared_warp_sums[warp] + inclusive - val; + return shared_warp_sums[warp.meta_group_rank()] + inclusive - val; } /////////////////////////////////////////////////////////////////////////////// @@ -288,15 +281,15 @@ __global__ void radix_select_small_kernel( constexpr int TILE_SIZE = BLOCK_THREADS * ITEMS_PER_THREAD; constexpr int NUM_PASSES = (Traits::BITS + RADIX_BITS - 1) / RADIX_BITS; + constexpr int NUM_WARPS = BLOCK_THREADS / WARP_SIZE; - // Dynamic shared memory layout - extern __shared__ char shared_mem[]; + auto block = cg::this_thread_block(); + auto warp = cg::tiled_partition(block); - // Calculate offsets for different arrays in shared memory + extern __shared__ char shared_mem[]; UnsignedT* shared_keys = reinterpret_cast(shared_mem); int* shared_hist = reinterpret_cast(shared_keys + TILE_SIZE); int* shared_count = shared_hist + RADIX_SIZE; - constexpr int NUM_WARPS = BLOCK_THREADS / WARP_SIZE; int* scatter_scratch = shared_count + 2; int* warp_less = scatter_scratch; int* warp_equal = warp_less + NUM_WARPS; @@ -306,7 +299,6 @@ __global__ void radix_select_small_kernel( int row = blockIdx.y; - // Compute row pointers based on addressing mode const ValT* row_input; OutT* row_output; if constexpr (USE_SIMPLE_STRIDE) { @@ -323,14 +315,12 @@ __global__ void radix_select_small_kernel( int tile_n = min(n, TILE_SIZE); - // Load data into shared memory - for (int i = threadIdx.x; i < TILE_SIZE; i += BLOCK_THREADS) { + for (int i = block.thread_rank(); i < TILE_SIZE; i += BLOCK_THREADS) { shared_keys[i] = (i < tile_n) ? Traits::to_radix(row_input[i * in_stride]) : ~UnsignedT(0); } - __syncthreads(); + block.sync(); - // Radix select to find pivot int k = kth + 1; UnsignedT target_prefix = 0; UnsignedT prefix_mask = 0; @@ -338,24 +328,22 @@ __global__ void radix_select_small_kernel( for (int pass = NUM_PASSES - 1; pass >= 0; pass--) { int start_bit = pass * RADIX_BITS; - // Clear histogram - for (int i = threadIdx.x; i < RADIX_SIZE; i += BLOCK_THREADS) { + for (int i = block.thread_rank(); i < RADIX_SIZE; i += BLOCK_THREADS) { shared_hist[i] = 0; } - __syncthreads(); + block.sync(); - // Build histogram - for (int i = threadIdx.x; i < tile_n; i += BLOCK_THREADS) { + for (int i = block.thread_rank(); i < tile_n; i += BLOCK_THREADS) { UnsignedT key = shared_keys[i]; if ((key & prefix_mask) == target_prefix) { int digit = (key >> start_bit) & ((1 << RADIX_BITS) - 1); atomicAdd(&shared_hist[digit], 1); } } - __syncthreads(); + block.sync(); // Find target bin (single thread) - if (threadIdx.x == 0) { + if (block.thread_rank() == 0) { int cumsum = 0; int target_bin = 0; for (int bin = 0; bin < RADIX_SIZE; bin++) { @@ -370,7 +358,7 @@ __global__ void radix_select_small_kernel( shared_count[0] = target_bin; shared_count[1] = k; } - __syncthreads(); + block.sync(); int target_bin = shared_count[0]; k = shared_count[1]; @@ -379,14 +367,14 @@ __global__ void radix_select_small_kernel( target_prefix |= UnsignedT(target_bin) << start_bit; prefix_mask |= digit_mask; - __syncthreads(); + block.sync(); } // Count per-thread bucket sizes once, then scatter in a single pass with // deterministic per-thread offsets. int local_less = 0; int local_equal = 0; - for (int i = threadIdx.x; i < tile_n; i += BLOCK_THREADS) { + for (int i = block.thread_rank(); i < tile_n; i += BLOCK_THREADS) { UnsignedT key = shared_keys[i]; if (key < target_prefix) { local_less++; @@ -396,26 +384,23 @@ __global__ void radix_select_small_kernel( } (void)block_exclusive_scan( - local_less, shared_hist, &shared_count[0]); + block, local_less, shared_hist, &shared_count[0]); (void)block_exclusive_scan( - local_equal, shared_hist, &shared_count[1]); + block, local_equal, shared_hist, &shared_count[1]); int less_count = shared_count[0]; int equal_count = shared_count[1]; // Scatter in increasing i order to keep tie behavior aligned with merge sort. - int lane = threadIdx.x & (WARP_SIZE - 1); - int warp = threadIdx.x / WARP_SIZE; - - if (threadIdx.x == 0) { + if (block.thread_rank() == 0) { running_bases[0] = 0; running_bases[1] = less_count; running_bases[2] = less_count + equal_count; } - __syncthreads(); + block.sync(); for (int base_i = 0; base_i < tile_n; base_i += BLOCK_THREADS) { - int i = base_i + threadIdx.x; + int i = base_i + block.thread_rank(); bool active = i < tile_n; UnsignedT key = 0; @@ -427,23 +412,23 @@ __global__ void radix_select_small_kernel( bool is_equal = active && (key == target_prefix); bool is_greater = active && !is_less && !is_equal; - unsigned less_mask = __ballot_sync(0xFFFFFFFF, is_less); - unsigned equal_mask = __ballot_sync(0xFFFFFFFF, is_equal); - unsigned greater_mask = __ballot_sync(0xFFFFFFFF, is_greater); + unsigned less_ballot = warp.ballot(is_less); + unsigned equal_ballot = warp.ballot(is_equal); + unsigned greater_ballot = warp.ballot(is_greater); - unsigned lane_mask = (1u << lane) - 1u; - int less_rank = __popc(less_mask & lane_mask); - int equal_rank = __popc(equal_mask & lane_mask); - int greater_rank = __popc(greater_mask & lane_mask); + unsigned lane_mask = (1u << warp.thread_rank()) - 1u; + int less_rank = __popc(less_ballot & lane_mask); + int equal_rank = __popc(equal_ballot & lane_mask); + int greater_rank = __popc(greater_ballot & lane_mask); - if (lane == 0) { - warp_less[warp] = __popc(less_mask); - warp_equal[warp] = __popc(equal_mask); - warp_greater[warp] = __popc(greater_mask); + if (warp.thread_rank() == 0) { + warp_less[warp.meta_group_rank()] = __popc(less_ballot); + warp_equal[warp.meta_group_rank()] = __popc(equal_ballot); + warp_greater[warp.meta_group_rank()] = __popc(greater_ballot); } - __syncthreads(); + block.sync(); - if (threadIdx.x == 0) { + if (block.thread_rank() == 0) { int run = 0; for (int w = 0; w < NUM_WARPS; ++w) { int c = warp_less[w]; @@ -468,16 +453,18 @@ __global__ void radix_select_small_kernel( } iter_counts[2] = run; } - __syncthreads(); + block.sync(); if (active) { int pos; if (is_less) { - pos = running_bases[0] + warp_less[warp] + less_rank; + pos = running_bases[0] + warp_less[warp.meta_group_rank()] + less_rank; } else if (is_equal) { - pos = running_bases[1] + warp_equal[warp] + equal_rank; + pos = + running_bases[1] + warp_equal[warp.meta_group_rank()] + equal_rank; } else { - pos = running_bases[2] + warp_greater[warp] + greater_rank; + pos = running_bases[2] + warp_greater[warp.meta_group_rank()] + + greater_rank; } if constexpr (ARG_PARTITION) { @@ -486,14 +473,14 @@ __global__ void radix_select_small_kernel( row_output[pos * out_stride] = row_input[i * in_stride]; } } - __syncthreads(); + block.sync(); - if (threadIdx.x == 0) { + if (block.thread_rank() == 0) { running_bases[0] += iter_counts[0]; running_bases[1] += iter_counts[1]; running_bases[2] += iter_counts[2]; } - __syncthreads(); + block.sync(); } } diff --git a/mlx/backend/cuda/partition.cu b/mlx/backend/cuda/partition.cu index 289ddf5da3..04482f9330 100644 --- a/mlx/backend/cuda/partition.cu +++ b/mlx/backend/cuda/partition.cu @@ -12,13 +12,6 @@ namespace mlx::core { -void gpu_partition_fallback( - const Stream& s, - const array& in, - array& out, - int axis, - bool arg_partition); - namespace { // Upper bound for small-kernel tiling. Keep this aligned with the @@ -76,7 +69,9 @@ size_t radix_small_shared_mem_bytes( (2 + 3 * num_warps + 6) * sizeof(int); // shared_count + scatter scratch } -bool radix_small_fits_shared_memory(Dtype dtype, int size_sorted_axis) { +} // namespace + +bool gpu_partition_small_fits(Dtype dtype, int size_sorted_axis) { if (size_sorted_axis <= 0) { return false; } @@ -221,31 +216,4 @@ void gpu_partition_small( }); } -} // namespace - -void gpu_partition( - const Stream& s, - const array& in, - array& out, - int axis_, - int kth_, - bool arg_partition) { - int axis = axis_ < 0 ? axis_ + in.ndim() : axis_; - int size_sorted_axis = in.shape(axis); - int kth = kth_ < 0 ? kth_ + size_sorted_axis : kth_; - int nc_dim = static_cast(in.ndim()) - 1; - - // Fixed-size const_param metadata is capped by MAX_NDIM. - if (nc_dim > MAX_NDIM) { - return gpu_partition_fallback(s, in, out, axis, arg_partition); - } - - // Dispatch based on whether the small kernel tile fits in shared memory. - if (radix_small_fits_shared_memory(in.dtype(), size_sorted_axis)) { - return gpu_partition_small(s, in, out, axis, kth, arg_partition); - } else { - return gpu_partition_fallback(s, in, out, axis, arg_partition); - } -} - } // namespace mlx::core diff --git a/mlx/backend/cuda/sort.cu b/mlx/backend/cuda/sort.cu index 61ff3d3634..271f0841c8 100644 --- a/mlx/backend/cuda/sort.cu +++ b/mlx/backend/cuda/sort.cu @@ -718,19 +718,14 @@ __global__ void mb_block_merge_kernel( } // namespace cu -void gpu_partition( - const Stream& s, - const array& in, - array& out, - int axis_, - int kth_, - bool arg_partition); +bool gpu_partition_small_fits(Dtype dtype, int size_sorted_axis); -void gpu_partition_fallback( +void gpu_partition_small( const Stream& s, const array& in, array& out, int axis, + int kth, bool arg_partition); namespace { @@ -1067,15 +1062,6 @@ void gpu_sort( } // namespace -void gpu_partition_fallback( - const Stream& s, - const array& in, - array& out, - int axis, - bool arg_partition) { - gpu_merge_sort(s, in, out, axis, arg_partition); -} - void ArgSort::eval_gpu(const std::vector& inputs, array& out) { nvtx3::scoped_range r("ArgSort::eval_gpu"); assert(inputs.size() == 1); @@ -1088,6 +1074,34 @@ void Sort::eval_gpu(const std::vector& inputs, array& out) { gpu_sort(stream(), inputs[0], out, axis_, false); } +namespace { + +void gpu_partition( + const Stream& s, + const array& in, + array& out, + int axis_, + int kth_, + bool arg_partition) { + int axis = axis_ < 0 ? axis_ + in.ndim() : axis_; + int size_sorted_axis = in.shape(axis); + int kth = kth_ < 0 ? kth_ + size_sorted_axis : kth_; + + // Fixed-size const_param metadata is capped by MAX_NDIM. + if (in.ndim() > MAX_NDIM) { + return gpu_merge_sort(s, in, out, axis, arg_partition); + } + + // Dispatch based on whether the small kernel tile fits in shared memory. + if (gpu_partition_small_fits(in.dtype(), size_sorted_axis)) { + return gpu_partition_small(s, in, out, axis, kth, arg_partition); + } else { + return gpu_merge_sort(s, in, out, axis, arg_partition); + } +} + +} // namespace + void ArgPartition::eval_gpu(const std::vector& inputs, array& out) { nvtx3::scoped_range r("ArgPartition::eval_gpu"); gpu_partition(stream(), inputs[0], out, axis_, kth_, true); From b32554df48c82c1875939f86070ea0a16c9b80cb Mon Sep 17 00:00:00 2001 From: Lyxot Date: Mon, 30 Mar 2026 18:21:38 +0800 Subject: [PATCH 23/23] refactor: use static shared memory for radix select kernel --- mlx/backend/cuda/device/radix_select.cuh | 18 +++--- mlx/backend/cuda/partition.cu | 82 ++++++++++++------------ 2 files changed, 48 insertions(+), 52 deletions(-) diff --git a/mlx/backend/cuda/device/radix_select.cuh b/mlx/backend/cuda/device/radix_select.cuh index ddda3ebf4b..e5e0b6ae0e 100644 --- a/mlx/backend/cuda/device/radix_select.cuh +++ b/mlx/backend/cuda/device/radix_select.cuh @@ -286,16 +286,14 @@ __global__ void radix_select_small_kernel( auto block = cg::this_thread_block(); auto warp = cg::tiled_partition(block); - extern __shared__ char shared_mem[]; - UnsignedT* shared_keys = reinterpret_cast(shared_mem); - int* shared_hist = reinterpret_cast(shared_keys + TILE_SIZE); - int* shared_count = shared_hist + RADIX_SIZE; - int* scatter_scratch = shared_count + 2; - int* warp_less = scatter_scratch; - int* warp_equal = warp_less + NUM_WARPS; - int* warp_greater = warp_equal + NUM_WARPS; - int* iter_counts = warp_greater + NUM_WARPS; - int* running_bases = iter_counts + 3; + __shared__ UnsignedT shared_keys[TILE_SIZE]; + __shared__ int shared_hist[RADIX_SIZE]; + __shared__ int shared_count[2]; + __shared__ int warp_less[NUM_WARPS]; + __shared__ int warp_equal[NUM_WARPS]; + __shared__ int warp_greater[NUM_WARPS]; + __shared__ int iter_counts[3]; + __shared__ int running_bases[3]; int row = blockIdx.y; diff --git a/mlx/backend/cuda/partition.cu b/mlx/backend/cuda/partition.cu index 04482f9330..0032c1abb5 100644 --- a/mlx/backend/cuda/partition.cu +++ b/mlx/backend/cuda/partition.cu @@ -57,13 +57,12 @@ void dispatch_radix_items_per_thread( } } -size_t radix_small_shared_mem_bytes( +constexpr size_t radix_small_shared_mem_bytes( size_t key_size, - int block_threads, - int items_per_thread) { - size_t tile_size = static_cast(block_threads) * - static_cast(items_per_thread); - size_t num_warps = static_cast(block_threads / WARP_SIZE); + size_t block_threads, + size_t items_per_thread) { + size_t tile_size = block_threads * items_per_thread; + size_t num_warps = block_threads / WARP_SIZE; return tile_size * key_size + // shared_keys cu::RADIX_SIZE * sizeof(int) + // shared_hist (2 + 3 * num_warps + 6) * sizeof(int); // shared_count + scatter scratch @@ -174,42 +173,41 @@ void gpu_partition_small( [&](auto items_per_thread_tag) { constexpr int ITEMS_PER_THREAD = items_per_thread_tag(); - dispatch_bool(contiguous, [&](auto contiguous_tag) { - constexpr bool USE_SIMPLE_STRIDE = - decltype(contiguous_tag)::value; - - auto kernel = cu::radix_select_small_kernel< - ValT, - OutT, - ARG_PARTITION, - USE_SIMPLE_STRIDE, - BLOCK_THREADS, - ITEMS_PER_THREAD>; - - size_t shared_mem_bytes = radix_small_shared_mem_bytes( - sizeof(typename cu::RadixTraits::UnsignedT), - BLOCK_THREADS, - ITEMS_PER_THREAD); - - encoder.add_kernel_node_ex( - kernel, - grid, - block, - {}, - static_cast(shared_mem_bytes), - gpu_ptr(in), - gpu_ptr(out), - kth, - size_sorted_axis, - in_stride_sorted_axis, - out_stride_sorted_axis, - in_stride_segment_axis, - out_stride_segment_axis, - nc_shape_param, - in_nc_strides_param, - out_nc_strides_param, - nc_dim); - }); + constexpr size_t SMEM = radix_small_shared_mem_bytes( + sizeof(typename cu::RadixTraits::UnsignedT), + BLOCK_THREADS, + ITEMS_PER_THREAD); + if constexpr (SMEM <= RADIX_SMALL_SHARED_MEM_BUDGET_BYTES) { + dispatch_bool(contiguous, [&](auto contiguous_tag) { + constexpr bool USE_SIMPLE_STRIDE = + decltype(contiguous_tag)::value; + + auto kernel = cu::radix_select_small_kernel< + ValT, + OutT, + ARG_PARTITION, + USE_SIMPLE_STRIDE, + BLOCK_THREADS, + ITEMS_PER_THREAD>; + + encoder.add_kernel_node( + kernel, + grid, + block, + gpu_ptr(in), + gpu_ptr(out), + kth, + size_sorted_axis, + in_stride_sorted_axis, + out_stride_sorted_axis, + in_stride_segment_axis, + out_stride_segment_axis, + nc_shape_param, + in_nc_strides_param, + out_nc_strides_param, + nc_dim); + }); + } }); }); });