diff --git a/benchmarks/python/radix_select_bench.py b/benchmarks/python/radix_select_bench.py new file mode 100644 index 0000000000..17f1f5c9dd --- /dev/null +++ b/benchmarks/python/radix_select_bench.py @@ -0,0 +1,259 @@ +#!/usr/bin/env python3 +"""Benchmark and verify MLX argpartition/partition (radix select).""" + +import argparse +import time + +import mlx.core as mx +import numpy as np + + +def _resolve_dtype(name): + dt = getattr(mx, name, None) or getattr(mx, name + "_", None) + if dt is None or not isinstance(dt, mx.Dtype): + raise ValueError(f"Unknown dtype: {name}") + return dt + + +# Must match partition.cu dispatch constants. +MAX_RADIX_ITEMS_PER_THREAD = 64 +RADIX_SIZE = 32 +WARP_SIZE = 32 +SMEM_BUDGET = 48 * 1024 + +BLOCK_THRESHOLDS = [(256, 32), (512, 64), (1024, 128)] +ITEMS_BUCKETS = (1, 2, 4, 8, 12, 16, 24, 64) + + +def _dtype_name(dtype): + return str(dtype).split(".")[-1] + + +def _key_bytes(dtype): + return dtype.size + + +def _block_threads(axis_size): + for threshold, threads in BLOCK_THRESHOLDS: + if axis_size <= threshold: + return threads + return 256 + + +def _items_per_thread(axis_size, block_threads): + needed = (axis_size + block_threads - 1) // block_threads + for b in ITEMS_BUCKETS: + if needed <= b: + return b + return None + + +def _smem_bytes(key_size, block_threads, items_per_thread): + tile = block_threads * items_per_thread + warps = block_threads // WARP_SIZE + return tile * key_size + RADIX_SIZE * 4 + (2 + 3 * warps + 6) * 4 + + +def max_small_kernel_axis(dtype): + """Largest axis size the radix small kernel can handle for dtype.""" + ks = _key_bytes(dtype) + best = 0 + for v in range(1, 256 * MAX_RADIX_ITEMS_PER_THREAD + 1): + bt = _block_threads(v) + ipt = _items_per_thread(v, bt) + if ipt is None or ipt > MAX_RADIX_ITEMS_PER_THREAD: + continue + if _smem_bytes(ks, bt, ipt) <= SMEM_BUDGET: + best = v + return best + + +def parse_dtypes(s): + return [_resolve_dtype(n.strip().lower()) for n in s.split(",") if n.strip()] + + +def verify_correctness(b, v, k, dtype=mx.float32): + x = mx.floor(mx.random.uniform(shape=(b, v)) * 257.0).astype(dtype) + mx.eval(x) + + indices = mx.argpartition(x, kth=k, axis=-1) + mx.eval(indices) + + x_np = np.array(x.astype(mx.float32)) if dtype == mx.bfloat16 else np.array(x) + idx = np.array(indices) + is_float = np.issubdtype(x_np.dtype, np.floating) + + for i in range(b): + row, ri = x_np[i], idx[i] + assert np.unique(ri).size == v, f"Row {i}: not a permutation" + + pv = row[ri] + pivot, left, right = pv[k], pv[:k], pv[k + 1 :] + + if is_float and np.isnan(pivot): + assert np.all(np.isnan(pv[k:])), f"Row {i}: non-NaN after NaN pivot" + continue + + if is_float: + assert np.all( + (~np.isnan(left)) & (left <= pivot) + ), f"Row {i}: left violation" + assert np.all( + np.isnan(right) | (right >= pivot) + ), f"Row {i}: right violation" + else: + assert np.all(left <= pivot), f"Row {i}: left violation" + assert np.all(right >= pivot), f"Row {i}: right violation" + + less = np.count_nonzero(row < pivot) + leq = np.count_nonzero(row <= pivot) + assert less <= k < leq, f"Row {i}: rank inconsistent" + + +def verify_determinism(b, v, k, dtype=mx.float32): + x = mx.zeros((b, v), dtype=dtype) + mx.eval(x) + + outputs = [] + for _ in range(8): + idx = mx.argpartition(x, kth=k, axis=-1) + mx.eval(idx) + outputs.append(np.array(idx)) + + assert len({o.tobytes() for o in outputs}) == 1, "Non-deterministic tie ordering" + + expected = mx.argsort(x, axis=-1) + mx.eval(expected) + assert np.array_equal( + outputs[0], np.array(expected) + ), "Tie order differs from argsort" + + +def _bench(x, fn, warmup=10, iters=50): + for _ in range(warmup): + mx.eval(fn(x)) + t0 = time.perf_counter() + for _ in range(iters): + mx.eval(fn(x)) + return (time.perf_counter() - t0) / iters * 1000 + + +def sweep(dtype, k_ratio=0.004, warmup=10, iters=50, verify=False): + limit = max_small_kernel_axis(dtype) + name = _dtype_name(dtype) + + vocabs = [ + 32, + 64, + 96, + 160, + 256, + 384, + 512, + 1024, + 2048, + 4096, + 8192, + 16384, + 32768, + 65536, + 131072, + ] + batches = [1, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192] + vocabs = [v for v in vocabs if v <= limit] + + if not vocabs: + print(f"No vocab sizes in range for {name} (limit={limit}).") + return + + print(f"\n**{name}** k=v*{k_ratio:.3f} small-kernel-limit={limit}\n") + print("| batch |", " | ".join(f"v={v}" for v in vocabs), "|") + print("|------:|", " | ".join("---:" for _ in vocabs), "|") + + for b in batches: + cells = [] + for v in vocabs: + k = max(1, int(v * k_ratio)) + try: + x = mx.random.uniform(shape=(b, v)).astype(dtype) + mx.eval(x) + ap = _bench( + x, lambda a: mx.argpartition(a, kth=k, axis=-1), warmup, iters + ) + ar = _bench(x, lambda a: mx.argsort(a, axis=-1), warmup, iters) + if verify: + verify_correctness(b, v, k, dtype) + verify_determinism(b, v, k, dtype) + cells.append(f"{ar / ap:.2f}x") + except Exception: + cells.append("ERR") + print(f"| {b} |", " | ".join(cells), "|") + + +def main(): + p = argparse.ArgumentParser(description="Benchmark MLX radix select") + p.add_argument("--verify", action="store_true", help="Run correctness checks") + p.add_argument( + "--sweep", action="store_true", help="Sweep batch x vocab for small kernel" + ) + p.add_argument( + "--dtypes", + default="bfloat16,float32", + help="Comma-separated dtypes (default: bfloat16,float32)", + ) + args = p.parse_args() + dtypes = parse_dtypes(args.dtypes) + + configs = [ + (2048, 8192, 32), + (2048, 4096, 32), + (1024, 4096, 16), + (512, 2048, 64), + (256, 1024, 32), + (128, 512, 16), + (1, 128000, 64), + (1, 512, 32), + (16, 8192, 32), + (32, 8192, 32), + (64, 8192, 32), + ] + + if args.verify: + print("# Correctness\n") + for dtype in dtypes: + name = _dtype_name(dtype) + for b, v, k in configs: + try: + verify_correctness(b, v, k, dtype) + verify_determinism(b, v, k, dtype) + print(f" PASS b={b} v={v} k={k} {name}") + except AssertionError as e: + print(f" FAIL b={b} v={v} k={k} {name}: {e}") + + if args.sweep: + print("# Sweep (speedup vs argsort)\n") + for dtype in dtypes: + sweep(dtype, verify=args.verify) + + if not args.verify and not args.sweep: + print("# Benchmark (speedup vs argsort)\n") + for dtype in dtypes: + name = _dtype_name(dtype) + print(f"\n**{name}**\n") + print("| config | argpartition | argsort | speedup |") + print("|--------|------------:|--------:|--------:|") + for b, v, k in configs: + try: + x = mx.random.uniform(shape=(b, v)).astype(dtype) + mx.eval(x) + ap = _bench(x, lambda a: mx.argpartition(a, kth=k, axis=-1), 3, 50) + ar = _bench(x, lambda a: mx.argsort(a, axis=-1), 3, 50) + print( + f"| b={b} v={v} k={k} | {ap:.3f}ms | {ar:.3f}ms | {ar/ap:.2f}x |" + ) + except Exception as e: + print(f"| b={b} v={v} k={k} | ERR | ERR | {e} |") + + +if __name__ == "__main__": + main() diff --git a/mlx/backend/cuda/CMakeLists.txt b/mlx/backend/cuda/CMakeLists.txt index 1cee777bbe..eb2db6f5bb 100644 --- a/mlx/backend/cuda/CMakeLists.txt +++ b/mlx/backend/cuda/CMakeLists.txt @@ -39,6 +39,7 @@ target_sources( ${CMAKE_CURRENT_SOURCE_DIR}/load.cpp ${CMAKE_CURRENT_SOURCE_DIR}/layer_norm.cu ${CMAKE_CURRENT_SOURCE_DIR}/logsumexp.cu + ${CMAKE_CURRENT_SOURCE_DIR}/partition.cu ${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp ${CMAKE_CURRENT_SOURCE_DIR}/random.cu ${CMAKE_CURRENT_SOURCE_DIR}/reduce.cu diff --git a/mlx/backend/cuda/device/radix_select.cuh b/mlx/backend/cuda/device/radix_select.cuh new file mode 100644 index 0000000000..e5e0b6ae0e --- /dev/null +++ b/mlx/backend/cuda/device/radix_select.cuh @@ -0,0 +1,485 @@ +// Copyright © 2025 Apple Inc. + +#pragma once + +#include +#include +#include +#include +#include +#include "mlx/backend/cuda/device/utils.cuh" + +namespace mlx::core::cu { + +/////////////////////////////////////////////////////////////////////////////// +// Radix Select Implementation for CUDA +// +// Multi-pass radix-based selection algorithm for partition operations. +// Uses IEEE 754 bit manipulation for correct floating-point ordering. +/////////////////////////////////////////////////////////////////////////////// + +// Radix configuration used by the small shared-memory kernel. +constexpr int RADIX_BITS = 5; +constexpr int RADIX_SIZE = 1 << RADIX_BITS; // 32 bins + +/////////////////////////////////////////////////////////////////////////////// +// Bit manipulation for radix sorting +/////////////////////////////////////////////////////////////////////////////// + +template +struct RadixTraits; + +template <> +struct RadixTraits { + using UnsignedT = uint32_t; + static constexpr int BITS = 32; + + __device__ __forceinline__ static UnsignedT to_radix(float val) { + if (cuda::std::isnan(val)) { + return ~UnsignedT(0); + } + uint32_t bits = __float_as_uint(val); + if ((bits << 1) == 0) { + bits = 0; // Canonicalize +/-0.0 to +0.0 for stable equal-value ties. + } + uint32_t mask = -int32_t(bits >> 31) | 0x80000000u; + return bits ^ mask; + } +}; + +template <> +struct RadixTraits { + using UnsignedT = uint64_t; + static constexpr int BITS = 64; + + __device__ __forceinline__ static UnsignedT to_radix(double val) { + if (cuda::std::isnan(val)) { + return ~UnsignedT(0); + } + uint64_t bits = __double_as_longlong(val); + if ((bits << 1) == 0) { + bits = 0; // Canonicalize +/-0.0 to +0.0 for stable equal-value ties. + } + uint64_t mask = -int64_t(bits >> 63) | 0x8000000000000000ull; + return bits ^ mask; + } +}; + +template <> +struct RadixTraits { + using UnsignedT = uint64_t; + static constexpr int BITS = 64; + + __device__ __forceinline__ static UnsignedT to_radix(complex64_t val) { + float real = val.real(); + float imag = val.imag(); + if (cuda::std::isnan(real) || cuda::std::isnan(imag)) { + return ~UnsignedT(0); + } + + auto real_key = RadixTraits::to_radix(real); + auto imag_key = RadixTraits::to_radix(imag); + return (static_cast(real_key) << 32) | + static_cast(imag_key); + } +}; + +template <> +struct RadixTraits<__half> { + using UnsignedT = uint16_t; + static constexpr int BITS = 16; + + __device__ __forceinline__ static UnsignedT to_radix(__half val) { + if (cuda::std::isnan(val)) { + return ~UnsignedT(0); + } + uint16_t bits = __half_as_ushort(val); + if ((bits & 0x7FFFu) == 0) { + bits = 0; // Canonicalize +/-0 to +0 for stable equal-value ties. + } + uint16_t mask = -int16_t(bits >> 15) | 0x8000u; + return bits ^ mask; + } +}; + +template <> +struct RadixTraits<__nv_bfloat16> { + using UnsignedT = uint16_t; + static constexpr int BITS = 16; + + __device__ __forceinline__ static UnsignedT to_radix(__nv_bfloat16 val) { + if (cuda::std::isnan(val)) { + return ~UnsignedT(0); + } + uint16_t bits = __bfloat16_as_ushort(val); + if ((bits & 0x7FFFu) == 0) { + bits = 0; // Canonicalize +/-0 to +0 for stable equal-value ties. + } + uint16_t mask = -int16_t(bits >> 15) | 0x8000u; + return bits ^ mask; + } +}; + +template <> +struct RadixTraits { + using UnsignedT = uint8_t; + static constexpr int BITS = 8; + + __device__ __forceinline__ static UnsignedT to_radix(int8_t val) { + return static_cast(val) ^ 0x80u; + } +}; + +template <> +struct RadixTraits { + using UnsignedT = uint16_t; + static constexpr int BITS = 16; + + __device__ __forceinline__ static UnsignedT to_radix(int16_t val) { + return static_cast(val) ^ 0x8000u; + } +}; + +template <> +struct RadixTraits { + using UnsignedT = uint32_t; + static constexpr int BITS = 32; + + __device__ __forceinline__ static UnsignedT to_radix(int32_t val) { + return static_cast(val) ^ 0x80000000u; + } +}; + +template <> +struct RadixTraits { + using UnsignedT = uint64_t; + static constexpr int BITS = 64; + + __device__ __forceinline__ static UnsignedT to_radix(int64_t val) { + return static_cast(val) ^ 0x8000000000000000ull; + } +}; + +template <> +struct RadixTraits { + using UnsignedT = uint8_t; + static constexpr int BITS = 8; + + __device__ __forceinline__ static UnsignedT to_radix(bool val) { + return static_cast(val); + } +}; + +template <> +struct RadixTraits { + using UnsignedT = uint8_t; + static constexpr int BITS = 8; + + __device__ __forceinline__ static UnsignedT to_radix(uint8_t val) { + return val; + } +}; + +template <> +struct RadixTraits { + using UnsignedT = uint16_t; + static constexpr int BITS = 16; + + __device__ __forceinline__ static UnsignedT to_radix(uint16_t val) { + return val; + } +}; + +template <> +struct RadixTraits { + using UnsignedT = uint32_t; + static constexpr int BITS = 32; + + __device__ __forceinline__ static UnsignedT to_radix(uint32_t val) { + return val; + } +}; + +template <> +struct RadixTraits { + using UnsignedT = uint64_t; + static constexpr int BITS = 64; + + __device__ __forceinline__ static UnsignedT to_radix(uint64_t val) { + return val; + } +}; + +/////////////////////////////////////////////////////////////////////////////// +// Block-level utilities +/////////////////////////////////////////////////////////////////////////////// + +namespace cg = cooperative_groups; + +template +__device__ __forceinline__ int block_exclusive_scan( + cg::thread_block& block, + int val, + int* shared_warp_sums, + int* block_total = nullptr) { + static_assert(BLOCK_THREADS % WARP_SIZE == 0); + constexpr int NUM_WARPS = BLOCK_THREADS / WARP_SIZE; + + auto warp = cg::tiled_partition(block); + int inclusive = cg::inclusive_scan(warp, val, cg::plus()); + + if (warp.thread_rank() == WARP_SIZE - 1) { + shared_warp_sums[warp.meta_group_rank()] = inclusive; + } + block.sync(); + + if (warp.meta_group_rank() == 0) { + int warp_val = warp.thread_rank() < NUM_WARPS + ? shared_warp_sums[warp.thread_rank()] + : 0; + int warp_scan = cg::inclusive_scan(warp, warp_val, cg::plus()); + + if (warp.thread_rank() < NUM_WARPS) { + shared_warp_sums[warp.thread_rank()] = + warp_scan - shared_warp_sums[warp.thread_rank()]; + } + if (block_total != nullptr && warp.thread_rank() == NUM_WARPS - 1) { + *block_total = warp_scan; + } + } + block.sync(); + + return shared_warp_sums[warp.meta_group_rank()] + inclusive - val; +} + +/////////////////////////////////////////////////////////////////////////////// +// Radix Select for small arrays (fits in shared memory) +/////////////////////////////////////////////////////////////////////////////// + +template < + typename ValT, + typename OutT, + bool ARG_PARTITION, + bool USE_SIMPLE_STRIDE, + int BLOCK_THREADS, + int ITEMS_PER_THREAD> +__global__ void radix_select_small_kernel( + const ValT* input, + OutT* output, + int kth, + int n, + int64_t in_stride, + int64_t out_stride, + int64_t in_segment_stride, + int64_t out_segment_stride, + const __grid_constant__ Shape nc_shape, + const __grid_constant__ Strides in_nc_strides, + const __grid_constant__ Strides out_nc_strides, + int nc_dim) { + using Traits = RadixTraits; + using UnsignedT = typename Traits::UnsignedT; + + constexpr int TILE_SIZE = BLOCK_THREADS * ITEMS_PER_THREAD; + constexpr int NUM_PASSES = (Traits::BITS + RADIX_BITS - 1) / RADIX_BITS; + constexpr int NUM_WARPS = BLOCK_THREADS / WARP_SIZE; + + auto block = cg::this_thread_block(); + auto warp = cg::tiled_partition(block); + + __shared__ UnsignedT shared_keys[TILE_SIZE]; + __shared__ int shared_hist[RADIX_SIZE]; + __shared__ int shared_count[2]; + __shared__ int warp_less[NUM_WARPS]; + __shared__ int warp_equal[NUM_WARPS]; + __shared__ int warp_greater[NUM_WARPS]; + __shared__ int iter_counts[3]; + __shared__ int running_bases[3]; + + int row = blockIdx.y; + + const ValT* row_input; + OutT* row_output; + if constexpr (USE_SIMPLE_STRIDE) { + row_input = input + row * in_segment_stride; + row_output = output + row * out_segment_stride; + } else { + int64_t in_block_idx = elem_to_loc( + int64_t(row), nc_shape.data(), in_nc_strides.data(), nc_dim); + int64_t out_block_idx = elem_to_loc( + int64_t(row), nc_shape.data(), out_nc_strides.data(), nc_dim); + row_input = input + in_block_idx; + row_output = output + out_block_idx; + } + + int tile_n = min(n, TILE_SIZE); + + for (int i = block.thread_rank(); i < TILE_SIZE; i += BLOCK_THREADS) { + shared_keys[i] = (i < tile_n) ? Traits::to_radix(row_input[i * in_stride]) + : ~UnsignedT(0); + } + block.sync(); + + int k = kth + 1; + UnsignedT target_prefix = 0; + UnsignedT prefix_mask = 0; + + for (int pass = NUM_PASSES - 1; pass >= 0; pass--) { + int start_bit = pass * RADIX_BITS; + + for (int i = block.thread_rank(); i < RADIX_SIZE; i += BLOCK_THREADS) { + shared_hist[i] = 0; + } + block.sync(); + + for (int i = block.thread_rank(); i < tile_n; i += BLOCK_THREADS) { + UnsignedT key = shared_keys[i]; + if ((key & prefix_mask) == target_prefix) { + int digit = (key >> start_bit) & ((1 << RADIX_BITS) - 1); + atomicAdd(&shared_hist[digit], 1); + } + } + block.sync(); + + // Find target bin (single thread) + if (block.thread_rank() == 0) { + int cumsum = 0; + int target_bin = 0; + for (int bin = 0; bin < RADIX_SIZE; bin++) { + int count = shared_hist[bin]; + if (cumsum + count >= k) { + target_bin = bin; + k = k - cumsum; + break; + } + cumsum += count; + } + shared_count[0] = target_bin; + shared_count[1] = k; + } + block.sync(); + + int target_bin = shared_count[0]; + k = shared_count[1]; + + UnsignedT digit_mask = UnsignedT((1 << RADIX_BITS) - 1) << start_bit; + target_prefix |= UnsignedT(target_bin) << start_bit; + prefix_mask |= digit_mask; + + block.sync(); + } + + // Count per-thread bucket sizes once, then scatter in a single pass with + // deterministic per-thread offsets. + int local_less = 0; + int local_equal = 0; + for (int i = block.thread_rank(); i < tile_n; i += BLOCK_THREADS) { + UnsignedT key = shared_keys[i]; + if (key < target_prefix) { + local_less++; + } else if (key == target_prefix) { + local_equal++; + } + } + + (void)block_exclusive_scan( + block, local_less, shared_hist, &shared_count[0]); + (void)block_exclusive_scan( + block, local_equal, shared_hist, &shared_count[1]); + + int less_count = shared_count[0]; + int equal_count = shared_count[1]; + + // Scatter in increasing i order to keep tie behavior aligned with merge sort. + if (block.thread_rank() == 0) { + running_bases[0] = 0; + running_bases[1] = less_count; + running_bases[2] = less_count + equal_count; + } + block.sync(); + + for (int base_i = 0; base_i < tile_n; base_i += BLOCK_THREADS) { + int i = base_i + block.thread_rank(); + bool active = i < tile_n; + + UnsignedT key = 0; + if (active) { + key = shared_keys[i]; + } + + bool is_less = active && (key < target_prefix); + bool is_equal = active && (key == target_prefix); + bool is_greater = active && !is_less && !is_equal; + + unsigned less_ballot = warp.ballot(is_less); + unsigned equal_ballot = warp.ballot(is_equal); + unsigned greater_ballot = warp.ballot(is_greater); + + unsigned lane_mask = (1u << warp.thread_rank()) - 1u; + int less_rank = __popc(less_ballot & lane_mask); + int equal_rank = __popc(equal_ballot & lane_mask); + int greater_rank = __popc(greater_ballot & lane_mask); + + if (warp.thread_rank() == 0) { + warp_less[warp.meta_group_rank()] = __popc(less_ballot); + warp_equal[warp.meta_group_rank()] = __popc(equal_ballot); + warp_greater[warp.meta_group_rank()] = __popc(greater_ballot); + } + block.sync(); + + if (block.thread_rank() == 0) { + int run = 0; + for (int w = 0; w < NUM_WARPS; ++w) { + int c = warp_less[w]; + warp_less[w] = run; + run += c; + } + iter_counts[0] = run; + + run = 0; + for (int w = 0; w < NUM_WARPS; ++w) { + int c = warp_equal[w]; + warp_equal[w] = run; + run += c; + } + iter_counts[1] = run; + + run = 0; + for (int w = 0; w < NUM_WARPS; ++w) { + int c = warp_greater[w]; + warp_greater[w] = run; + run += c; + } + iter_counts[2] = run; + } + block.sync(); + + if (active) { + int pos; + if (is_less) { + pos = running_bases[0] + warp_less[warp.meta_group_rank()] + less_rank; + } else if (is_equal) { + pos = + running_bases[1] + warp_equal[warp.meta_group_rank()] + equal_rank; + } else { + pos = running_bases[2] + warp_greater[warp.meta_group_rank()] + + greater_rank; + } + + if constexpr (ARG_PARTITION) { + row_output[pos * out_stride] = i; + } else { + row_output[pos * out_stride] = row_input[i * in_stride]; + } + } + block.sync(); + + if (block.thread_rank() == 0) { + running_bases[0] += iter_counts[0]; + running_bases[1] += iter_counts[1]; + running_bases[2] += iter_counts[2]; + } + block.sync(); + } +} + +} // namespace mlx::core::cu diff --git a/mlx/backend/cuda/partition.cu b/mlx/backend/cuda/partition.cu new file mode 100644 index 0000000000..0032c1abb5 --- /dev/null +++ b/mlx/backend/cuda/partition.cu @@ -0,0 +1,217 @@ +// Copyright © 2026 Apple Inc. + +#include +#include +#include + +#include "mlx/backend/cuda/device.h" +#include "mlx/backend/cuda/device/radix_select.cuh" +#include "mlx/backend/cuda/kernel_utils.cuh" +#include "mlx/dtype.h" +#include "mlx/dtype_utils.h" + +namespace mlx::core { + +namespace { + +// Upper bound for small-kernel tiling. Keep this aligned with the +// items-per-thread dispatch set and per-block shared-memory budget. +constexpr int MAX_RADIX_ITEMS_PER_THREAD = 64; +constexpr size_t RADIX_SMALL_SHARED_MEM_BUDGET_BYTES = 48 * 1024; + +template +void dispatch_radix_small_block_threads(int size_sorted_axis, F&& f) { + if (size_sorted_axis <= 256) { + f(std::integral_constant{}); + } else if (size_sorted_axis <= 512) { + f(std::integral_constant{}); + } else if (size_sorted_axis <= 1024) { + f(std::integral_constant{}); + } else { + f(std::integral_constant{}); + } +} + +template +void dispatch_radix_items_per_thread( + int size_sorted_axis, + int block_threads, + F&& f) { + int items_per_thread = (size_sorted_axis + block_threads - 1) / block_threads; + if (items_per_thread <= 1) { + f(std::integral_constant{}); + } else if (items_per_thread <= 2) { + f(std::integral_constant{}); + } else if (items_per_thread <= 4) { + f(std::integral_constant{}); + } else if (items_per_thread <= 8) { + f(std::integral_constant{}); + } else if (items_per_thread <= 12) { + f(std::integral_constant{}); + } else if (items_per_thread <= 16) { + f(std::integral_constant{}); + } else if (items_per_thread <= 24) { + f(std::integral_constant{}); + } else { + f(std::integral_constant{}); + } +} + +constexpr size_t radix_small_shared_mem_bytes( + size_t key_size, + size_t block_threads, + size_t items_per_thread) { + size_t tile_size = block_threads * items_per_thread; + size_t num_warps = block_threads / WARP_SIZE; + return tile_size * key_size + // shared_keys + cu::RADIX_SIZE * sizeof(int) + // shared_hist + (2 + 3 * num_warps + 6) * sizeof(int); // shared_count + scatter scratch +} + +} // namespace + +bool gpu_partition_small_fits(Dtype dtype, int size_sorted_axis) { + if (size_sorted_axis <= 0) { + return false; + } + + bool fits = false; + dispatch_radix_small_block_threads(size_sorted_axis, [&](auto block_dim_tag) { + constexpr int BLOCK_THREADS = block_dim_tag(); + int required_items = (size_sorted_axis + BLOCK_THREADS - 1) / BLOCK_THREADS; + if (required_items > MAX_RADIX_ITEMS_PER_THREAD) { + fits = false; + return; + } + + dispatch_radix_items_per_thread( + size_sorted_axis, BLOCK_THREADS, [&](auto items_per_thread_tag) { + constexpr int ITEMS_PER_THREAD = items_per_thread_tag(); + fits = radix_small_shared_mem_bytes( + size_of(dtype), BLOCK_THREADS, ITEMS_PER_THREAD) <= + RADIX_SMALL_SHARED_MEM_BUDGET_BYTES; + }); + }); + return fits; +} + +void gpu_partition_small( + const Stream& s, + const array& in, + array& out, + int axis, + int kth, + bool arg_partition) { + int n_rows = in.size() / in.shape(axis); + + auto in_nc_str = in.strides(); + in_nc_str.erase(in_nc_str.begin() + axis); + + auto out_nc_str = out.strides(); + out_nc_str.erase(out_nc_str.begin() + axis); + + auto nc_shape = in.shape(); + nc_shape.erase(nc_shape.begin() + axis); + + int nc_dim = nc_shape.size(); + + int size_sorted_axis = in.shape(axis); + int64_t in_stride_sorted_axis = in.strides()[axis]; + int64_t out_stride_sorted_axis = out.strides()[axis]; + + bool contiguous = in.flags().contiguous; + auto check_strides = [](const array& x, int64_t sort_stride) { + int64_t min_stride = + *std::min_element(x.strides().begin(), x.strides().end()); + int64_t max_stride = + *std::max_element(x.strides().begin(), x.strides().end()); + return sort_stride == min_stride || sort_stride == max_stride; + }; + contiguous &= check_strides(in, in_stride_sorted_axis); + contiguous &= check_strides(out, out_stride_sorted_axis); + + auto& encoder = cu::get_command_encoder(s); + out.set_data(cu::malloc_async(out.nbytes(), encoder)); + encoder.set_input_array(in); + encoder.set_output_array(out); + + auto nc_shape_param = const_param(nc_shape); + auto in_nc_strides_param = const_param(in_nc_str); + auto out_nc_strides_param = const_param(out_nc_str); + + dispatch_all_types(in.dtype(), [&](auto type_tag) { + using CTYPE = MLX_GET_TYPE(type_tag); + using ValT = cuda_type_t; + + dispatch_bool(arg_partition, [&](auto arg_tag) { + constexpr bool ARG_PARTITION = decltype(arg_tag)::value; + using OutT = std::conditional_t; + + int64_t in_stride_segment_axis = INT64_MAX; + int64_t out_stride_segment_axis = INT64_MAX; + if (contiguous) { + for (size_t i = 0; i < nc_shape.size(); i++) { + if (nc_shape[i] == 1) { + continue; + } + in_stride_segment_axis = + std::min(in_stride_segment_axis, in_nc_str[i]); + out_stride_segment_axis = + std::min(out_stride_segment_axis, out_nc_str[i]); + } + } + + dispatch_radix_small_block_threads( + size_sorted_axis, [&](auto block_dim_tag) { + constexpr int BLOCK_THREADS = block_dim_tag(); + dim3 grid(1, n_rows, 1); + dim3 block(BLOCK_THREADS, 1, 1); + + dispatch_radix_items_per_thread( + size_sorted_axis, + BLOCK_THREADS, + [&](auto items_per_thread_tag) { + constexpr int ITEMS_PER_THREAD = items_per_thread_tag(); + + constexpr size_t SMEM = radix_small_shared_mem_bytes( + sizeof(typename cu::RadixTraits::UnsignedT), + BLOCK_THREADS, + ITEMS_PER_THREAD); + if constexpr (SMEM <= RADIX_SMALL_SHARED_MEM_BUDGET_BYTES) { + dispatch_bool(contiguous, [&](auto contiguous_tag) { + constexpr bool USE_SIMPLE_STRIDE = + decltype(contiguous_tag)::value; + + auto kernel = cu::radix_select_small_kernel< + ValT, + OutT, + ARG_PARTITION, + USE_SIMPLE_STRIDE, + BLOCK_THREADS, + ITEMS_PER_THREAD>; + + encoder.add_kernel_node( + kernel, + grid, + block, + gpu_ptr(in), + gpu_ptr(out), + kth, + size_sorted_axis, + in_stride_sorted_axis, + out_stride_sorted_axis, + in_stride_segment_axis, + out_stride_segment_axis, + nc_shape_param, + in_nc_strides_param, + out_nc_strides_param, + nc_dim); + }); + } + }); + }); + }); + }); +} + +} // namespace mlx::core diff --git a/mlx/backend/cuda/sort.cu b/mlx/backend/cuda/sort.cu index 43756f7078..271f0841c8 100644 --- a/mlx/backend/cuda/sort.cu +++ b/mlx/backend/cuda/sort.cu @@ -718,6 +718,16 @@ __global__ void mb_block_merge_kernel( } // namespace cu +bool gpu_partition_small_fits(Dtype dtype, int size_sorted_axis); + +void gpu_partition_small( + const Stream& s, + const array& in, + array& out, + int axis, + int kth, + bool arg_partition); + namespace { void single_block_sort( @@ -1064,14 +1074,42 @@ void Sort::eval_gpu(const std::vector& inputs, array& out) { gpu_sort(stream(), inputs[0], out, axis_, false); } +namespace { + +void gpu_partition( + const Stream& s, + const array& in, + array& out, + int axis_, + int kth_, + bool arg_partition) { + int axis = axis_ < 0 ? axis_ + in.ndim() : axis_; + int size_sorted_axis = in.shape(axis); + int kth = kth_ < 0 ? kth_ + size_sorted_axis : kth_; + + // Fixed-size const_param metadata is capped by MAX_NDIM. + if (in.ndim() > MAX_NDIM) { + return gpu_merge_sort(s, in, out, axis, arg_partition); + } + + // Dispatch based on whether the small kernel tile fits in shared memory. + if (gpu_partition_small_fits(in.dtype(), size_sorted_axis)) { + return gpu_partition_small(s, in, out, axis, kth, arg_partition); + } else { + return gpu_merge_sort(s, in, out, axis, arg_partition); + } +} + +} // namespace + void ArgPartition::eval_gpu(const std::vector& inputs, array& out) { nvtx3::scoped_range r("ArgPartition::eval_gpu"); - gpu_sort(stream(), inputs[0], out, axis_, true); + gpu_partition(stream(), inputs[0], out, axis_, kth_, true); } void Partition::eval_gpu(const std::vector& inputs, array& out) { nvtx3::scoped_range r("Partition::eval_gpu"); - gpu_sort(stream(), inputs[0], out, axis_, false); + gpu_partition(stream(), inputs[0], out, axis_, kth_, false); } -} // namespace mlx::core \ No newline at end of file +} // namespace mlx::core