From 2b7aec76b05c39728631e53b212e66f4db529641 Mon Sep 17 00:00:00 2001 From: Sheldon Aristide Date: Tue, 24 Mar 2026 16:37:36 -0400 Subject: [PATCH 1/3] Implement BroadcastAxes::vmap --- mlx/primitives.cpp | 76 ++++++++++++++++++++++++++++++++++++++- python/tests/test_vmap.py | 67 ++++++++++++++++++++++++++++++++++ tests/vmap_tests.cpp | 28 +++++++++++++++ 3 files changed, 170 insertions(+), 1 deletion(-) diff --git a/mlx/primitives.cpp b/mlx/primitives.cpp index fa9e55700e..30fa595a43 100644 --- a/mlx/primitives.cpp +++ b/mlx/primitives.cpp @@ -897,7 +897,81 @@ std::vector BroadcastAxes::jvp( std::pair, std::vector> BroadcastAxes::vmap( const std::vector& inputs, const std::vector& axes) { - throw std::invalid_argument("[BroadcastAxes] VMAP NYI"); + assert(inputs.size() == axes.size()); + assert(!inputs.empty()); + + if (std::all_of(axes.begin(), axes.end(), [](int ax) { return ax == -1; })) { + return { + {array( + output_shape(inputs, ignore_axes_), + inputs[0].dtype(), + std::make_shared(stream(), ignore_axes_), + inputs)}, + {-1}}; + } + + int ndim = 0; + for (int i = 0; i < inputs.size(); ++i) { + ndim = std::max(ndim, static_cast(inputs[i].ndim()) + (axes[i] == -1)); + } + + auto expand_dims = [this, ndim](const array& in) { + auto shape = in.shape(); + shape.insert(shape.begin(), ndim - shape.size(), 1); + return reshape(in, std::move(shape), stream()); + }; + + auto aligned_inputs = inputs; + int to_ax = (ndim - static_cast(inputs[0].ndim())) + axes[0]; + if (to_ax < 0 || to_ax >= ndim) { + throw std::invalid_argument( + "[BroadcastAxes::vmap] Received invalid vmapped axis."); + } + for (int i = 0; i < aligned_inputs.size(); ++i) { + int from_ax = (ndim - static_cast(inputs[i].ndim())) + axes[i]; + if (from_ax < 0 || from_ax >= ndim) { + throw std::invalid_argument( + "[BroadcastAxes::vmap] Received invalid vmapped axis."); + } + aligned_inputs[i] = expand_dims(inputs[i]); + + if (from_ax != to_ax) { + std::vector tdims(aligned_inputs[i].ndim()); + std::iota(tdims.begin(), tdims.end(), 0); + tdims.erase(tdims.begin() + from_ax); + tdims.insert(tdims.begin() + to_ax, from_ax); + aligned_inputs[i] = transpose(aligned_inputs[i], tdims, stream()); + } + } + + int prefix = ndim - static_cast(inputs[0].ndim()); + int unbatched_ndim = static_cast(inputs[0].ndim()) - (axes[0] >= 0); + std::vector ignore_axes; + ignore_axes.reserve(ignore_axes_.size()); + for (auto ax : ignore_axes_) { + auto pos_ax = unbatched_ndim + ax; + if (pos_ax < 0 || pos_ax >= unbatched_ndim) { + throw std::invalid_argument( + "[BroadcastAxes::vmap] Invalid axis in ignore_axes."); + } + if (axes[0] >= 0 && pos_ax >= axes[0]) { + pos_ax++; + } + pos_ax += prefix; + if (pos_ax < 0 || pos_ax >= ndim) { + throw std::invalid_argument( + "[BroadcastAxes::vmap] Invalid axis in ignore_axes."); + } + ignore_axes.push_back(pos_ax - ndim); + } + + return { + {array( + output_shape(aligned_inputs, ignore_axes), + aligned_inputs[0].dtype(), + std::make_shared(stream(), ignore_axes), + std::move(aligned_inputs))}, + {to_ax}}; } bool BroadcastAxes::is_equivalent(const Primitive& other) const { diff --git a/python/tests/test_vmap.py b/python/tests/test_vmap.py index 7847a9a60e..3f002bba16 100644 --- a/python/tests/test_vmap.py +++ b/python/tests/test_vmap.py @@ -242,6 +242,30 @@ def test_vmap_reduce(self): out = mx.vmap(lambda x: x.sum(axis=(0, 1)), in_axes=(2,))(a) self.assertTrue(mx.array_equal(out, mx.full((2,), 15))) + def test_vmap_broadcast_to(self): + x = mx.arange(2 * 3 * 1 * 5).reshape(2, 3, 1, 5) + + out = mx.vmap(lambda a: mx.broadcast_to(a, (3, 4, 5)), in_axes=0)(x) + expected = mx.stack( + [mx.broadcast_to(x[i], (3, 4, 5)) for i in range(x.shape[0])] + ) + self.assertTrue(mx.array_equal(out, expected)) + + out = mx.vmap(lambda a: mx.broadcast_to(a, (2, 4, 5)), in_axes=1)(x) + expected = mx.stack( + [mx.broadcast_to(x[:, i, :, :], (2, 4, 5)) for i in range(x.shape[1])] + ) + self.assertTrue(mx.array_equal(out, expected)) + + out = mx.vmap(lambda a: mx.broadcast_to(a, (2, 3, 4)), in_axes=-1, out_axes=-1)( + x + ) + expected = mx.stack( + [mx.broadcast_to(x[:, :, :, i], (2, 3, 4)) for i in range(x.shape[-1])], + axis=-1, + ) + self.assertTrue(mx.array_equal(out, expected)) + def test_vmap_argreduce(self): a = mx.array([[1, 2, 3], [2, 3, 1]]) out = mx.vmap(lambda x: mx.argmin(x))(a) @@ -595,6 +619,25 @@ def fun(a, idx): out = mx.vmap(fun, in_axes=(None, 0))(a, idx) self.assertEqual(out.shape, (4, 2, 1)) + a = mx.arange(3 * 2 * 5 * 4).reshape(3, 2, 5, 4) + idx = mx.zeros((3, 2, 1, 4), mx.int32) + out = mx.vmap(lambda x, y: mx.take_along_axis(x, y, axis=1), in_axes=(0, 0))( + a, idx + ) + expected = mx.stack( + [mx.take_along_axis(a[i], idx[i], axis=1) for i in range(a.shape[0])] + ) + self.assertTrue(mx.array_equal(out, expected)) + + idx = mx.zeros((3, 2, 5, 1), mx.int32) + out = mx.vmap(lambda x, y: mx.take_along_axis(x, y, axis=-1), in_axes=(0, 0))( + a, idx + ) + expected = mx.stack( + [mx.take_along_axis(a[i], idx[i], axis=-1) for i in range(a.shape[0])] + ) + self.assertTrue(mx.array_equal(out, expected)) + def test_vmap_put_along_axis(self): a = mx.zeros((4, 5, 1)) idx = mx.ones((2, 4, 1), mx.int32) @@ -621,6 +664,30 @@ def fun(a, idx, upd): out = mx.vmap(fun, in_axes=(None, 1, 1))(a, idx, upd) self.assertEqual(out.shape, (4, 5, 1)) + a = mx.zeros((3, 2, 5, 4)) + idx = mx.zeros((3, 2, 1, 4), mx.int32) + upd = mx.ones((3, 2, 1, 4)) + out = mx.vmap( + lambda x, y, z: mx.put_along_axis(x, y, z, axis=1), in_axes=(0, 0, 0) + )(a, idx, upd) + expected = mx.stack( + [mx.put_along_axis(a[i], idx[i], upd[i], axis=1) for i in range(a.shape[0])] + ) + self.assertTrue(mx.array_equal(out, expected)) + + idx = mx.zeros((3, 2, 5, 1), mx.int32) + upd = mx.ones((3, 2, 5, 1)) + out = mx.vmap( + lambda x, y, z: mx.put_along_axis(x, y, z, axis=-1), in_axes=(0, 0, 0) + )(a, idx, upd) + expected = mx.stack( + [ + mx.put_along_axis(a[i], idx[i], upd[i], axis=-1) + for i in range(a.shape[0]) + ] + ) + self.assertTrue(mx.array_equal(out, expected)) + def test_vmap_split_vmap(self): def fun(x): a, b = mx.split(x, 2, 1) diff --git a/tests/vmap_tests.cpp b/tests/vmap_tests.cpp index 2a2a285713..0e3f0c3f1e 100644 --- a/tests/vmap_tests.cpp +++ b/tests/vmap_tests.cpp @@ -3,6 +3,7 @@ #include "doctest/doctest.h" #include "mlx/mlx.h" +#include "mlx/primitives.h" using namespace mlx::core; @@ -237,6 +238,33 @@ TEST_CASE("test vmap with eval") { CHECK_THROWS(vmap(fun2)({x, y})); } +TEST_CASE("test vmap broadcast axes primitive") { + auto s = default_stream(default_device()); + { + auto p = BroadcastAxes(s, {-1}); + auto x = reshape(arange(2 * 3 * 1 * 5, float32, s), {2, 3, 1, 5}, s); + auto y = zeros({1, 2, 4, 5}, float32, s); + + auto [out, out_axes] = p.vmap({x, y}, {0, 1}); + auto expected = broadcast_to(x, {2, 3, 4, 5}, s); + CHECK_EQ(out_axes.size(), 1); + CHECK_EQ(out_axes[0], 0); + CHECK(array_equal(out[0], expected).item()); + } + + { + auto p = BroadcastAxes(s, {-1}); + auto x = reshape(arange(3 * 1 * 5, float32, s), {3, 1, 5}, s); + auto y = zeros({2, 1, 4, 5}, float32, s); + + auto [out, out_axes] = p.vmap({x, y}, {-1, 0}); + auto expected = broadcast_to(x, {2, 3, 4, 5}, s); + CHECK_EQ(out_axes.size(), 1); + CHECK_EQ(out_axes[0], 0); + CHECK(array_equal(out[0], expected).item()); + } +} + TEST_CASE("test vmap comparison ops") { // vmap equal { From ede24201dfc627a339651fcbd5ba483705a94672 Mon Sep 17 00:00:00 2001 From: Sheldon Aristide Date: Wed, 25 Mar 2026 09:00:31 -0400 Subject: [PATCH 2/3] Implement BroadcastAxes::vmap --- mlx/primitives.cpp | 16 ---------------- 1 file changed, 16 deletions(-) diff --git a/mlx/primitives.cpp b/mlx/primitives.cpp index 30fa595a43..7f52be9c2b 100644 --- a/mlx/primitives.cpp +++ b/mlx/primitives.cpp @@ -923,16 +923,8 @@ std::pair, std::vector> BroadcastAxes::vmap( auto aligned_inputs = inputs; int to_ax = (ndim - static_cast(inputs[0].ndim())) + axes[0]; - if (to_ax < 0 || to_ax >= ndim) { - throw std::invalid_argument( - "[BroadcastAxes::vmap] Received invalid vmapped axis."); - } for (int i = 0; i < aligned_inputs.size(); ++i) { int from_ax = (ndim - static_cast(inputs[i].ndim())) + axes[i]; - if (from_ax < 0 || from_ax >= ndim) { - throw std::invalid_argument( - "[BroadcastAxes::vmap] Received invalid vmapped axis."); - } aligned_inputs[i] = expand_dims(inputs[i]); if (from_ax != to_ax) { @@ -950,18 +942,10 @@ std::pair, std::vector> BroadcastAxes::vmap( ignore_axes.reserve(ignore_axes_.size()); for (auto ax : ignore_axes_) { auto pos_ax = unbatched_ndim + ax; - if (pos_ax < 0 || pos_ax >= unbatched_ndim) { - throw std::invalid_argument( - "[BroadcastAxes::vmap] Invalid axis in ignore_axes."); - } if (axes[0] >= 0 && pos_ax >= axes[0]) { pos_ax++; } pos_ax += prefix; - if (pos_ax < 0 || pos_ax >= ndim) { - throw std::invalid_argument( - "[BroadcastAxes::vmap] Invalid axis in ignore_axes."); - } ignore_axes.push_back(pos_ax - ndim); } From 9edc6fc9025e4093062333fd1b77df4dc45cbab3 Mon Sep 17 00:00:00 2001 From: Sheldon Aristide Date: Thu, 26 Mar 2026 20:49:47 -0400 Subject: [PATCH 3/3] Add comment for BroadcastAxes ignore_axes remap --- mlx/primitives.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/mlx/primitives.cpp b/mlx/primitives.cpp index 4d12cf059e..62a8519c1e 100644 --- a/mlx/primitives.cpp +++ b/mlx/primitives.cpp @@ -947,6 +947,7 @@ std::pair, std::vector> BroadcastAxes::vmap( int unbatched_ndim = static_cast(inputs[0].ndim()) - (axes[0] >= 0); std::vector ignore_axes; ignore_axes.reserve(ignore_axes_.size()); + // Reexpress ignore_axes_ in the normalized batched layout. for (auto ax : ignore_axes_) { auto pos_ax = unbatched_ndim + ax; if (axes[0] >= 0 && pos_ax >= axes[0]) {