From d278c2a5489da2d2344f5abe4ee221b0680870c4 Mon Sep 17 00:00:00 2001 From: Robert Johansson Date: Sat, 21 Mar 2026 21:31:21 +0100 Subject: [PATCH 1/4] Fix vmap + floor_divide: preserve integer dtype MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Divide::vmap called divide() which promotes integers to float via at_least_float. But floor_divide creates a Divide primitive with integer dtype for integer division. The vmap of this primitive should preserve integer semantics. Fix: in Divide::vmap, check if inputs are integer type and construct the Divide primitive directly (preserving dtype) instead of calling the high-level divide() which promotes to float. Before: vmap(floor_divide(12, 5)) → 2.4 (float32) After: vmap(floor_divide(12, 5)) → 2 (int32) Co-Authored-By: Claude Opus 4.6 (1M context) --- mlx/primitives.cpp | 16 +++++++++++++++ tests/vmap_tests.cpp | 49 ++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 65 insertions(+) diff --git a/mlx/primitives.cpp b/mlx/primitives.cpp index fa9e55700e..35e39605c0 100644 --- a/mlx/primitives.cpp +++ b/mlx/primitives.cpp @@ -1789,6 +1789,22 @@ std::pair, std::vector> Divide::vmap( const std::vector& inputs, const std::vector& axes) { auto [a, b, to_ax] = vmap_binary_op(inputs, axes, stream()); + // The high-level divide() promotes integers to float via at_least_float. + // But floor_divide() creates a Divide primitive with integer dtype + // for integer division. Preserve that: if inputs are integer, keep + // integer semantics by constructing the primitive directly. + auto dtype = promote_types(a.dtype(), b.dtype()); + if (issubdtype(dtype, integer)) { + auto bcast = broadcast_arrays( + {astype(a, dtype, stream()), astype(b, dtype, stream())}, stream()); + return { + {array( + bcast[0].shape(), + dtype, + std::make_shared(stream()), + std::move(bcast))}, + {to_ax}}; + } return {{divide(a, b, stream())}, {to_ax}}; } diff --git a/tests/vmap_tests.cpp b/tests/vmap_tests.cpp index 2a2a285713..0b2d4fadb5 100644 --- a/tests/vmap_tests.cpp +++ b/tests/vmap_tests.cpp @@ -545,3 +545,52 @@ TEST_CASE("test vmap dynamic slices") { CHECK(array_equal(out, array({0, 0, 1, 1}, {2, 2})).item()); } } + +TEST_CASE("test vmap floor_divide integer") { + // floor_divide with integer inputs should preserve integer dtype under vmap. + // Bug: Divide::vmap called divide() which promotes integers to float. + { + auto x = arange(0, 25, int32); + auto divisor = array(5, int32); + + // Without vmap: floor_divide returns int32 + auto expected = floor_divide(x, divisor); + CHECK_EQ(expected.dtype(), int32); + + // With vmap: should also return int32 + auto vfun = vmap([&divisor](array s) { + return floor_divide(s, divisor); + }); + auto result = vfun(x); + CHECK_EQ(result.dtype(), int32); + CHECK(array_equal(result, expected).item()); + } + + // Also check remainder preserves integer dtype under vmap + { + auto x = arange(0, 10, int32); + auto divisor = array(3, int32); + + auto expected = remainder(x, divisor); + auto vfun = vmap([&divisor](array s) { + return remainder(s, divisor); + }); + auto result = vfun(x); + CHECK_EQ(result.dtype(), int32); + CHECK(array_equal(result, expected).item()); + } + + // floor_divide + remainder: should reconstruct original + { + auto x = arange(0, 25, int32); + auto w = array(5, int32); + + auto vfun = vmap([&w](array s) { + auto q = floor_divide(s, w); + auto r = remainder(s, w); + return add(multiply(q, w), r); + }); + auto result = vfun(x); + CHECK(array_equal(result, x).item()); + } +} From 10eb2cede4b7c1a2d78fd6c0862b986510f28f0b Mon Sep 17 00:00:00 2001 From: Robert Johansson Date: Sun, 22 Mar 2026 14:38:58 +0100 Subject: [PATCH 2/4] Fix clang-format: single-line short lambdas in tests Co-Authored-By: Claude Opus 4.6 (1M context) --- tests/vmap_tests.cpp | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/tests/vmap_tests.cpp b/tests/vmap_tests.cpp index 0b2d4fadb5..361982848c 100644 --- a/tests/vmap_tests.cpp +++ b/tests/vmap_tests.cpp @@ -558,9 +558,7 @@ TEST_CASE("test vmap floor_divide integer") { CHECK_EQ(expected.dtype(), int32); // With vmap: should also return int32 - auto vfun = vmap([&divisor](array s) { - return floor_divide(s, divisor); - }); + auto vfun = vmap([&divisor](array s) { return floor_divide(s, divisor); }); auto result = vfun(x); CHECK_EQ(result.dtype(), int32); CHECK(array_equal(result, expected).item()); @@ -572,9 +570,7 @@ TEST_CASE("test vmap floor_divide integer") { auto divisor = array(3, int32); auto expected = remainder(x, divisor); - auto vfun = vmap([&divisor](array s) { - return remainder(s, divisor); - }); + auto vfun = vmap([&divisor](array s) { return remainder(s, divisor); }); auto result = vfun(x); CHECK_EQ(result.dtype(), int32); CHECK(array_equal(result, expected).item()); From 7f89bfaa09bbca716c437023ae90692173210329 Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Tue, 24 Mar 2026 15:06:32 -0700 Subject: [PATCH 3/4] Simplify vmap logic --- mlx/primitives.cpp | 16 +++------------- 1 file changed, 3 insertions(+), 13 deletions(-) diff --git a/mlx/primitives.cpp b/mlx/primitives.cpp index 35e39605c0..09d492f1d9 100644 --- a/mlx/primitives.cpp +++ b/mlx/primitives.cpp @@ -1793,19 +1793,9 @@ std::pair, std::vector> Divide::vmap( // But floor_divide() creates a Divide primitive with integer dtype // for integer division. Preserve that: if inputs are integer, keep // integer semantics by constructing the primitive directly. - auto dtype = promote_types(a.dtype(), b.dtype()); - if (issubdtype(dtype, integer)) { - auto bcast = broadcast_arrays( - {astype(a, dtype, stream()), astype(b, dtype, stream())}, stream()); - return { - {array( - bcast[0].shape(), - dtype, - std::make_shared(stream()), - std::move(bcast))}, - {to_ax}}; - } - return {{divide(a, b, stream())}, {to_ax}}; + auto out = issubdtype(a.dtype(), integer) ? floor_divide(a, b, stream()) + : divide(a, b, stream()); + return {{out}, {to_ax}}; } std::vector Remainder::vjp( From 162b79028113e3c73cb0e42ac2496c4bcdad6eac Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Tue, 24 Mar 2026 15:07:26 -0700 Subject: [PATCH 4/4] Clean up comments in Divide::vmap function --- mlx/primitives.cpp | 4 ---- 1 file changed, 4 deletions(-) diff --git a/mlx/primitives.cpp b/mlx/primitives.cpp index 09d492f1d9..2e4767c468 100644 --- a/mlx/primitives.cpp +++ b/mlx/primitives.cpp @@ -1789,10 +1789,6 @@ std::pair, std::vector> Divide::vmap( const std::vector& inputs, const std::vector& axes) { auto [a, b, to_ax] = vmap_binary_op(inputs, axes, stream()); - // The high-level divide() promotes integers to float via at_least_float. - // But floor_divide() creates a Divide primitive with integer dtype - // for integer division. Preserve that: if inputs are integer, keep - // integer semantics by constructing the primitive directly. auto out = issubdtype(a.dtype(), integer) ? floor_divide(a, b, stream()) : divide(a, b, stream()); return {{out}, {to_ax}};