From df9a4996088aedd25b37ac472f3fd0414cc05bb6 Mon Sep 17 00:00:00 2001 From: Lyxot Date: Fri, 20 Mar 2026 21:37:38 +0800 Subject: [PATCH 1/2] [Metal] Support sorting complex numbers --- mlx/backend/metal/kernels/complex.h | 4 ++++ mlx/backend/metal/kernels/sort.h | 7 +++++++ mlx/backend/metal/kernels/sort.metal | 4 +++- python/tests/test_ops.py | 12 ++++-------- 4 files changed, 18 insertions(+), 9 deletions(-) diff --git a/mlx/backend/metal/kernels/complex.h b/mlx/backend/metal/kernels/complex.h index 6e391483d3..419d9ee0a5 100644 --- a/mlx/backend/metal/kernels/complex.h +++ b/mlx/backend/metal/kernels/complex.h @@ -97,6 +97,10 @@ constexpr bool operator<(complex64_t a, complex64_t b) { return operator>(b, a); } +constexpr bool isnan(complex64_t x) { + return isnan(x.real) || isnan(x.imag); +} + constexpr bool operator==(complex64_t a, complex64_t b) { return a.real == b.real && a.imag == b.imag; } diff --git a/mlx/backend/metal/kernels/sort.h b/mlx/backend/metal/kernels/sort.h index 0d357333ac..c8de634d5b 100644 --- a/mlx/backend/metal/kernels/sort.h +++ b/mlx/backend/metal/kernels/sort.h @@ -29,6 +29,13 @@ struct Init>> { static constexpr constant T v = metal::numeric_limits::quiet_NaN(); }; +template <> +struct Init { + static constexpr constant complex64_t v = complex64_t( + metal::numeric_limits::quiet_NaN(), + metal::numeric_limits::quiet_NaN()); +}; + template struct LessThan { static constexpr constant T init = Init::v; diff --git a/mlx/backend/metal/kernels/sort.metal b/mlx/backend/metal/kernels/sort.metal index 7f198d55d3..d9bf20476f 100644 --- a/mlx/backend/metal/kernels/sort.metal +++ b/mlx/backend/metal/kernels/sort.metal @@ -50,6 +50,7 @@ instantiate_block_sort_bn(bfloat16, bfloat16_t) instantiate_block_sort_long(uint64, uint64_t) instantiate_block_sort_long(int64, int64_t) +instantiate_block_sort_long(complex64, complex64_t) #define instantiate_multi_block_sort( \ vtname, vtype, itname, itype, arg_sort, bn, tn) \ @@ -77,4 +78,5 @@ instantiate_multi_block_sort_base(bfloat16, bfloat16_t) instantiate_multi_block_sort(vtname, vtype, uint32, uint32_t, true, 256, 4) instantiate_multi_block_sort_long(uint64, uint64_t) -instantiate_multi_block_sort_long(int64, int64_t) // clang-format on +instantiate_multi_block_sort_long(int64, int64_t) +instantiate_multi_block_sort_long(complex64, complex64_t) // clang-format on diff --git a/python/tests/test_ops.py b/python/tests/test_ops.py index 4aff1eaddf..12b5d3f394 100644 --- a/python/tests/test_ops.py +++ b/python/tests/test_ops.py @@ -2192,11 +2192,8 @@ def test_squeeze_expand(self): def test_sort(self): shape = (6, 4, 10) - dtypes = ["int32", "float32"] - if not mx.metal.is_available(): - dtypes.append("complex64") tests = product( - dtypes, # type + ("int32", "float32", "complex64"), # type (None, 0, 1, 2), # axis (True, False), # strided ) @@ -3326,10 +3323,9 @@ def test_sort_nan(self): expected = mx.array([0.0, 2.0, 3.0, mx.nan], dtype=dtype) self.assertTrue(mx.array_equal(mx.sort(x), expected, equal_nan=True)) - if not mx.metal.is_available(): - x = mx.array([3.0 + 1j, mx.nan + 2j, 2.0 + 1j, 0.0 + 1j]) - expected = mx.array([0.0 + 1j, 2.0 + 1j, 3.0 + 1j, mx.nan + 2j]) - self.assertTrue(mx.array_equal(mx.sort(x), expected, equal_nan=True)) + x = mx.array([3.0 + 1j, mx.nan + 2j, 2.0 + 1j, 0.0 + 1j]) + expected = mx.array([0.0 + 1j, 2.0 + 1j, 3.0 + 1j, mx.nan + 2j]) + self.assertTrue(mx.array_equal(mx.sort(x), expected, equal_nan=True)) def test_argsort_nan(self): for dtype in [mx.float32, mx.float16, mx.bfloat16]: From 894427377c0a68497f97cd1715d59ff61efd6083 Mon Sep 17 00:00:00 2001 From: Lyxot Date: Wed, 25 Mar 2026 20:28:16 +0800 Subject: [PATCH 2/2] inline complex isnan --- mlx/backend/metal/kernels/complex.h | 4 ---- mlx/backend/metal/kernels/sort.h | 13 +++++++++---- 2 files changed, 9 insertions(+), 8 deletions(-) diff --git a/mlx/backend/metal/kernels/complex.h b/mlx/backend/metal/kernels/complex.h index 419d9ee0a5..6e391483d3 100644 --- a/mlx/backend/metal/kernels/complex.h +++ b/mlx/backend/metal/kernels/complex.h @@ -97,10 +97,6 @@ constexpr bool operator<(complex64_t a, complex64_t b) { return operator>(b, a); } -constexpr bool isnan(complex64_t x) { - return isnan(x.real) || isnan(x.imag); -} - constexpr bool operator==(complex64_t a, complex64_t b) { return a.real == b.real && a.imag == b.imag; } diff --git a/mlx/backend/metal/kernels/sort.h b/mlx/backend/metal/kernels/sort.h index c8de634d5b..7d7fab0f9a 100644 --- a/mlx/backend/metal/kernels/sort.h +++ b/mlx/backend/metal/kernels/sort.h @@ -40,10 +40,15 @@ template struct LessThan { static constexpr constant T init = Init::v; METAL_FUNC bool operator()(T a, T b) const { - if constexpr ( - metal::is_floating_point_v || metal::is_same_v) { - bool an = isnan(a); - bool bn = isnan(b); + if constexpr (metal::is_floating_point_v) { + bool an = metal::isnan(a); + bool bn = metal::isnan(b); + if (an | bn) { + return (!an) & bn; + } + } else if constexpr (metal::is_same_v) { + bool an = metal::isnan(a.real) || metal::isnan(a.imag); + bool bn = metal::isnan(b.real) || metal::isnan(b.imag); if (an | bn) { return (!an) & bn; }