diff --git a/mlx/backend/cpu/sort.cpp b/mlx/backend/cpu/sort.cpp index fcf12d7ad7..67b471f6ad 100644 --- a/mlx/backend/cpu/sort.cpp +++ b/mlx/backend/cpu/sort.cpp @@ -15,10 +15,14 @@ namespace mlx::core { namespace { +template +inline constexpr bool is_floating_v = std::is_floating_point_v || + std::is_same_v || std::is_same_v; + // NaN-aware comparator that places NaNs at the end template bool nan_aware_less(T a, T b) { - if constexpr (std::is_floating_point_v || std::is_same_v) { + if constexpr (is_floating_v || std::is_same_v) { if (std::isnan(a)) return false; if (std::isnan(b)) @@ -198,7 +202,7 @@ void argsort(const array& in, array& out, int axis) { auto v2 = data_ptr[b * in_stride]; // Handle NaNs (place them at the end) - if (std::is_floating_point::value) { + if constexpr (is_floating_v) { if (std::isnan(v1)) return false; if (std::isnan(v2)) @@ -299,7 +303,7 @@ void argpartition(const array& in, array& out, int axis, int kth) { auto v2 = data_ptr[b * in_stride]; // Handle NaNs (place them at the end) - if (std::is_floating_point::value) { + if constexpr (is_floating_v) { if (std::isnan(v1)) return false; if (std::isnan(v2)) diff --git a/mlx/backend/cuda/kernel_utils.cuh b/mlx/backend/cuda/kernel_utils.cuh index 1507a0ae13..02868961f5 100644 --- a/mlx/backend/cuda/kernel_utils.cuh +++ b/mlx/backend/cuda/kernel_utils.cuh @@ -89,7 +89,8 @@ using cuda_type_t = typename CTypeToCudaType::type; template inline constexpr bool is_floating_v = cuda::std::is_same_v || cuda::std::is_same_v || - cuda::std::is_same_v || cuda::std::is_same_v; + cuda::std::is_same_v || cuda::std::is_same_v || + cuda::std::is_same_v || cuda::std::is_same_v; // Type traits for detecting complex numbers. template diff --git a/mlx/backend/cuda/sort.cu b/mlx/backend/cuda/sort.cu index bc4524ae12..655462a2cf 100644 --- a/mlx/backend/cuda/sort.cu +++ b/mlx/backend/cuda/sort.cu @@ -52,7 +52,7 @@ struct InitValue { }; template -struct InitValue>> { +struct InitValue>> { __device__ __forceinline__ static T value() { return nan_value(); } @@ -72,7 +72,7 @@ struct LessThan { } __device__ __forceinline__ bool operator()(T a, T b) const { - if constexpr (std::is_floating_point_v) { + if constexpr (is_floating_v) { bool an = cuda::std::isnan(a); bool bn = cuda::std::isnan(b); if (an | bn) { diff --git a/python/tests/test_ops.py b/python/tests/test_ops.py index 790689e7f7..57443479a8 100644 --- a/python/tests/test_ops.py +++ b/python/tests/test_ops.py @@ -3311,11 +3311,23 @@ def test_broadcast_shapes(self): mx.broadcast_shapes() def test_sort_nan(self): - x = mx.array([3.0, mx.nan, 2.0, 0.0]) - expected = mx.array([0.0, 2.0, 3.0, mx.nan]) - self.assertTrue(mx.array_equal(mx.sort(x), expected, equal_nan=True)) + for dtype in [mx.float32, mx.float16, mx.bfloat16]: + with self.subTest(dtype=dtype): + x = mx.array([3.0, mx.nan, 2.0, 0.0], dtype=dtype) + expected = mx.array([0.0, 2.0, 3.0, mx.nan], dtype=dtype) + self.assertTrue(mx.array_equal(mx.sort(x), expected, equal_nan=True)) + x = mx.array([3.0, mx.nan, 2.0, 0.0]) + 1j * mx.array([1.0] * 4) + def test_argsort_nan(self): + for dtype in [mx.float32, mx.float16, mx.bfloat16]: + with self.subTest(dtype=dtype): + x = mx.array([3.0, mx.nan, 2.0, 0.0], dtype=dtype) + expected = mx.array([0.0, 2.0, 3.0, mx.nan], dtype=dtype) + indices = mx.argsort(x) + sorted_x = mx.take(x, indices) + self.assertTrue(mx.array_equal(sorted_x, expected, equal_nan=True)) + def test_to_from_fp8(self): vals = mx.array( [448, 256, 192, 128, 96, 64, 48, 32, 24, 16, 12, 8, 6, 4, 3, 2, 0.015625]