diff --git a/mlx/primitives.cpp b/mlx/primitives.cpp index fa9e55700e..2e4767c468 100644 --- a/mlx/primitives.cpp +++ b/mlx/primitives.cpp @@ -1789,7 +1789,9 @@ std::pair, std::vector> Divide::vmap( const std::vector& inputs, const std::vector& axes) { auto [a, b, to_ax] = vmap_binary_op(inputs, axes, stream()); - return {{divide(a, b, stream())}, {to_ax}}; + auto out = issubdtype(a.dtype(), integer) ? floor_divide(a, b, stream()) + : divide(a, b, stream()); + return {{out}, {to_ax}}; } std::vector Remainder::vjp( diff --git a/tests/vmap_tests.cpp b/tests/vmap_tests.cpp index 2a2a285713..361982848c 100644 --- a/tests/vmap_tests.cpp +++ b/tests/vmap_tests.cpp @@ -545,3 +545,48 @@ TEST_CASE("test vmap dynamic slices") { CHECK(array_equal(out, array({0, 0, 1, 1}, {2, 2})).item()); } } + +TEST_CASE("test vmap floor_divide integer") { + // floor_divide with integer inputs should preserve integer dtype under vmap. + // Bug: Divide::vmap called divide() which promotes integers to float. + { + auto x = arange(0, 25, int32); + auto divisor = array(5, int32); + + // Without vmap: floor_divide returns int32 + auto expected = floor_divide(x, divisor); + CHECK_EQ(expected.dtype(), int32); + + // With vmap: should also return int32 + auto vfun = vmap([&divisor](array s) { return floor_divide(s, divisor); }); + auto result = vfun(x); + CHECK_EQ(result.dtype(), int32); + CHECK(array_equal(result, expected).item()); + } + + // Also check remainder preserves integer dtype under vmap + { + auto x = arange(0, 10, int32); + auto divisor = array(3, int32); + + auto expected = remainder(x, divisor); + auto vfun = vmap([&divisor](array s) { return remainder(s, divisor); }); + auto result = vfun(x); + CHECK_EQ(result.dtype(), int32); + CHECK(array_equal(result, expected).item()); + } + + // floor_divide + remainder: should reconstruct original + { + auto x = arange(0, 25, int32); + auto w = array(5, int32); + + auto vfun = vmap([&w](array s) { + auto q = floor_divide(s, w); + auto r = remainder(s, w); + return add(multiply(q, w), r); + }); + auto result = vfun(x); + CHECK(array_equal(result, x).item()); + } +}