From 80fd15f93184522c2241b947be5623c4cc80d47f Mon Sep 17 00:00:00 2001 From: Sheldon Aristide Date: Tue, 24 Mar 2026 01:25:26 -0400 Subject: [PATCH 1/2] Implement Pad vmap and add coverage --- mlx/primitives.cpp | 38 +++++++++++++++++++++++++++++++++++++- python/tests/test_vmap.py | 38 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 75 insertions(+), 1 deletion(-) diff --git a/mlx/primitives.cpp b/mlx/primitives.cpp index fa9e55700e..e985a29e46 100644 --- a/mlx/primitives.cpp +++ b/mlx/primitives.cpp @@ -3231,7 +3231,43 @@ std::vector Pad::jvp( std::pair, std::vector> Pad::vmap( const std::vector& inputs, const std::vector& axes) { - throw std::runtime_error("Pad vmap is NYI."); + assert(inputs.size() == 2); + assert(axes.size() == 2); + + if (axes[1] >= 0) { + throw std::invalid_argument( + "[Pad::vmap] Vmap over padding value is not supported."); + } + + auto ax = axes[0]; + auto& in = inputs[0]; + auto pad_axes = axes_; + if (ax >= 0) { + auto unbatched_ndim = static_cast(in.ndim()) - 1; + pad_axes.clear(); + pad_axes.reserve(axes_.size()); + for (auto pad_ax : axes_) { + auto normalized_pad_ax = pad_ax < 0 ? pad_ax + unbatched_ndim : pad_ax; + if (normalized_pad_ax < 0 || normalized_pad_ax >= unbatched_ndim) { + throw std::invalid_argument("[Pad::vmap] Invalid padding axis."); + } + pad_axes.push_back( + normalized_pad_ax >= ax ? normalized_pad_ax + 1 : normalized_pad_ax); + } + } + + auto pad_val = inputs[1]; + return { + { + pad(in, + pad_axes, + low_pad_size_, + high_pad_size_, + pad_val, + "constant", + stream()), + }, + {ax}}; } bool Pad::is_equivalent(const Primitive& other) const { diff --git a/python/tests/test_vmap.py b/python/tests/test_vmap.py index 7847a9a60e..30c85de00f 100644 --- a/python/tests/test_vmap.py +++ b/python/tests/test_vmap.py @@ -723,6 +723,44 @@ def gconv(x, w): out = mx.vmap(gconv, in_axes=(0, 0))(x, w) self.assertTrue(mx.allclose(expected, out)) + def test_vmap_pad(self): + def pad2d(x, value=0.0): + return mx.pad(x, ((1, 2), (0, 1)), constant_values=value) + + x = mx.arange(24, dtype=mx.float32).reshape(2, 3, 4) + + expected = mx.stack([pad2d(xi) for xi in x]) + out = mx.vmap(pad2d, in_axes=0)(x) + self.assertTrue(mx.array_equal(out, expected)) + + expected = mx.stack([pad2d(x[:, i, :]) for i in range(x.shape[1])]) + out = mx.vmap(pad2d, in_axes=1)(x) + self.assertTrue(mx.array_equal(out, expected)) + + expected = mx.stack([pad2d(x[:, :, i]) for i in range(x.shape[2])], axis=2) + out = mx.vmap(pad2d, in_axes=-1, out_axes=-1)(x) + self.assertTrue(mx.array_equal(out, expected)) + + nested = mx.vmap(mx.vmap(lambda y: mx.pad(y, (1, 1)))) + out = nested(x) + expected = mx.pad(x, ((0, 0), (0, 0), (1, 1))) + self.assertTrue(mx.array_equal(out, expected)) + + out = mx.vmap( + lambda a, v: mx.pad(a, ((1, 1), (1, 1)), constant_values=v), + in_axes=(0, None), + )(x, mx.array(5.0)) + expected = mx.stack( + [mx.pad(xi, ((1, 1), (1, 1)), constant_values=mx.array(5.0)) for xi in x] + ) + self.assertTrue(mx.array_equal(out, expected)) + + pad_values = mx.array([3.0, 4.0]) + with self.assertRaises(ValueError): + mx.vmap(lambda a, v: mx.pad(a, ((1, 1), (1, 1)), constant_values=v))( + x, pad_values + ) + def test_vmap_types(self): from typing import NamedTuple From b35855491240b3cbe1812b69403e9d7fc6393b1e Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Tue, 24 Mar 2026 13:52:09 -0700 Subject: [PATCH 2/2] Simplify Pad::vmap --- mlx/primitives.cpp | 31 ++++++++++--------------------- 1 file changed, 10 insertions(+), 21 deletions(-) diff --git a/mlx/primitives.cpp b/mlx/primitives.cpp index e985a29e46..2193226838 100644 --- a/mlx/primitives.cpp +++ b/mlx/primitives.cpp @@ -3240,33 +3240,22 @@ std::pair, std::vector> Pad::vmap( } auto ax = axes[0]; - auto& in = inputs[0]; auto pad_axes = axes_; if (ax >= 0) { - auto unbatched_ndim = static_cast(in.ndim()) - 1; - pad_axes.clear(); - pad_axes.reserve(axes_.size()); - for (auto pad_ax : axes_) { - auto normalized_pad_ax = pad_ax < 0 ? pad_ax + unbatched_ndim : pad_ax; - if (normalized_pad_ax < 0 || normalized_pad_ax >= unbatched_ndim) { - throw std::invalid_argument("[Pad::vmap] Invalid padding axis."); - } - pad_axes.push_back( - normalized_pad_ax >= ax ? normalized_pad_ax + 1 : normalized_pad_ax); + for (auto& pad_ax : pad_axes) { + pad_ax = (pad_ax >= ax) ? pad_ax + 1 : pad_ax; } } - auto pad_val = inputs[1]; return { - { - pad(in, - pad_axes, - low_pad_size_, - high_pad_size_, - pad_val, - "constant", - stream()), - }, + {pad( + inputs[0], + pad_axes, + low_pad_size_, + high_pad_size_, + inputs[1], + "constant", + stream())}, {ax}}; }