From efc8a501e4a6dabfb1236fed7f2c89f59b56d9ef Mon Sep 17 00:00:00 2001 From: AN Long Date: Sun, 5 Oct 2025 16:50:30 +0900 Subject: [PATCH 1/3] Check isnan in maximum / minimum with CPU backend --- mlx/backend/cpu/simd/accelerate_simd.h | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/mlx/backend/cpu/simd/accelerate_simd.h b/mlx/backend/cpu/simd/accelerate_simd.h index c89a104a04..fad1b717b5 100644 --- a/mlx/backend/cpu/simd/accelerate_simd.h +++ b/mlx/backend/cpu/simd/accelerate_simd.h @@ -217,13 +217,21 @@ Simd atan2(Simd a, Simd b) { template Simd maximum(Simd a, Simd b) { - // TODO add isnan + if constexpr (!std::is_integral_v) { + auto a_is_nan = isnan(a); + return select(a_is_nan, a, Simd(asd::max(a.value, b.value))); + } + return asd::max(a.value, b.value); } template Simd minimum(Simd a, Simd b) { - // TODO add isnan + if constexpr (!std::is_integral_v) { + auto a_is_nan = isnan(a); + return select(a_is_nan, a, Simd(asd::min(a.value, b.value))); + } + return asd::min(a.value, b.value); } From 96399d8b4a987f5624275403980200ee55214237 Mon Sep 17 00:00:00 2001 From: AN Long Date: Mon, 13 Oct 2025 15:21:57 +0900 Subject: [PATCH 2/3] Add tests --- tests/ops_tests.cpp | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/tests/ops_tests.cpp b/tests/ops_tests.cpp index c473b59c3a..3673e35096 100644 --- a/tests/ops_tests.cpp +++ b/tests/ops_tests.cpp @@ -1059,6 +1059,30 @@ TEST_CASE("test reduction ops") { x = array({1.0f, NAN, 3.0f, 4.0f, 5.0f, 6.0f}, {2, 3}); CHECK(array_equal(max(x, 0), array({4.0f, NAN, 6.0f}), true).item()); CHECK(array_equal(max(x, 1), array({NAN, 6.0f}), true).item()); + + // Test maximum and minimum with NaN values + x = array({1.0f, NAN, 3.0f}); + auto y = array({NAN, 2.0f, 1.0f}); + auto max_result = maximum(x, y); + auto min_result = minimum(x, y); + CHECK(array_equal(max_result, array({NAN, NAN, 3.0f}), true).item()); + CHECK(array_equal(min_result, array({NAN, NAN, 1.0f}), true).item()); + + // Test with all NaN values + x = array({NAN, NAN, NAN}); + y = array({NAN, NAN, NAN}); + max_result = maximum(x, y); + min_result = minimum(x, y); + CHECK(array_equal(max_result, array({NAN, NAN, NAN}), true).item()); + CHECK(array_equal(min_result, array({NAN, NAN, NAN}), true).item()); + + // Test broadcasting with NaN + x = array({1.0f, NAN}); + y = array({2.0f}); + max_result = maximum(x, y); + min_result = minimum(x, y); + CHECK(array_equal(max_result, array({2.0f, NAN}), true).item()); + CHECK(array_equal(min_result, array({1.0f, NAN}), true).item()); } // Test logsumexp From 41ffce37643ca018d3292f86de098bee6740bde3 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Mon, 3 Nov 2025 07:12:45 -0800 Subject: [PATCH 3/3] fix --- mlx/backend/cpu/simd/accelerate_simd.h | 14 ++++---- tests/ops_tests.cpp | 45 ++++++++++++-------------- 2 files changed, 27 insertions(+), 32 deletions(-) diff --git a/mlx/backend/cpu/simd/accelerate_simd.h b/mlx/backend/cpu/simd/accelerate_simd.h index fad1b717b5..f62c67d38b 100644 --- a/mlx/backend/cpu/simd/accelerate_simd.h +++ b/mlx/backend/cpu/simd/accelerate_simd.h @@ -217,22 +217,20 @@ Simd atan2(Simd a, Simd b) { template Simd maximum(Simd a, Simd b) { + auto out = Simd(asd::max(a.value, b.value)); if constexpr (!std::is_integral_v) { - auto a_is_nan = isnan(a); - return select(a_is_nan, a, Simd(asd::max(a.value, b.value))); + out = select(isnan(b), b, select(isnan(a), a, out)); } - - return asd::max(a.value, b.value); + return out; } template Simd minimum(Simd a, Simd b) { + auto out = Simd(asd::min(a.value, b.value)); if constexpr (!std::is_integral_v) { - auto a_is_nan = isnan(a); - return select(a_is_nan, a, Simd(asd::min(a.value, b.value))); + out = select(isnan(b), b, select(isnan(a), a, out)); } - - return asd::min(a.value, b.value); + return out; } template diff --git a/tests/ops_tests.cpp b/tests/ops_tests.cpp index 3673e35096..1b95066228 100644 --- a/tests/ops_tests.cpp +++ b/tests/ops_tests.cpp @@ -1059,30 +1059,6 @@ TEST_CASE("test reduction ops") { x = array({1.0f, NAN, 3.0f, 4.0f, 5.0f, 6.0f}, {2, 3}); CHECK(array_equal(max(x, 0), array({4.0f, NAN, 6.0f}), true).item()); CHECK(array_equal(max(x, 1), array({NAN, 6.0f}), true).item()); - - // Test maximum and minimum with NaN values - x = array({1.0f, NAN, 3.0f}); - auto y = array({NAN, 2.0f, 1.0f}); - auto max_result = maximum(x, y); - auto min_result = minimum(x, y); - CHECK(array_equal(max_result, array({NAN, NAN, 3.0f}), true).item()); - CHECK(array_equal(min_result, array({NAN, NAN, 1.0f}), true).item()); - - // Test with all NaN values - x = array({NAN, NAN, NAN}); - y = array({NAN, NAN, NAN}); - max_result = maximum(x, y); - min_result = minimum(x, y); - CHECK(array_equal(max_result, array({NAN, NAN, NAN}), true).item()); - CHECK(array_equal(min_result, array({NAN, NAN, NAN}), true).item()); - - // Test broadcasting with NaN - x = array({1.0f, NAN}); - y = array({2.0f}); - max_result = maximum(x, y); - min_result = minimum(x, y); - CHECK(array_equal(max_result, array({2.0f, NAN}), true).item()); - CHECK(array_equal(min_result, array({1.0f, NAN}), true).item()); } // Test logsumexp @@ -4076,3 +4052,24 @@ TEST_CASE("test fp8 conversion") { auto expected = array({-448.0f, 448.0f}); CHECK(array_equal(out, expected, true).item()); } + +TEST_CASE("test max min with nan") { + // Test maximum and minimum with NaN values + auto x = array({0.0f, 1.0f, NAN, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f}); + auto y = array({NAN, 1.0f, NAN, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f}); + auto expected_max = array({NAN, 1.0f, NAN, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f}); + auto expected_min = array({NAN, 1.0f, NAN, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f}); + auto max_result = maximum(x, y); + auto min_result = minimum(x, y); + CHECK(array_equal(max_result, expected_max, true).item()); + CHECK(array_equal(min_result, expected_min, true).item()); + + // Test with all NaN values + x = array({NAN, NAN, NAN, NAN, NAN, NAN, NAN, NAN}); + y = array({NAN, NAN, NAN, NAN, NAN, NAN, NAN, NAN}); + max_result = maximum(x, y); + min_result = minimum(x, y); + auto expected = array({NAN, NAN, NAN, NAN, NAN, NAN, NAN, NAN}); + CHECK(array_equal(max_result, expected, true).item()); + CHECK(array_equal(min_result, expected, true).item()); +}