diff --git a/mlx/primitives.cpp b/mlx/primitives.cpp index fa9e55700e..2193226838 100644 --- a/mlx/primitives.cpp +++ b/mlx/primitives.cpp @@ -3231,7 +3231,32 @@ std::vector Pad::jvp( std::pair, std::vector> Pad::vmap( const std::vector& inputs, const std::vector& axes) { - throw std::runtime_error("Pad vmap is NYI."); + assert(inputs.size() == 2); + assert(axes.size() == 2); + + if (axes[1] >= 0) { + throw std::invalid_argument( + "[Pad::vmap] Vmap over padding value is not supported."); + } + + auto ax = axes[0]; + auto pad_axes = axes_; + if (ax >= 0) { + for (auto& pad_ax : pad_axes) { + pad_ax = (pad_ax >= ax) ? pad_ax + 1 : pad_ax; + } + } + + return { + {pad( + inputs[0], + pad_axes, + low_pad_size_, + high_pad_size_, + inputs[1], + "constant", + stream())}, + {ax}}; } bool Pad::is_equivalent(const Primitive& other) const { diff --git a/python/tests/test_vmap.py b/python/tests/test_vmap.py index 7847a9a60e..30c85de00f 100644 --- a/python/tests/test_vmap.py +++ b/python/tests/test_vmap.py @@ -723,6 +723,44 @@ def gconv(x, w): out = mx.vmap(gconv, in_axes=(0, 0))(x, w) self.assertTrue(mx.allclose(expected, out)) + def test_vmap_pad(self): + def pad2d(x, value=0.0): + return mx.pad(x, ((1, 2), (0, 1)), constant_values=value) + + x = mx.arange(24, dtype=mx.float32).reshape(2, 3, 4) + + expected = mx.stack([pad2d(xi) for xi in x]) + out = mx.vmap(pad2d, in_axes=0)(x) + self.assertTrue(mx.array_equal(out, expected)) + + expected = mx.stack([pad2d(x[:, i, :]) for i in range(x.shape[1])]) + out = mx.vmap(pad2d, in_axes=1)(x) + self.assertTrue(mx.array_equal(out, expected)) + + expected = mx.stack([pad2d(x[:, :, i]) for i in range(x.shape[2])], axis=2) + out = mx.vmap(pad2d, in_axes=-1, out_axes=-1)(x) + self.assertTrue(mx.array_equal(out, expected)) + + nested = mx.vmap(mx.vmap(lambda y: mx.pad(y, (1, 1)))) + out = nested(x) + expected = mx.pad(x, ((0, 0), (0, 0), (1, 1))) + self.assertTrue(mx.array_equal(out, expected)) + + out = mx.vmap( + lambda a, v: mx.pad(a, ((1, 1), (1, 1)), constant_values=v), + in_axes=(0, None), + )(x, mx.array(5.0)) + expected = mx.stack( + [mx.pad(xi, ((1, 1), (1, 1)), constant_values=mx.array(5.0)) for xi in x] + ) + self.assertTrue(mx.array_equal(out, expected)) + + pad_values = mx.array([3.0, 4.0]) + with self.assertRaises(ValueError): + mx.vmap(lambda a, v: mx.pad(a, ((1, 1), (1, 1)), constant_values=v))( + x, pad_values + ) + def test_vmap_types(self): from typing import NamedTuple