From 2f95afeff30ba7bb050ed990752f27232ea55346 Mon Sep 17 00:00:00 2001 From: ghstrider Date: Wed, 4 Mar 2026 18:44:27 +0530 Subject: [PATCH] Fix shapeless compile eliding reductions on size-1 dimensions MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit During shapeless (dynamic) tracing, compute_reduce_shape() reported is_noop=true when a reduced dimension happened to be size 1 at trace time. This caused sum/mean/max/min/prod/all/any/argmin/argmax to skip creating the Reduce primitive entirely—returning just astype() or the input directly. On replay with a larger dimension, the missing Reduce meant the graph went straight from GatherAxis (correctly producing [n] elements) to Squeeze (dropping the dimension to scalar), and only the first element was visible in the output. The fix disables the is_noop optimisation when in_dynamic_tracing() is true, following the same pattern already used for no-op broadcast elision (lines 1489, 1543 of ops.cpp). Fixes #3201 Co-Authored-By: Claude Opus 4.6 --- mlx/ops.cpp | 7 ++++++ python/tests/test_compile.py | 49 ++++++++++++++++++++++++++++++++++++ 2 files changed, 56 insertions(+) diff --git a/mlx/ops.cpp b/mlx/ops.cpp index c7af8834fe..2d4c97111b 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -51,6 +51,13 @@ std::tuple, bool> compute_reduce_shape( is_noop &= (out_shape.back() == shape[i]); } std::vector sorted_axes(axes_set.begin(), axes_set.end()); + // During dynamic (shapeless) tracing, dimensions that happen to be size 1 + // at trace time may have different sizes on replay. Never elide the + // reduction in that case, otherwise the Reduce primitive is missing from + // the traced graph and replays produce wrong results. + if (detail::in_dynamic_tracing()) { + is_noop = false; + } return {out_shape, sorted_axes, is_noop}; } diff --git a/python/tests/test_compile.py b/python/tests/test_compile.py index 7db471cc03..2bdc8b24a9 100644 --- a/python/tests/test_compile.py +++ b/python/tests/test_compile.py @@ -989,6 +989,55 @@ def fun(args): self.assertEqual(out[0].shape, (3, 1, 4, 2)) self.assertEqual(out[1].shape, (2, 2, 5)) + def test_shapeless_compile_reduce_after_gather(self): + # Regression test: when the first call to a shapeless-compiled function + # has a size-1 reduced dimension, the reduction was elided as a no-op + # during tracing. On replay with larger sizes, the missing reduction + # caused stale (first-call) values to be returned. + buf = mx.array([10.0, 20.0, 30.0, 40.0, 50.0]) + + # take + sum + def fn_sum(buf, idx): + return mx.take(buf, idx, axis=0).sum() + + cfn = mx.compile(fn_sum, shapeless=True) + for n in [1, 2, 3, 4]: + idx = mx.arange(n) + expected = fn_sum(buf, idx) + result = cfn(buf, idx) + self.assertTrue( + mx.allclose(result, expected), + f"sum failed for n={n}: got {result.item()}, expected {expected.item()}", + ) + + # take + mean + def fn_mean(buf, idx): + return mx.take(buf, idx, axis=0).mean() + + cfn = mx.compile(fn_mean, shapeless=True) + for n in [1, 2, 3, 4]: + idx = mx.arange(n) + expected = fn_mean(buf, idx) + result = cfn(buf, idx) + self.assertTrue( + mx.allclose(result, expected), + f"mean failed for n={n}: got {result.item()}, expected {expected.item()}", + ) + + # take + max + def fn_max(buf, idx): + return mx.take(buf, idx, axis=0).max() + + cfn = mx.compile(fn_max, shapeless=True) + for n in [1, 2, 3, 4]: + idx = mx.arange(n) + expected = fn_max(buf, idx) + result = cfn(buf, idx) + self.assertTrue( + mx.allclose(result, expected), + f"max failed for n={n}: got {result.item()}, expected {expected.item()}", + ) + def test_leaks(self): gc.collect() if mx.metal.is_available():