diff --git a/mlx/backend/metal/kernels/sort.h b/mlx/backend/metal/kernels/sort.h index 0d357333ac..7d7fab0f9a 100644 --- a/mlx/backend/metal/kernels/sort.h +++ b/mlx/backend/metal/kernels/sort.h @@ -29,14 +29,26 @@ struct Init>> { static constexpr constant T v = metal::numeric_limits::quiet_NaN(); }; +template <> +struct Init { + static constexpr constant complex64_t v = complex64_t( + metal::numeric_limits::quiet_NaN(), + metal::numeric_limits::quiet_NaN()); +}; + template struct LessThan { static constexpr constant T init = Init::v; METAL_FUNC bool operator()(T a, T b) const { - if constexpr ( - metal::is_floating_point_v || metal::is_same_v) { - bool an = isnan(a); - bool bn = isnan(b); + if constexpr (metal::is_floating_point_v) { + bool an = metal::isnan(a); + bool bn = metal::isnan(b); + if (an | bn) { + return (!an) & bn; + } + } else if constexpr (metal::is_same_v) { + bool an = metal::isnan(a.real) || metal::isnan(a.imag); + bool bn = metal::isnan(b.real) || metal::isnan(b.imag); if (an | bn) { return (!an) & bn; } diff --git a/mlx/backend/metal/kernels/sort.metal b/mlx/backend/metal/kernels/sort.metal index 7f198d55d3..d9bf20476f 100644 --- a/mlx/backend/metal/kernels/sort.metal +++ b/mlx/backend/metal/kernels/sort.metal @@ -50,6 +50,7 @@ instantiate_block_sort_bn(bfloat16, bfloat16_t) instantiate_block_sort_long(uint64, uint64_t) instantiate_block_sort_long(int64, int64_t) +instantiate_block_sort_long(complex64, complex64_t) #define instantiate_multi_block_sort( \ vtname, vtype, itname, itype, arg_sort, bn, tn) \ @@ -77,4 +78,5 @@ instantiate_multi_block_sort_base(bfloat16, bfloat16_t) instantiate_multi_block_sort(vtname, vtype, uint32, uint32_t, true, 256, 4) instantiate_multi_block_sort_long(uint64, uint64_t) -instantiate_multi_block_sort_long(int64, int64_t) // clang-format on +instantiate_multi_block_sort_long(int64, int64_t) +instantiate_multi_block_sort_long(complex64, complex64_t) // clang-format on diff --git a/python/tests/test_ops.py b/python/tests/test_ops.py index 4aff1eaddf..12b5d3f394 100644 --- a/python/tests/test_ops.py +++ b/python/tests/test_ops.py @@ -2192,11 +2192,8 @@ def test_squeeze_expand(self): def test_sort(self): shape = (6, 4, 10) - dtypes = ["int32", "float32"] - if not mx.metal.is_available(): - dtypes.append("complex64") tests = product( - dtypes, # type + ("int32", "float32", "complex64"), # type (None, 0, 1, 2), # axis (True, False), # strided ) @@ -3326,10 +3323,9 @@ def test_sort_nan(self): expected = mx.array([0.0, 2.0, 3.0, mx.nan], dtype=dtype) self.assertTrue(mx.array_equal(mx.sort(x), expected, equal_nan=True)) - if not mx.metal.is_available(): - x = mx.array([3.0 + 1j, mx.nan + 2j, 2.0 + 1j, 0.0 + 1j]) - expected = mx.array([0.0 + 1j, 2.0 + 1j, 3.0 + 1j, mx.nan + 2j]) - self.assertTrue(mx.array_equal(mx.sort(x), expected, equal_nan=True)) + x = mx.array([3.0 + 1j, mx.nan + 2j, 2.0 + 1j, 0.0 + 1j]) + expected = mx.array([0.0 + 1j, 2.0 + 1j, 3.0 + 1j, mx.nan + 2j]) + self.assertTrue(mx.array_equal(mx.sort(x), expected, equal_nan=True)) def test_argsort_nan(self): for dtype in [mx.float32, mx.float16, mx.bfloat16]: