From 63b73e7a78598ef7e1e1c30e8500590f495db6f8 Mon Sep 17 00:00:00 2001 From: Sheldon Aristide Date: Tue, 24 Mar 2026 17:41:28 -0400 Subject: [PATCH] Fix moved-from shape bug in broadcast_arrays causing vmap bus error --- mlx/ops.cpp | 2 +- python/tests/test_vmap.py | 6 ++++++ tests/vmap_tests.cpp | 12 ++++++++++++ 3 files changed, 19 insertions(+), 1 deletion(-) diff --git a/mlx/ops.cpp b/mlx/ops.cpp index 935bbd6efe..ef792cd6f4 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -1659,7 +1659,7 @@ std::vector broadcast_arrays( outputs.push_back(in); } else { outputs.push_back(array( - std::move(out_shape), + out_shape, in.dtype(), std::make_shared(to_stream(s), out_shape), {in})); diff --git a/python/tests/test_vmap.py b/python/tests/test_vmap.py index 7847a9a60e..30e1bd8fb2 100644 --- a/python/tests/test_vmap.py +++ b/python/tests/test_vmap.py @@ -595,6 +595,12 @@ def fun(a, idx): out = mx.vmap(fun, in_axes=(None, 0))(a, idx) self.assertEqual(out.shape, (4, 2, 1)) + a = mx.zeros((4, 5, 3)) + idx = mx.zeros((2, 2, 1, 3), mx.int32) + + out = mx.vmap(fun, in_axes=(None, 0))(a, idx) + self.assertEqual(out.shape, (2, 2, 5, 3)) + def test_vmap_put_along_axis(self): a = mx.zeros((4, 5, 1)) idx = mx.ones((2, 4, 1), mx.int32) diff --git a/tests/vmap_tests.cpp b/tests/vmap_tests.cpp index 2a2a285713..a38b3a3c68 100644 --- a/tests/vmap_tests.cpp +++ b/tests/vmap_tests.cpp @@ -392,6 +392,18 @@ TEST_CASE("test vmap gather") { } } +TEST_CASE("test vmap take_along_axis with unmapped input and mapped index") { + auto fun = [](std::vector inputs) { + return std::vector{take_along_axis(inputs[0], inputs[1], 0)}; + }; + + auto a = reshape(arange(60), {4, 5, 3}); + auto idx = zeros({2, 2, 1, 3}, int32); + + auto out = vmap(fun, {-1, 0})({a, idx})[0]; + CHECK_EQ(out.shape(), Shape{2, 2, 5, 3}); +} + TEST_CASE("test vmap scatter") { auto make_scatter_fn = [](const std::vector& indices, const array& updates,