From 780e94b0ef4e47c24c52255f4e6fe37a238118d9 Mon Sep 17 00:00:00 2001 From: Sampurna Tuladhar Date: Wed, 3 Dec 2025 00:12:32 -0800 Subject: [PATCH 1/6] feat: working reflect pad --- mlx/ops.cpp | 143 ++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 143 insertions(+) diff --git a/mlx/ops.cpp b/mlx/ops.cpp index ef792cd6f4..28f479f008 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -1400,6 +1400,147 @@ array edge_pad( return padded; } +array reflect_pad( + const array& a, + const std::vector& axes, + const Shape& low_pad_size, + const Shape& high_pad_size, + const Shape& out_shape, + StreamOrDevice s /* = {}*/) { + array out = zeros(out_shape, a.dtype(), s); + auto stops = a.shape(); + for (int i = 0; i < stops.size(); i++) { + stops[i] += low_pad_size[i]; + } + // Copy over values from the unpadded array + array padded = slice_update(out, a, low_pad_size, stops, s); + for (int axis = 0; axis < a.ndim(); axis++) { + std::cout << "Processing axis=" << axis << " low_pad=" << low_pad_size[axis] + << " high_pad=" << high_pad_size[axis] << std::endl; + if (low_pad_size[axis] > 0) { + int n = a.shape(axis); + Shape starts(a.ndim(), 0); + Shape stops = padded.shape(); + Shape strides(a.ndim(), 1); + strides[axis] = -1; + + starts[axis] = low_pad_size[axis] + 1; + stops[axis] = low_pad_size[axis] + n; + array forward = slice(padded, starts, stops, s); + + starts[axis] = low_pad_size[axis] + n - 2; + stops[axis] = low_pad_size[axis] - 1; + array backward = slice(padded, starts, stops, strides, s); + + array cycle = concatenate({forward, backward}, axis, s); + int cycle_len = cycle.shape(axis); // how many rows in cycle (4) + int reps_needed = (low_pad_size[axis] + cycle_len - 1) / cycle_len + 1; + std::vector reps(a.ndim(), 1); + reps[axis] = reps_needed; + array tiled = tile(cycle, reps, s); + + Shape slice_starts(a.ndim(), 0); + Shape slice_stops = tiled.shape(); + slice_stops[axis] = low_pad_size[axis]; + + array padding = slice(tiled, slice_starts, slice_stops, s); + + starts[axis] = 0; + stops[axis] = low_pad_size[axis]; + + Shape rev_strides(a.ndim(), 1); + rev_strides[axis] = -1; + Shape rev_starts(a.ndim(), 0); + Shape rev_stops = padding.shape(); + rev_starts[axis] = padding.shape(axis) - 1; // start from last + rev_stops[axis] = -padding.shape(axis) - 1; // go before first + padding = slice(padding, rev_starts, rev_stops, rev_strides, s); + // Update values in the padded array + padded = slice_update(padded, padding, starts, stops, s); + } + + if (high_pad_size[axis] > 0) { + int n = a.shape(axis); + Shape starts(a.ndim(), 0); + Shape stops = padded.shape(); + Shape strides(a.ndim(), 1); + strides[axis] = -1; + starts[axis] = padded.shape(axis) - high_pad_size[axis] - n + 1; + stops[axis] = padded.shape(axis) - high_pad_size[axis]; + // Edge + 1 values + array forward = slice(padded, starts, stops, s); + starts[axis] = padded.shape(axis) - high_pad_size[axis] - 2; + stops[axis] = padded.shape(axis) - high_pad_size[axis] - n - 1; + + array backward = slice(padded, starts, stops, strides, s); + array cycle = concatenate({backward, forward}, axis, s); + int cycle_len = cycle.shape(axis); + int reps_needed = (high_pad_size[axis] + cycle_len - 1) / cycle_len + 1; + std::vector reps(a.ndim(), 1); + reps[axis] = reps_needed; + array tiled = tile(cycle, reps, s); + + Shape slice_starts(a.ndim(), 0); + Shape slice_stops = tiled.shape(); + slice_stops[axis] = high_pad_size[axis]; + + array padding = slice(tiled, slice_starts, slice_stops, s); + + starts[axis] = padded.shape(axis) - high_pad_size[axis]; + stops[axis] = padded.shape(axis); + + // Update values in the padded array + padded = slice_update(padded, padding, starts, stops, s); + } + } + return padded; +} + +// array symmetric_pad( +// const array& a, +// const std::vector& axes, +// const Shape& low_pad_size, +// const Shape& high_pad_size, +// const Shape& out_shape, +// StreamOrDevice s /* = {}*/) { +// array out = zeros(out_shape, a.dtype(), s); +// auto stops = a.shape(); +// for (int i = 0; i < stops.size(); i++) { +// stops[i] += low_pad_size[i]; +// } +// // Copy over values from the unpadded array +// array padded = slice_update(out, a, low_pad_size, stops, s); +// +// for (int axis = 0; axis < a.ndim(); axis++) { +// if (low_pad_size[axis] > 0) { +// Shape starts(a.ndim(), 0); +// starts[axis] = low_pad_size[axis]; +// auto stops = out.shape(); +// stops[axis] = low_pad_size[axis] + 1; +// // Fetch edge values +// array edge_value = slice(padded, starts, stops, s); +// +// starts[axis] = 0; +// stops[axis] = low_pad_size[axis]; +// // Update edge values in the padded array +// padded = slice_update(padded, edge_value, starts, stops, s); +// } +// +// if (high_pad_size[axis] > 0) { +// Shape starts(a.ndim(), 0); +// starts[axis] = -high_pad_size[axis] - 1; +// auto stops = out.shape(); +// stops[axis] = -high_pad_size[axis]; +// array edge_value = slice(padded, starts, stops, s); +// +// starts[axis] = -high_pad_size[axis]; +// stops[axis] = out.shape(axis); +// padded = slice_update(padded, edge_value, starts, stops, s); +// } +// } +// return padded; +// } + /** Pad an array with a constant value */ array pad( const array& a, @@ -1447,6 +1588,8 @@ array pad( {a, astype(pad_value, a.dtype(), s)}); } else if (mode == "edge") { return edge_pad(a, axes, low_pad_size, high_pad_size, out_shape, s); + } else if (mode == "reflect") { + return reflect_pad(a, axes, low_pad_size, high_pad_size, out_shape, s); } else { std::ostringstream msg; msg << "Invalid padding mode (" << mode << ") passed to pad"; From a2f376241571b71fad59e8b134dee01cbfaf4ea8 Mon Sep 17 00:00:00 2001 From: Sampurna Tuladhar Date: Wed, 3 Dec 2025 01:22:31 -0800 Subject: [PATCH 2/6] feat: add symmetric pad --- mlx/ops.cpp | 173 +++++++++++++++++++++++++-------------- python/src/ops.cpp | 4 +- python/tests/test_ops.py | 20 ++++- tests/ops_tests.cpp | 79 ++++++++++++++++++ 4 files changed, 212 insertions(+), 64 deletions(-) diff --git a/mlx/ops.cpp b/mlx/ops.cpp index 28f479f008..5afb3ad6a9 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -1412,11 +1412,9 @@ array reflect_pad( for (int i = 0; i < stops.size(); i++) { stops[i] += low_pad_size[i]; } - // Copy over values from the unpadded array array padded = slice_update(out, a, low_pad_size, stops, s); + for (int axis = 0; axis < a.ndim(); axis++) { - std::cout << "Processing axis=" << axis << " low_pad=" << low_pad_size[axis] - << " high_pad=" << high_pad_size[axis] << std::endl; if (low_pad_size[axis] > 0) { int n = a.shape(axis); Shape starts(a.ndim(), 0); @@ -1432,8 +1430,9 @@ array reflect_pad( stops[axis] = low_pad_size[axis] - 1; array backward = slice(padded, starts, stops, strides, s); + // build bounce pattern cycle array cycle = concatenate({forward, backward}, axis, s); - int cycle_len = cycle.shape(axis); // how many rows in cycle (4) + int cycle_len = cycle.shape(axis); int reps_needed = (low_pad_size[axis] + cycle_len - 1) / cycle_len + 1; std::vector reps(a.ndim(), 1); reps[axis] = reps_needed; @@ -1442,20 +1441,113 @@ array reflect_pad( Shape slice_starts(a.ndim(), 0); Shape slice_stops = tiled.shape(); slice_stops[axis] = low_pad_size[axis]; - array padding = slice(tiled, slice_starts, slice_stops, s); + // reverse padding for left side placement + Shape rev_strides(a.ndim(), 1); + rev_strides[axis] = -1; + Shape rev_starts(a.ndim(), 0); + Shape rev_stops = padding.shape(); + rev_starts[axis] = padding.shape(axis) - 1; + rev_stops[axis] = -padding.shape(axis) - 1; + padding = slice(padding, rev_starts, rev_stops, rev_strides, s); + starts[axis] = 0; stops[axis] = low_pad_size[axis]; + padded = slice_update(padded, padding, starts, stops, s); + } + + if (high_pad_size[axis] > 0) { + int n = a.shape(axis); + Shape starts(a.ndim(), 0); + Shape stops = padded.shape(); + Shape strides(a.ndim(), 1); + strides[axis] = -1; + int orig_end = low_pad_size[axis] + n; + + starts[axis] = orig_end - n + 1; + stops[axis] = orig_end; + array forward = slice(padded, starts, stops, s); + + starts[axis] = orig_end - 2; + stops[axis] = -(padded.shape(axis) + 1); + array backward = slice(padded, starts, stops, strides, s); + // build bounce pattern cycle + array cycle = concatenate({backward, forward}, axis, s); + int cycle_len = cycle.shape(axis); + int reps_needed = (high_pad_size[axis] + cycle_len - 1) / cycle_len + 1; + std::vector reps(a.ndim(), 1); + reps[axis] = reps_needed; + array tiled = tile(cycle, reps, s); + + Shape slice_starts(a.ndim(), 0); + Shape slice_stops = tiled.shape(); + slice_stops[axis] = high_pad_size[axis]; + array padding = slice(tiled, slice_starts, slice_stops, s); + + starts[axis] = low_pad_size[axis] + n; + stops[axis] = padded.shape(axis); + padded = slice_update(padded, padding, starts, stops, s); + } + } + return padded; +} + +array symmetric_pad( + const array& a, + const std::vector& axes, + const Shape& low_pad_size, + const Shape& high_pad_size, + const Shape& out_shape, + StreamOrDevice s /* = {}*/) { + array out = zeros(out_shape, a.dtype(), s); + auto stops = a.shape(); + for (int i = 0; i < stops.size(); i++) { + stops[i] += low_pad_size[i]; + } + array padded = slice_update(out, a, low_pad_size, stops, s); + + for (int axis = 0; axis < a.ndim(); axis++) { + if (low_pad_size[axis] > 0) { + int n = a.shape(axis); + Shape starts(a.ndim(), 0); + Shape stops = padded.shape(); + Shape strides(a.ndim(), 1); + strides[axis] = -1; + + starts[axis] = low_pad_size[axis]; + stops[axis] = low_pad_size[axis] + n; + array forward = slice(padded, starts, stops, s); + + starts[axis] = low_pad_size[axis] + n - 1; + stops[axis] = (padded.shape(axis) + 1); + array backward = slice(padded, starts, stops, strides, s); + + // build bounce pattern cycle + array cycle = concatenate({forward, backward}, axis, s); + int cycle_len = cycle.shape(axis); + int reps_needed = (low_pad_size[axis] + cycle_len - 1) / cycle_len + 1; + std::vector reps(a.ndim(), 1); + reps[axis] = reps_needed; + array tiled = tile(cycle, reps, s); + + Shape slice_starts(a.ndim(), 0); + Shape slice_stops = tiled.shape(); + slice_stops[axis] = low_pad_size[axis]; + array padding = slice(tiled, slice_starts, slice_stops, s); + + // reverse padding for left side placement Shape rev_strides(a.ndim(), 1); rev_strides[axis] = -1; Shape rev_starts(a.ndim(), 0); Shape rev_stops = padding.shape(); - rev_starts[axis] = padding.shape(axis) - 1; // start from last - rev_stops[axis] = -padding.shape(axis) - 1; // go before first + rev_starts[axis] = padding.shape(axis) - 1; + rev_stops[axis] = -padding.shape(axis) - 1; padding = slice(padding, rev_starts, rev_stops, rev_strides, s); - // Update values in the padded array + + starts[axis] = 0; + stops[axis] = low_pad_size[axis]; padded = slice_update(padded, padding, starts, stops, s); } @@ -1465,14 +1557,17 @@ array reflect_pad( Shape stops = padded.shape(); Shape strides(a.ndim(), 1); strides[axis] = -1; - starts[axis] = padded.shape(axis) - high_pad_size[axis] - n + 1; - stops[axis] = padded.shape(axis) - high_pad_size[axis]; - // Edge + 1 values + int orig_end = low_pad_size[axis] + n; + + starts[axis] = orig_end - n; + stops[axis] = orig_end; array forward = slice(padded, starts, stops, s); - starts[axis] = padded.shape(axis) - high_pad_size[axis] - 2; - stops[axis] = padded.shape(axis) - high_pad_size[axis] - n - 1; + starts[axis] = orig_end - 1; + stops[axis] = -(padded.shape(axis) + 1); array backward = slice(padded, starts, stops, strides, s); + + // build bounce pattern cycle array cycle = concatenate({backward, forward}, axis, s); int cycle_len = cycle.shape(axis); int reps_needed = (high_pad_size[axis] + cycle_len - 1) / cycle_len + 1; @@ -1483,64 +1578,16 @@ array reflect_pad( Shape slice_starts(a.ndim(), 0); Shape slice_stops = tiled.shape(); slice_stops[axis] = high_pad_size[axis]; - array padding = slice(tiled, slice_starts, slice_stops, s); - starts[axis] = padded.shape(axis) - high_pad_size[axis]; + starts[axis] = low_pad_size[axis] + n; stops[axis] = padded.shape(axis); - - // Update values in the padded array padded = slice_update(padded, padding, starts, stops, s); } } return padded; } -// array symmetric_pad( -// const array& a, -// const std::vector& axes, -// const Shape& low_pad_size, -// const Shape& high_pad_size, -// const Shape& out_shape, -// StreamOrDevice s /* = {}*/) { -// array out = zeros(out_shape, a.dtype(), s); -// auto stops = a.shape(); -// for (int i = 0; i < stops.size(); i++) { -// stops[i] += low_pad_size[i]; -// } -// // Copy over values from the unpadded array -// array padded = slice_update(out, a, low_pad_size, stops, s); -// -// for (int axis = 0; axis < a.ndim(); axis++) { -// if (low_pad_size[axis] > 0) { -// Shape starts(a.ndim(), 0); -// starts[axis] = low_pad_size[axis]; -// auto stops = out.shape(); -// stops[axis] = low_pad_size[axis] + 1; -// // Fetch edge values -// array edge_value = slice(padded, starts, stops, s); -// -// starts[axis] = 0; -// stops[axis] = low_pad_size[axis]; -// // Update edge values in the padded array -// padded = slice_update(padded, edge_value, starts, stops, s); -// } -// -// if (high_pad_size[axis] > 0) { -// Shape starts(a.ndim(), 0); -// starts[axis] = -high_pad_size[axis] - 1; -// auto stops = out.shape(); -// stops[axis] = -high_pad_size[axis]; -// array edge_value = slice(padded, starts, stops, s); -// -// starts[axis] = -high_pad_size[axis]; -// stops[axis] = out.shape(axis); -// padded = slice_update(padded, edge_value, starts, stops, s); -// } -// } -// return padded; -// } - /** Pad an array with a constant value */ array pad( const array& a, @@ -1590,6 +1637,8 @@ array pad( return edge_pad(a, axes, low_pad_size, high_pad_size, out_shape, s); } else if (mode == "reflect") { return reflect_pad(a, axes, low_pad_size, high_pad_size, out_shape, s); + } else if (mode == "symmetric") { + return symmetric_pad(a, axes, low_pad_size, high_pad_size, out_shape, s); } else { std::ostringstream msg; msg << "Invalid padding mode (" << mode << ") passed to pad"; diff --git a/python/src/ops.cpp b/python/src/ops.cpp index a4ce55f8b3..6aa8062584 100644 --- a/python/src/ops.cpp +++ b/python/src/ops.cpp @@ -3257,7 +3257,7 @@ void init_ops(nb::module_& m) { nb::kw_only(), "stream"_a = nb::none(), nb::sig( - "def pad(a: array, pad_width: Union[int, tuple[int], tuple[int, int], list[tuple[int, int]]], mode: Literal['constant', 'edge'] = 'constant', constant_values: Union[scalar, array] = 0, *, stream: Union[None, Stream, Device] = None) -> array"), + "def pad(a: array, pad_width: Union[int, tuple[int], tuple[int, int], list[tuple[int, int]]], mode: Literal['constant', 'edge', 'reflect', 'symmetric'] = 'constant', constant_values: Union[scalar, array] = 0, *, stream: Union[None, Stream, Device] = None) -> array"), R"pbdoc( Pad an array with a constant value @@ -3272,6 +3272,8 @@ void init_ops(nb::module_& m) { mode: Padding mode. One of the following strings: "constant" (default): Pads with a constant value. "edge": Pads with the edge values of array. + "reflect": Pads with the reflection of the array mirrored along the edge, excluding the edge value. + "symmetric": Pads with the reflection of the array mirrored along the edge, including the edge value. constant_value (array or scalar, optional): Optional constant value to pad the edges of the array with. diff --git a/python/tests/test_ops.py b/python/tests/test_ops.py index 4aff1eaddf..2ff17ce17b 100644 --- a/python/tests/test_ops.py +++ b/python/tests/test_ops.py @@ -5,10 +5,11 @@ import unittest from itertools import permutations, product -import mlx.core as mx import mlx_tests import numpy as np +import mlx.core as mx + def np_wrap_between(x, a): """Wraps `x` between `[-a, a]`.""" @@ -1960,6 +1961,23 @@ def test_pad(self): self.assertEqual(list(b_npy.shape), list(b_mlx.shape)) self.assertTrue(np.allclose(b_npy, b_mlx, atol=1e-6)) + b_npy = np.pad(a_npy, pw, mode="reflect") + b_mlx = mx.pad(a_mlx, pw, mode="reflect") + + self.assertEqual(list(b_npy.shape), list(b_mlx.shape)) + self.assertTrue(np.allclose(b_npy, b_mlx, atol=1e-6)) + + b_npy = np.pad(a_npy, pw, mode="symmetric") + b_mlx = mx.pad(a_mlx, pw, mode="symmetric") + + if not np.allclose(b_npy, b_mlx, atol=1e-6): + print(f"\nSymmetric test failed for pad_width: {pw}") + print(f"NumPy result (first 20):", b_npy.flat[:20]) + print(f"MLX result (first 20):", np.array(b_mlx).flat[:20]) + + self.assertEqual(list(b_npy.shape), list(b_mlx.shape)) + self.assertTrue(np.allclose(b_npy, b_mlx, atol=1e-6)) + a = mx.zeros((1, 1, 1)) self.assertEqual(mx.pad(a, 1).shape, (3, 3, 3)) self.assertEqual(mx.pad(a, (1,)).shape, (3, 3, 3)) diff --git a/tests/ops_tests.cpp b/tests/ops_tests.cpp index 1d05cb0c15..02ea47f374 100644 --- a/tests/ops_tests.cpp +++ b/tests/ops_tests.cpp @@ -2840,6 +2840,85 @@ TEST_CASE("test pad") { CHECK(array_equal(padded_x, expected).item()); } +TEST_CASE("test pad reflect") { + auto x1d = array({1.0f, 2.0f, 3.0f}); + auto padded_1d = pad(x1d, {{2, 2}}, array(0), "reflect"); + auto expected_1d = array({3.0f, 2.0f, 1.0f, 2.0f, 3.0f, 2.0f, 1.0f}); + CHECK(array_equal(padded_1d, expected_1d).item()); + + auto x2d = + array({1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f}, {3, 3}); + auto padded_2d = pad(x2d, {{1, 1}, {2, 2}}, array(0), "reflect"); + auto expected_2d = array( + {6.0f, 5.0f, 4.0f, 5.0f, 6.0f, 5.0f, 4.0f, 3.0f, 2.0f, 1.0f, 2.0f, 3.0f, + 2.0f, 1.0f, 6.0f, 5.0f, 4.0f, 5.0f, 6.0f, 5.0f, 4.0f, 9.0f, 8.0f, 7.0f, + 8.0f, 9.0f, 8.0f, 7.0f, 6.0f, 5.0f, 4.0f, 5.0f, 6.0f, 5.0f, 4.0f}, + {5, 7}); + CHECK(array_equal(padded_2d, expected_2d).item()); + + auto x_small = array({1.0f, 2.0f, 3.0f}); + auto padded_large = pad(x_small, {{5, 5}}, array(0), "reflect"); + auto expected_large = array( + {2.0f, + 1.0f, + 2.0f, + 3.0f, + 2.0f, + 1.0f, + 2.0f, + 3.0f, + 2.0f, + 1.0f, + 2.0f, + 3.0f, + 2.0f}); + CHECK(array_equal(padded_large, expected_large).item()); + + auto x_min = array({1.0f, 2.0f}); + auto padded_min = pad(x_min, {{1, 1}}, array(0), "reflect"); + auto expected_min = array({2.0f, 1.0f, 2.0f, 1.0f}); + CHECK(array_equal(padded_min, expected_min).item()); + + auto x_asym = array({1.0f, 2.0f, 3.0f, 4.0f}, {2, 2}); + auto padded_asym = pad(x_asym, {{0, 2}, {1, 0}}, array(0), "reflect"); + CHECK_EQ(padded_asym.shape(), Shape{4, 3}); +} + +TEST_CASE("test pad symmetric") { + auto x1d = array({1.0f, 2.0f, 3.0f}); + auto padded_1d = pad(x1d, {{2, 2}}, array(0), "symmetric"); + auto expected_1d = array({2.0f, 1.0f, 1.0f, 2.0f, 3.0f, 3.0f, 2.0f}); + CHECK(array_equal(padded_1d, expected_1d).item()); + + auto x2d = + array({1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f}, {3, 3}); + auto padded_2d = pad(x2d, {{1, 1}, {2, 2}}, array(0), "symmetric"); + auto expected_2d = array( + {5.0f, 4.0f, 4.0f, 5.0f, 6.0f, 6.0f, 5.0f, 2.0f, 1.0f, 1.0f, 2.0f, 3.0f, + 3.0f, 2.0f, 2.0f, 1.0f, 1.0f, 2.0f, 3.0f, 3.0f, 2.0f, 5.0f, 4.0f, 4.0f, + 5.0f, 6.0f, 6.0f, 5.0f, 8.0f, 7.0f, 7.0f, 8.0f, 9.0f, 9.0f, 8.0f}, + {5, 7}); + CHECK(array_equal(padded_2d, expected_2d).item()); + + auto x_small = array({1.0f, 2.0f, 3.0f}); + auto padded_large = pad(x_small, {{5, 5}}, array(0), "symmetric"); + auto expected_large = array( + {3.0f, + 2.0f, + 1.0f, + 1.0f, + 2.0f, + 3.0f, + 3.0f, + 2.0f, + 1.0f, + 1.0f, + 2.0f, + 3.0f, + 3.0f}); + CHECK(array_equal(padded_large, expected_large).item()); +} + TEST_CASE("test power") { CHECK_EQ(power(array(1), array(2)).item(), 1); CHECK_EQ((power(array(-1), array(2))).item(), 1); From d4a461ac09a4f7dcddefbc69bf9675bfbaa6135f Mon Sep 17 00:00:00 2001 From: Sampurna Tuladhar Date: Wed, 3 Dec 2025 01:48:52 -0800 Subject: [PATCH 3/6] syntax: remove debug --- python/tests/test_ops.py | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/python/tests/test_ops.py b/python/tests/test_ops.py index 2ff17ce17b..55813b364f 100644 --- a/python/tests/test_ops.py +++ b/python/tests/test_ops.py @@ -5,11 +5,10 @@ import unittest from itertools import permutations, product +import mlx.core as mx import mlx_tests import numpy as np -import mlx.core as mx - def np_wrap_between(x, a): """Wraps `x` between `[-a, a]`.""" @@ -1970,11 +1969,6 @@ def test_pad(self): b_npy = np.pad(a_npy, pw, mode="symmetric") b_mlx = mx.pad(a_mlx, pw, mode="symmetric") - if not np.allclose(b_npy, b_mlx, atol=1e-6): - print(f"\nSymmetric test failed for pad_width: {pw}") - print(f"NumPy result (first 20):", b_npy.flat[:20]) - print(f"MLX result (first 20):", np.array(b_mlx).flat[:20]) - self.assertEqual(list(b_npy.shape), list(b_mlx.shape)) self.assertTrue(np.allclose(b_npy, b_mlx, atol=1e-6)) From 0c4fad38069841cfdf357e15d5df1e54339542b3 Mon Sep 17 00:00:00 2001 From: Sampurna Tuladhar Date: Wed, 3 Dec 2025 01:51:36 -0800 Subject: [PATCH 4/6] docs: add ack --- ACKNOWLEDGMENTS.md | 1 + 1 file changed, 1 insertion(+) diff --git a/ACKNOWLEDGMENTS.md b/ACKNOWLEDGMENTS.md index 186908f09c..c5c3b429ef 100644 --- a/ACKNOWLEDGMENTS.md +++ b/ACKNOWLEDGMENTS.md @@ -20,6 +20,7 @@ MLX was developed with contributions from the following individuals: - Paul Paczuski: Improved stability of BCE loss calculation - Max-Heinrich Laves: Added `conv_transpose1d`, `conv_transpose2d`, and `conv_transpose3d` ops. - Gökdeniz Gülmez: Added the `Muon (MomentUm Orthogonalized by Newton-schulz)` optimizer, and the `ReLU²` activation function. +- Siddhartha Tuladhar: Added `reflect` and `symmetric` padding modes. From 0f0bc89a9919262631a50287227ff48140d95c97 Mon Sep 17 00:00:00 2001 From: sidtuladhar Date: Sun, 29 Mar 2026 17:40:23 +0100 Subject: [PATCH 5/6] fix: tests --- tests/ops_tests.cpp | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/ops_tests.cpp b/tests/ops_tests.cpp index 02ea47f374..c2976dfdad 100644 --- a/tests/ops_tests.cpp +++ b/tests/ops_tests.cpp @@ -2903,7 +2903,9 @@ TEST_CASE("test pad symmetric") { auto x_small = array({1.0f, 2.0f, 3.0f}); auto padded_large = pad(x_small, {{5, 5}}, array(0), "symmetric"); auto expected_large = array( - {3.0f, + {2.0f, + 3.0f, + 3.0f, 2.0f, 1.0f, 1.0f, @@ -2913,9 +2915,7 @@ TEST_CASE("test pad symmetric") { 2.0f, 1.0f, 1.0f, - 2.0f, - 3.0f, - 3.0f}); + 2.0f}); CHECK(array_equal(padded_large, expected_large).item()); } From 9f6e248aaffe2f46ab4548c5721e21d99c222592 Mon Sep 17 00:00:00 2001 From: sidtuladhar Date: Mon, 30 Mar 2026 07:29:22 +0100 Subject: [PATCH 6/6] fix: cpp test --- mlx/ops.cpp | 2 +- tests/ops_tests.cpp | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/mlx/ops.cpp b/mlx/ops.cpp index 5afb3ad6a9..c2361c2016 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -1521,7 +1521,7 @@ array symmetric_pad( array forward = slice(padded, starts, stops, s); starts[axis] = low_pad_size[axis] + n - 1; - stops[axis] = (padded.shape(axis) + 1); + stops[axis] = low_pad_size[axis] - 1; array backward = slice(padded, starts, stops, strides, s); // build bounce pattern cycle diff --git a/tests/ops_tests.cpp b/tests/ops_tests.cpp index c2976dfdad..0a5b97fc2f 100644 --- a/tests/ops_tests.cpp +++ b/tests/ops_tests.cpp @@ -2894,9 +2894,9 @@ TEST_CASE("test pad symmetric") { array({1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f}, {3, 3}); auto padded_2d = pad(x2d, {{1, 1}, {2, 2}}, array(0), "symmetric"); auto expected_2d = array( - {5.0f, 4.0f, 4.0f, 5.0f, 6.0f, 6.0f, 5.0f, 2.0f, 1.0f, 1.0f, 2.0f, 3.0f, - 3.0f, 2.0f, 2.0f, 1.0f, 1.0f, 2.0f, 3.0f, 3.0f, 2.0f, 5.0f, 4.0f, 4.0f, - 5.0f, 6.0f, 6.0f, 5.0f, 8.0f, 7.0f, 7.0f, 8.0f, 9.0f, 9.0f, 8.0f}, + {2.0f, 1.0f, 1.0f, 2.0f, 3.0f, 3.0f, 2.0f, 2.0f, 1.0f, 1.0f, 2.0f, 3.0f, + 3.0f, 2.0f, 5.0f, 4.0f, 4.0f, 5.0f, 6.0f, 6.0f, 5.0f, 8.0f, 7.0f, 7.0f, + 8.0f, 9.0f, 9.0f, 8.0f, 8.0f, 7.0f, 7.0f, 8.0f, 9.0f, 9.0f, 8.0f}, {5, 7}); CHECK(array_equal(padded_2d, expected_2d).item());