diff --git a/mlx/backend/cuda/sort.cu b/mlx/backend/cuda/sort.cu index 655462a2cf..43756f7078 100644 --- a/mlx/backend/cuda/sort.cu +++ b/mlx/backend/cuda/sort.cu @@ -44,6 +44,12 @@ __device__ __forceinline__ __nv_bfloat16 nan_value<__nv_bfloat16>() { return __float2bfloat16(cuda::std::numeric_limits::quiet_NaN()); } +template <> +__device__ __forceinline__ complex64_t nan_value() { + float qnan = cuda::std::numeric_limits::quiet_NaN(); + return complex64_t{qnan, qnan}; +} + template struct InitValue { __device__ __forceinline__ static T value() { @@ -52,7 +58,9 @@ struct InitValue { }; template -struct InitValue>> { +struct InitValue< + T, + cuda::std::enable_if_t || cu::is_complex_v>> { __device__ __forceinline__ static T value() { return nan_value(); } @@ -65,6 +73,15 @@ __device__ __forceinline__ void thread_swap(T& a, T& b) { b = w; } +template +__device__ __forceinline__ bool check_nan(T a) { + if constexpr (cu::is_complex_v) { + return cuda::std::isnan(a.real()) || cuda::std::isnan(a.imag()); + } else { + return cuda::std::isnan(a); + } +} + template struct LessThan { __device__ __forceinline__ static T init() { @@ -72,9 +89,9 @@ struct LessThan { } __device__ __forceinline__ bool operator()(T a, T b) const { - if constexpr (is_floating_v) { - bool an = cuda::std::isnan(a); - bool bn = cuda::std::isnan(b); + if constexpr (is_floating_v || cu::is_complex_v) { + bool an = check_nan(a); + bool bn = check_nan(b); if (an | bn) { return (!an) & bn; } @@ -745,82 +762,76 @@ void single_block_sort( dispatch_all_types(in.dtype(), [&](auto type_tag) { using CTYPE = MLX_GET_TYPE(type_tag); - if constexpr (!std::is_same_v) { - using ValT = cuda_type_t; - dispatch_block_dim(bn, [&](auto block_dim) { - constexpr int BLOCK_THREADS = block_dim(); - if constexpr (BLOCK_THREADS < 1024) { - dim3 grid(1, n_rows, 1); - dim3 block(BLOCK_THREADS, 1, 1); - - dispatch_bool(argsort, [&](auto arg_tag) { - constexpr bool ARG_SORT = decltype(arg_tag)::value; - using OutT = std::conditional_t; - - if (contiguous) { - auto kernel = cu::block_sort_kernel< - ValT, - OutT, - ARG_SORT, - BLOCK_THREADS, - N_PER_THREAD>; - int64_t in_stride_segment_axis = INT64_MAX; - int64_t out_stride_segment_axis = INT64_MAX; - for (int i = 0; i < nc_shape.size(); i++) { - if (nc_shape[i] == 1) { - continue; - } - if (in_nc_str[i] > INT32_MAX || out_nc_str[i] > INT32_MAX) { - throw std::runtime_error( - "[Sort::eval_gpu] Stride too large."); - } - in_stride_segment_axis = - std::min(in_stride_segment_axis, in_nc_str[i]); - out_stride_segment_axis = - std::min(out_stride_segment_axis, out_nc_str[i]); + using ValT = cuda_type_t; + dispatch_block_dim(bn, [&](auto block_dim) { + constexpr int BLOCK_THREADS = block_dim(); + if constexpr (BLOCK_THREADS < 1024) { + dim3 grid(1, n_rows, 1); + dim3 block(BLOCK_THREADS, 1, 1); + + dispatch_bool(argsort, [&](auto arg_tag) { + constexpr bool ARG_SORT = decltype(arg_tag)::value; + using OutT = std::conditional_t; + + if (contiguous) { + auto kernel = cu::block_sort_kernel< + ValT, + OutT, + ARG_SORT, + BLOCK_THREADS, + N_PER_THREAD>; + int64_t in_stride_segment_axis = INT64_MAX; + int64_t out_stride_segment_axis = INT64_MAX; + for (int i = 0; i < nc_shape.size(); i++) { + if (nc_shape[i] == 1) { + continue; } - encoder.add_kernel_node( - kernel, - grid, - block, - gpu_ptr(in), - gpu_ptr(out), - size_sorted_axis, - in_stride_sorted_axis, - out_stride_sorted_axis, - in_stride_segment_axis, - out_stride_segment_axis); - } else { - auto kernel = cu::block_sort_nc_kernel< - ValT, - OutT, - ARG_SORT, - BLOCK_THREADS, - N_PER_THREAD>; - auto nc_shape_param = const_param(nc_shape); - auto in_nc_strides_param = const_param(in_nc_str); - auto out_nc_strides_param = const_param(out_nc_str); - encoder.add_kernel_node( - kernel, - grid, - block, - gpu_ptr(in), - gpu_ptr(out), - size_sorted_axis, - in_stride_sorted_axis, - out_stride_sorted_axis, - nc_shape_param, - in_nc_strides_param, - out_nc_strides_param, - nc_dim); + if (in_nc_str[i] > INT32_MAX || out_nc_str[i] > INT32_MAX) { + throw std::runtime_error("[Sort::eval_gpu] Stride too large."); + } + in_stride_segment_axis = + std::min(in_stride_segment_axis, in_nc_str[i]); + out_stride_segment_axis = + std::min(out_stride_segment_axis, out_nc_str[i]); } - }); - } - }); - } else { - throw std::runtime_error( - "CUDA backend does not support sorting complex numbers"); - } + encoder.add_kernel_node( + kernel, + grid, + block, + gpu_ptr(in), + gpu_ptr(out), + size_sorted_axis, + in_stride_sorted_axis, + out_stride_sorted_axis, + in_stride_segment_axis, + out_stride_segment_axis); + } else { + auto kernel = cu::block_sort_nc_kernel< + ValT, + OutT, + ARG_SORT, + BLOCK_THREADS, + N_PER_THREAD>; + auto nc_shape_param = const_param(nc_shape); + auto in_nc_strides_param = const_param(in_nc_str); + auto out_nc_strides_param = const_param(out_nc_str); + encoder.add_kernel_node( + kernel, + grid, + block, + gpu_ptr(in), + gpu_ptr(out), + size_sorted_axis, + in_stride_sorted_axis, + out_stride_sorted_axis, + nc_shape_param, + in_nc_strides_param, + out_nc_strides_param, + nc_dim); + } + }); + } + }); }); } @@ -870,99 +881,94 @@ void multi_block_sort( dispatch_all_types(in.dtype(), [&](auto type_tag) { using CTYPE = MLX_GET_TYPE(type_tag); - if constexpr (!std::is_same_v) { - using ValT = cuda_type_t; - using IdxT = uint32_t; - constexpr int BLOCK_THREADS = sizeof(ValT) == 8 ? 256 : 512; - dim3 grid(n_blocks, n_rows, 1); - dim3 block(BLOCK_THREADS, 1, 1); - - dispatch_bool(argsort, [&](auto arg_tag) { - constexpr bool ARG_SORT = decltype(arg_tag)::value; - auto nc_shape_param = const_param(nc_shape); - auto nc_strides_param = const_param(nc_str); - - auto block_sort_kernel = cu::mb_block_sort_kernel< + using ValT = cuda_type_t; + using IdxT = uint32_t; + constexpr int BLOCK_THREADS = sizeof(ValT) == 8 ? 256 : 512; + dim3 grid(n_blocks, n_rows, 1); + dim3 block(BLOCK_THREADS, 1, 1); + + dispatch_bool(argsort, [&](auto arg_tag) { + constexpr bool ARG_SORT = decltype(arg_tag)::value; + auto nc_shape_param = const_param(nc_shape); + auto nc_strides_param = const_param(nc_str); + + auto block_sort_kernel = cu::mb_block_sort_kernel< + ValT, + IdxT, + ARG_SORT, + BLOCK_THREADS, + N_PER_THREAD>; + encoder.set_input_array(in); + encoder.set_output_array(dev_vals_in); + encoder.set_output_array(dev_idxs_in); + encoder.add_kernel_node( + block_sort_kernel, + grid, + block, + gpu_ptr(in), + gpu_ptr(dev_vals_in), + gpu_ptr(dev_idxs_in), + size_sorted_axis, + stride_sorted_axis, + nc_shape_param, + nc_strides_param, + nc_dim); + + int n_thr_per_group = (n_blocks + 1) < 1024 ? (n_blocks + 1) : 1024; + + for (int merge_tiles = 2; (merge_tiles / 2) < n_blocks; + merge_tiles *= 2) { + auto partition_kernel = cu::mb_block_partition_kernel< + ValT, + IdxT, + ARG_SORT, + BLOCK_THREADS, + N_PER_THREAD>; + + encoder.set_input_array(dev_vals_in); + encoder.set_input_array(dev_idxs_in); + encoder.set_output_array(block_partitions); + + encoder.add_kernel_node( + partition_kernel, + dim3(1, n_rows, 1), + dim3(n_thr_per_group, 1, 1), + gpu_ptr(block_partitions), + gpu_ptr(dev_vals_in), + gpu_ptr(dev_idxs_in), + size_sorted_axis, + merge_tiles, + n_blocks); + + auto merge_kernel = cu::mb_block_merge_kernel< ValT, IdxT, ARG_SORT, BLOCK_THREADS, N_PER_THREAD>; - encoder.set_input_array(in); - encoder.set_output_array(dev_vals_in); - encoder.set_output_array(dev_idxs_in); + + encoder.set_input_array(dev_vals_in); + encoder.set_input_array(dev_idxs_in); + encoder.set_input_array(block_partitions); + encoder.set_output_array(dev_vals_out); + encoder.set_output_array(dev_idxs_out); + encoder.add_kernel_node( - block_sort_kernel, - grid, - block, - gpu_ptr(in), + merge_kernel, + dim3(n_blocks, n_rows, 1), + dim3(BLOCK_THREADS, 1, 1), + gpu_ptr(block_partitions), gpu_ptr(dev_vals_in), gpu_ptr(dev_idxs_in), + gpu_ptr(dev_vals_out), + gpu_ptr(dev_idxs_out), size_sorted_axis, - stride_sorted_axis, - nc_shape_param, - nc_strides_param, - nc_dim); - - int n_thr_per_group = (n_blocks + 1) < 1024 ? (n_blocks + 1) : 1024; - - for (int merge_tiles = 2; (merge_tiles / 2) < n_blocks; - merge_tiles *= 2) { - auto partition_kernel = cu::mb_block_partition_kernel< - ValT, - IdxT, - ARG_SORT, - BLOCK_THREADS, - N_PER_THREAD>; - - encoder.set_input_array(dev_vals_in); - encoder.set_input_array(dev_idxs_in); - encoder.set_output_array(block_partitions); - - encoder.add_kernel_node( - partition_kernel, - dim3(1, n_rows, 1), - dim3(n_thr_per_group, 1, 1), - gpu_ptr(block_partitions), - gpu_ptr(dev_vals_in), - gpu_ptr(dev_idxs_in), - size_sorted_axis, - merge_tiles, - n_blocks); - - auto merge_kernel = cu::mb_block_merge_kernel< - ValT, - IdxT, - ARG_SORT, - BLOCK_THREADS, - N_PER_THREAD>; - - encoder.set_input_array(dev_vals_in); - encoder.set_input_array(dev_idxs_in); - encoder.set_input_array(block_partitions); - encoder.set_output_array(dev_vals_out); - encoder.set_output_array(dev_idxs_out); - - encoder.add_kernel_node( - merge_kernel, - dim3(n_blocks, n_rows, 1), - dim3(BLOCK_THREADS, 1, 1), - gpu_ptr(block_partitions), - gpu_ptr(dev_vals_in), - gpu_ptr(dev_idxs_in), - gpu_ptr(dev_vals_out), - gpu_ptr(dev_idxs_out), - size_sorted_axis, - merge_tiles, - n_blocks); - std::swap(dev_vals_in, dev_vals_out); - std::swap(dev_idxs_in, dev_idxs_out); - } - }); - } else { - throw std::runtime_error( - "CUDA backend does not support sorting complex numbers"); - } + merge_tiles, + n_blocks); + std::swap(dev_vals_in, dev_vals_out); + std::swap(dev_idxs_in, dev_idxs_out); + } + }); }); encoder.add_temporary(dev_vals_out); diff --git a/python/tests/test_ops.py b/python/tests/test_ops.py index 57443479a8..4aff1eaddf 100644 --- a/python/tests/test_ops.py +++ b/python/tests/test_ops.py @@ -2192,8 +2192,11 @@ def test_squeeze_expand(self): def test_sort(self): shape = (6, 4, 10) + dtypes = ["int32", "float32"] + if not mx.metal.is_available(): + dtypes.append("complex64") tests = product( - ("int32", "float32"), # type + dtypes, # type (None, 0, 1, 2), # axis (True, False), # strided ) @@ -2201,7 +2204,13 @@ def test_sort(self): with self.subTest(dtype=dtype, axis=axis, strided=strided): np.random.seed(0) np_dtype = getattr(np, dtype) - a_np = np.random.uniform(0, 100, size=shape).astype(np_dtype) + if np.issubdtype(np_dtype, np.complexfloating): + a_np = ( + np.random.uniform(0, 100, size=shape) + + 1j * np.random.uniform(0, 100, size=shape) + ).astype(np_dtype) + else: + a_np = np.random.uniform(0, 100, size=shape).astype(np_dtype) a_mx = mx.array(a_np) if strided: a_mx = a_mx[::2, :, ::2] @@ -3317,7 +3326,10 @@ def test_sort_nan(self): expected = mx.array([0.0, 2.0, 3.0, mx.nan], dtype=dtype) self.assertTrue(mx.array_equal(mx.sort(x), expected, equal_nan=True)) - x = mx.array([3.0, mx.nan, 2.0, 0.0]) + 1j * mx.array([1.0] * 4) + if not mx.metal.is_available(): + x = mx.array([3.0 + 1j, mx.nan + 2j, 2.0 + 1j, 0.0 + 1j]) + expected = mx.array([0.0 + 1j, 2.0 + 1j, 3.0 + 1j, mx.nan + 2j]) + self.assertTrue(mx.array_equal(mx.sort(x), expected, equal_nan=True)) def test_argsort_nan(self): for dtype in [mx.float32, mx.float16, mx.bfloat16]: