Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions mlx/backend/cpu/sort.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,14 @@ namespace mlx::core {

namespace {

template <typename T>
inline constexpr bool is_floating_v = std::is_floating_point_v<T> ||
std::is_same_v<T, float16_t> || std::is_same_v<T, bfloat16_t>;

// NaN-aware comparator that places NaNs at the end
template <typename T>
bool nan_aware_less(T a, T b) {
if constexpr (std::is_floating_point_v<T> || std::is_same_v<T, complex64_t>) {
if constexpr (is_floating_v<T> || std::is_same_v<T, complex64_t>) {
if (std::isnan(a))
return false;
if (std::isnan(b))
Expand Down Expand Up @@ -198,7 +202,7 @@ void argsort(const array& in, array& out, int axis) {
auto v2 = data_ptr[b * in_stride];

// Handle NaNs (place them at the end)
if (std::is_floating_point<T>::value) {
if constexpr (is_floating_v<T>) {
if (std::isnan(v1))
return false;
if (std::isnan(v2))
Expand Down Expand Up @@ -299,7 +303,7 @@ void argpartition(const array& in, array& out, int axis, int kth) {
auto v2 = data_ptr[b * in_stride];

// Handle NaNs (place them at the end)
if (std::is_floating_point<T>::value) {
if constexpr (is_floating_v<T>) {
if (std::isnan(v1))
return false;
if (std::isnan(v2))
Expand Down
3 changes: 2 additions & 1 deletion mlx/backend/cuda/kernel_utils.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,8 @@ using cuda_type_t = typename CTypeToCudaType<T>::type;
template <typename T>
inline constexpr bool is_floating_v =
cuda::std::is_same_v<T, float> || cuda::std::is_same_v<T, double> ||
cuda::std::is_same_v<T, float16_t> || cuda::std::is_same_v<T, bfloat16_t>;
cuda::std::is_same_v<T, float16_t> || cuda::std::is_same_v<T, bfloat16_t> ||
cuda::std::is_same_v<T, __half> || cuda::std::is_same_v<T, __nv_bfloat16>;

// Type traits for detecting complex numbers.
template <typename T>
Expand Down
4 changes: 2 additions & 2 deletions mlx/backend/cuda/sort.cu
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ struct InitValue {
};

template <typename T>
struct InitValue<T, cuda::std::enable_if_t<std::is_floating_point_v<T>>> {
struct InitValue<T, cuda::std::enable_if_t<is_floating_v<T>>> {
__device__ __forceinline__ static T value() {
return nan_value<T>();
}
Expand All @@ -72,7 +72,7 @@ struct LessThan {
}

__device__ __forceinline__ bool operator()(T a, T b) const {
if constexpr (std::is_floating_point_v<T>) {
if constexpr (is_floating_v<T>) {
bool an = cuda::std::isnan(a);
bool bn = cuda::std::isnan(b);
if (an | bn) {
Expand Down
18 changes: 15 additions & 3 deletions python/tests/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -3311,11 +3311,23 @@ def test_broadcast_shapes(self):
mx.broadcast_shapes()

def test_sort_nan(self):
x = mx.array([3.0, mx.nan, 2.0, 0.0])
expected = mx.array([0.0, 2.0, 3.0, mx.nan])
self.assertTrue(mx.array_equal(mx.sort(x), expected, equal_nan=True))
for dtype in [mx.float32, mx.float16, mx.bfloat16]:
with self.subTest(dtype=dtype):
x = mx.array([3.0, mx.nan, 2.0, 0.0], dtype=dtype)
expected = mx.array([0.0, 2.0, 3.0, mx.nan], dtype=dtype)
self.assertTrue(mx.array_equal(mx.sort(x), expected, equal_nan=True))

x = mx.array([3.0, mx.nan, 2.0, 0.0]) + 1j * mx.array([1.0] * 4)

def test_argsort_nan(self):
for dtype in [mx.float32, mx.float16, mx.bfloat16]:
with self.subTest(dtype=dtype):
x = mx.array([3.0, mx.nan, 2.0, 0.0], dtype=dtype)
expected = mx.array([0.0, 2.0, 3.0, mx.nan], dtype=dtype)
indices = mx.argsort(x)
sorted_x = mx.take(x, indices)
self.assertTrue(mx.array_equal(sorted_x, expected, equal_nan=True))

def test_to_from_fp8(self):
vals = mx.array(
[448, 256, 192, 128, 96, 64, 48, 32, 24, 16, 12, 8, 6, 4, 3, 2, 0.015625]
Expand Down
Loading