From 705bf69e30e57e2a4516cc7a54fc1c2047b4e33a Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Tue, 31 Mar 2026 02:52:50 -0700 Subject: [PATCH] Add vmap for BroadcastAxes --- mlx/ops.cpp | 52 +++++++++++++++++---------------------- mlx/primitives.cpp | 35 +++++++++++++++++++++++++- python/tests/test_vmap.py | 45 +++++++++++++++++++++++++++++++++ 3 files changed, 101 insertions(+), 31 deletions(-) diff --git a/mlx/ops.cpp b/mlx/ops.cpp index ef792cd6f4..36e5863aa6 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -1676,23 +1676,19 @@ std::vector broadcast_arrays( for (int i = 0; i < inputs.size(); ++i) { auto& in = inputs[i]; auto out_shape = check_and_get_shape(in); - if (in.shape() == out_shape) { - outputs.push_back(in); - } else { - // broadcasted array goes first followed by other stopgrad inputs - std::vector p_inputs = {in}; - for (int j = 0; j < inputs.size(); ++j) { - if (j == i) { - continue; - } - p_inputs.push_back(stop_grad_inputs[j]); + // broadcasted array goes first followed by other stopgrad inputs + std::vector p_inputs = {in}; + for (int j = 0; j < inputs.size(); ++j) { + if (j == i) { + continue; } - outputs.push_back(array( - std::move(out_shape), - in.dtype(), - std::make_shared(to_stream(s), ignore_axes), - std::move(p_inputs))); + p_inputs.push_back(stop_grad_inputs[j]); } + outputs.push_back(array( + out_shape, + in.dtype(), + std::make_shared(to_stream(s), ignore_axes), + std::move(p_inputs))); } return outputs; } @@ -1727,23 +1723,19 @@ std::vector broadcast_arrays( } for (int i = 0; i < inputs.size(); ++i) { auto& in = inputs[i]; - if (in.shape() == shape) { - outputs.push_back(in); - } else { - // broadcasted array goes first followed by other stopgrad inputs - std::vector p_inputs = {in}; - for (int j = 0; j < inputs.size(); ++j) { - if (j == i) { - continue; - } - p_inputs.push_back(stop_grad_inputs[j]); + // broadcasted array goes first followed by other stopgrad inputs + std::vector p_inputs = {in}; + for (int j = 0; j < inputs.size(); ++j) { + if (j == i) { + continue; } - outputs.push_back(array( - shape, - in.dtype(), - std::make_shared(to_stream(s), shape), - std::move(p_inputs))); + p_inputs.push_back(stop_grad_inputs[j]); } + outputs.push_back(array( + shape, + in.dtype(), + std::make_shared(to_stream(s), shape), + std::move(p_inputs))); } return outputs; } diff --git a/mlx/primitives.cpp b/mlx/primitives.cpp index 220b8bcf55..a861dba40b 100644 --- a/mlx/primitives.cpp +++ b/mlx/primitives.cpp @@ -904,7 +904,40 @@ std::vector BroadcastAxes::jvp( std::pair, std::vector> BroadcastAxes::vmap( const std::vector& inputs, const std::vector& axes) { - throw std::invalid_argument("[BroadcastAxes] VMAP NYI"); + std::vector new_inputs = inputs; + std::vector new_axes = axes; + size_t ndim = 0; + bool have_batch = false; + for (int i = 0; i < inputs.size(); i++) { + have_batch |= axes[i] >= 0; + ndim = std::max(inputs[i].ndim(), ndim); + } + + std::vector expand; + expand.reserve(ndim); + for (int i = 0; i < inputs.size(); i++) { + int extra = ndim - inputs[i].ndim(); + if (axes[i] >= 0 && extra > 0) { + new_axes[i] += extra; + expand.resize(extra); + std::iota(expand.begin(), expand.end(), 0); + new_inputs[i] = expand_dims(new_inputs[i], expand, stream()); + } + + if (new_axes[i] > 0) { + new_inputs[i] = moveaxis(new_inputs[i], new_axes[i], 0, stream()); + } + } + + auto shape = output_shape(new_inputs, ignore_axes_); + auto dtype = new_inputs[0].dtype(); + return { + {array( + shape, + dtype, + std::make_shared(stream(), ignore_axes_), + std::move(new_inputs))}, + {have_batch ? 0 : -1}}; } bool BroadcastAxes::is_equivalent(const Primitive& other) const { diff --git a/python/tests/test_vmap.py b/python/tests/test_vmap.py index 5d60c0e197..99c30a2dc2 100644 --- a/python/tests/test_vmap.py +++ b/python/tests/test_vmap.py @@ -899,6 +899,51 @@ def scatter_fn(x, m, src): out = double_scatter(a + 0, mask, src) self.assertTrue(mx.array_equal(expected, out)) + def test_broadcast_axes_vmap(self): + # Broadcast axes requires shapeless compile to properly test + + counter = [0] + + def fn(x, y): + counter[0] += 1 + return mx.matmul(x, y) + + x = mx.random.normal((2, 3, 1, 4, 5)) + y = mx.random.normal((1, 2, 5, 6)) + z = mx.random.normal((3, 2, 1, 4, 5)) + w = mx.random.normal((2, 3, 5, 6)) + + vmap_fn = mx.vmap(fn, in_axes=(0, 1)) + cvmap_fn = mx.compile(vmap_fn, shapeless=True) + + expected = vmap_fn(x, y) + out = cvmap_fn(x, y) + self.assertTrue(mx.array_equal(expected, out)) + self.assertEqual(2, counter[0]) + + expected = vmap_fn(z, w) + out = cvmap_fn(z, w) + self.assertTrue(mx.array_equal(expected, out)) + self.assertEqual(3, counter[0]) + + x = mx.random.normal((2, 3, 1, 4, 5)) + y = mx.random.normal((1, 2, 5, 6)) + z = mx.random.normal((2, 3, 1, 7, 2)) + w = mx.random.normal((1, 2, 2, 3)) + + vmap_fn = mx.vmap(fn, in_axes=(0, None)) + cvmap_fn = mx.compile(vmap_fn, shapeless=True) + + expected = vmap_fn(x, y) + out = cvmap_fn(x, y) + self.assertTrue(mx.array_equal(expected, out)) + self.assertEqual(5, counter[0]) + + expected = vmap_fn(z, w) + out = cvmap_fn(z, w) + self.assertTrue(mx.array_equal(expected, out)) + self.assertEqual(6, counter[0]) + if __name__ == "__main__": mlx_tests.MLXTestRunner()