From 387f52d9ae6bba9d176e604c45c8bbaa4cf253dd Mon Sep 17 00:00:00 2001 From: pHequals7 Date: Mon, 16 Mar 2026 10:19:02 +0530 Subject: [PATCH 1/6] Add output_shapes for AddMM, Slice, and CustomKernel primitives Enable `compile(shapeless: true)` for models that use: 1. **AddMM** (biased Linear layers): Most transformer models use Linear with bias=true, which dispatches AddMM. Without output_shapes, compile(shapeless:true) fails. The fix matches Matmul::output_shapes. 2. **Slice** (array subscripting): Any compiled function that slices arrays (e.g., `array[0..(total / resolved_shapes.size()); + for (size_t i = 0; i < resolved_shapes.size(); i++) { + for (size_t j = 0; j < resolved_shapes[i].size(); j++) { + if (resolved_shapes[i][j] == -1) resolved_shapes[i][j] = per_out; + } + } + } + return array::make_arrays( - std::move(output_shapes), + std::move(resolved_shapes), std::move(output_dtypes), std::make_shared( s, @@ -319,7 +341,8 @@ CustomKernelFunction metal_kernel( init_value, std::vector{}, false, - 0), + 0, + std::move(saved_shapes)), std::move(inputs)); }; } diff --git a/mlx/fast_primitives.h b/mlx/fast_primitives.h index 4434830875..84ca412a73 100644 --- a/mlx/fast_primitives.h +++ b/mlx/fast_primitives.h @@ -375,7 +375,8 @@ class CustomKernel : public Primitive { std::optional init_value, std::vector scalar_arguments, bool is_precompiled, - int shared_memory) + int shared_memory, + std::vector output_shapes = {}) : Primitive(stream), name_(std::move(name)), source_(std::move(source)), @@ -386,7 +387,8 @@ class CustomKernel : public Primitive { init_value_(init_value), scalar_arguments_(std::move(scalar_arguments)), is_precompiled_(is_precompiled), - shared_memory_(shared_memory) {} + shared_memory_(shared_memory), + output_shapes_(std::move(output_shapes)) {} void eval_cpu(const std::vector& inputs, std::vector& outputs) override { @@ -397,6 +399,25 @@ class CustomKernel : public Primitive { override; DEFINE_NAME(CustomKernel); + std::vector output_shapes(const std::vector& inputs) override { + // Check for dynamic sentinel (-1) in stored shapes + bool has_dynamic = false; + for (const auto& s : output_shapes_) { + for (auto d : s) { if (d == -1) { has_dynamic = true; break; } } + if (has_dynamic) break; + } + if (!has_dynamic) return output_shapes_; + + // Dynamic: compute from runtime inputs (total_input_size / num_outputs) + size_t total = 0; + for (const auto& in : inputs) total += in.size(); + auto per_out = static_cast(total / output_shapes_.size()); + auto result = output_shapes_; + for (auto& s : result) { + for (auto& d : s) { if (d == -1) d = per_out; } + } + return result; + } auto state() const { return std::make_tuple( name_, @@ -422,6 +443,7 @@ class CustomKernel : public Primitive { std::vector scalar_arguments_; bool is_precompiled_; int shared_memory_; + std::vector output_shapes_; }; } // namespace mlx::core::fast diff --git a/mlx/primitives.cpp b/mlx/primitives.cpp index 92e54f9991..dbec418103 100644 --- a/mlx/primitives.cpp +++ b/mlx/primitives.cpp @@ -345,6 +345,14 @@ bool AddMM::is_equivalent(const Primitive& other) const { return (alpha_ == a_other.alpha_ && beta_ == a_other.beta_); } +std::vector AddMM::output_shapes(const std::vector& inputs) { + // out = alpha * (A @ B) + beta * C + // Output shape matches C (inputs[0]), with last dim from B (inputs[2]) + auto out_shape = inputs[0].shape(); + out_shape.back() = inputs[2].shape(-1); + return {std::move(out_shape)}; +} + std::pair, std::vector> AddMM::vmap( const std::vector& inputs, const std::vector& axes) { @@ -4775,6 +4783,31 @@ bool Slice::is_equivalent(const Primitive& other) const { end_indices_ == s_other.end_indices_ && strides_ == s_other.strides_); } +std::vector Slice::output_shapes(const std::vector& inputs) { + // Re-normalize slice bounds against the runtime input shape. + // Works for constant-dimension slices; variable-dimension slices + // should use take()/DynamicSlice instead. + auto& in_shape = inputs[0].shape(); + Shape out_shape(in_shape.size()); + for (int i = 0; i < in_shape.size(); ++i) { + auto n = in_shape[i]; + auto s = start_indices_[i] < 0 ? start_indices_[i] + n : start_indices_[i]; + auto e = end_indices_[i] < 0 ? end_indices_[i] + n : end_indices_[i]; + s = std::max(static_cast(0), std::min(s, n)); + e = std::max(static_cast(0), std::min(e, n)); + if (strides_[i] > 0) { + e = e < s ? s : e; + out_shape[i] = (e - s + strides_[i] - 1) / strides_[i]; + } else { + auto st = std::min(s, n - 1); + auto ed = e > -1 ? e : -1; + ed = ed > st ? st : ed; + out_shape[i] = (st - ed + (-strides_[i]) - 1) / (-strides_[i]); + } + } + return {std::move(out_shape)}; +} + std::pair, std::vector> SliceUpdate::vmap( const std::vector& inputs, const std::vector& axes) { diff --git a/mlx/primitives.h b/mlx/primitives.h index 4091aafcfb..1d417c918e 100644 --- a/mlx/primitives.h +++ b/mlx/primitives.h @@ -200,6 +200,7 @@ class AddMM : public UnaryPrimitive { DEFINE_NAME(AddMM) bool is_equivalent(const Primitive& other) const override; + std::vector output_shapes(const std::vector& inputs) override; std::pair state() const { return {alpha_, beta_}; }; @@ -2043,6 +2044,7 @@ class Slice : public UnaryPrimitive { DEFINE_GRADS() DEFINE_NAME(Slice) bool is_equivalent(const Primitive& other) const override; + std::vector output_shapes(const std::vector& inputs) override; auto state() const { return std::make_tuple(start_indices_, end_indices_, strides_); } From 95382f97d5ba9d225b4b63af124789f4c0c75334 Mon Sep 17 00:00:00 2001 From: pHequals7 Date: Wed, 18 Mar 2026 00:23:09 +0530 Subject: [PATCH 2/6] Simplify AddMM::output_shapes per review feedback C (inputs[0]) is already validated to match the output shape at construction in ops.cpp, so we can return its shape directly instead of recalculating from B's last dimension. Co-Authored-By: Claude Opus 4.6 (1M context) --- mlx/primitives.cpp | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/mlx/primitives.cpp b/mlx/primitives.cpp index dbec418103..2848d8a46b 100644 --- a/mlx/primitives.cpp +++ b/mlx/primitives.cpp @@ -346,11 +346,9 @@ bool AddMM::is_equivalent(const Primitive& other) const { } std::vector AddMM::output_shapes(const std::vector& inputs) { - // out = alpha * (A @ B) + beta * C - // Output shape matches C (inputs[0]), with last dim from B (inputs[2]) - auto out_shape = inputs[0].shape(); - out_shape.back() = inputs[2].shape(-1); - return {std::move(out_shape)}; + // C (inputs[0]) is validated to match the output shape at construction + // in ops.cpp, so its shape is the output shape directly. + return {inputs[0].shape()}; } std::pair, std::vector> AddMM::vmap( From 858ac132400808759c640a388618447678b17a88 Mon Sep 17 00:00:00 2001 From: pHequals7 Date: Fri, 20 Mar 2026 02:48:43 +0530 Subject: [PATCH 3/6] Scope down to AddMM::output_shapes only MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Drop Slice and CustomKernel changes per review feedback: - Slice::output_shapes unnecessary for constant-dimension slices - CustomKernel -1 sentinel is not a general solution; proper approach is a shape inference function (separate discussion) Keeping only AddMM::output_shapes which is straightforward — C's shape is already validated to match the output at construction. Co-Authored-By: Claude Opus 4.6 (1M context) --- mlx/backend/metal/custom_kernel.cpp | 27 ++------------------------- mlx/fast_primitives.h | 26 ++------------------------ mlx/primitives.cpp | 25 ------------------------- mlx/primitives.h | 1 - 4 files changed, 4 insertions(+), 75 deletions(-) diff --git a/mlx/backend/metal/custom_kernel.cpp b/mlx/backend/metal/custom_kernel.cpp index 5437822a1c..eec6645bdd 100644 --- a/mlx/backend/metal/custom_kernel.cpp +++ b/mlx/backend/metal/custom_kernel.cpp @@ -305,30 +305,8 @@ CustomKernelFunction metal_kernel( << "```" << std::endl; } - // Store original shapes (may contain -1 for dynamic dimensions) - auto saved_shapes = output_shapes; - - // Resolve -1 sentinel values for array creation (make_arrays needs real sizes). - // Convention: -1 → total_input_elements / num_outputs. - auto resolved_shapes = output_shapes; - bool has_dynamic = false; - for (const auto& s : resolved_shapes) { - for (auto d : s) { if (d == -1) { has_dynamic = true; break; } } - if (has_dynamic) break; - } - if (has_dynamic) { - size_t total = 0; - for (const auto& in : inputs) total += in.size(); - auto per_out = static_cast(total / resolved_shapes.size()); - for (size_t i = 0; i < resolved_shapes.size(); i++) { - for (size_t j = 0; j < resolved_shapes[i].size(); j++) { - if (resolved_shapes[i][j] == -1) resolved_shapes[i][j] = per_out; - } - } - } - return array::make_arrays( - std::move(resolved_shapes), + std::move(output_shapes), std::move(output_dtypes), std::make_shared( s, @@ -341,8 +319,7 @@ CustomKernelFunction metal_kernel( init_value, std::vector{}, false, - 0, - std::move(saved_shapes)), + 0), std::move(inputs)); }; } diff --git a/mlx/fast_primitives.h b/mlx/fast_primitives.h index 84ca412a73..4434830875 100644 --- a/mlx/fast_primitives.h +++ b/mlx/fast_primitives.h @@ -375,8 +375,7 @@ class CustomKernel : public Primitive { std::optional init_value, std::vector scalar_arguments, bool is_precompiled, - int shared_memory, - std::vector output_shapes = {}) + int shared_memory) : Primitive(stream), name_(std::move(name)), source_(std::move(source)), @@ -387,8 +386,7 @@ class CustomKernel : public Primitive { init_value_(init_value), scalar_arguments_(std::move(scalar_arguments)), is_precompiled_(is_precompiled), - shared_memory_(shared_memory), - output_shapes_(std::move(output_shapes)) {} + shared_memory_(shared_memory) {} void eval_cpu(const std::vector& inputs, std::vector& outputs) override { @@ -399,25 +397,6 @@ class CustomKernel : public Primitive { override; DEFINE_NAME(CustomKernel); - std::vector output_shapes(const std::vector& inputs) override { - // Check for dynamic sentinel (-1) in stored shapes - bool has_dynamic = false; - for (const auto& s : output_shapes_) { - for (auto d : s) { if (d == -1) { has_dynamic = true; break; } } - if (has_dynamic) break; - } - if (!has_dynamic) return output_shapes_; - - // Dynamic: compute from runtime inputs (total_input_size / num_outputs) - size_t total = 0; - for (const auto& in : inputs) total += in.size(); - auto per_out = static_cast(total / output_shapes_.size()); - auto result = output_shapes_; - for (auto& s : result) { - for (auto& d : s) { if (d == -1) d = per_out; } - } - return result; - } auto state() const { return std::make_tuple( name_, @@ -443,7 +422,6 @@ class CustomKernel : public Primitive { std::vector scalar_arguments_; bool is_precompiled_; int shared_memory_; - std::vector output_shapes_; }; } // namespace mlx::core::fast diff --git a/mlx/primitives.cpp b/mlx/primitives.cpp index 2848d8a46b..01ef6a2ffd 100644 --- a/mlx/primitives.cpp +++ b/mlx/primitives.cpp @@ -4781,31 +4781,6 @@ bool Slice::is_equivalent(const Primitive& other) const { end_indices_ == s_other.end_indices_ && strides_ == s_other.strides_); } -std::vector Slice::output_shapes(const std::vector& inputs) { - // Re-normalize slice bounds against the runtime input shape. - // Works for constant-dimension slices; variable-dimension slices - // should use take()/DynamicSlice instead. - auto& in_shape = inputs[0].shape(); - Shape out_shape(in_shape.size()); - for (int i = 0; i < in_shape.size(); ++i) { - auto n = in_shape[i]; - auto s = start_indices_[i] < 0 ? start_indices_[i] + n : start_indices_[i]; - auto e = end_indices_[i] < 0 ? end_indices_[i] + n : end_indices_[i]; - s = std::max(static_cast(0), std::min(s, n)); - e = std::max(static_cast(0), std::min(e, n)); - if (strides_[i] > 0) { - e = e < s ? s : e; - out_shape[i] = (e - s + strides_[i] - 1) / strides_[i]; - } else { - auto st = std::min(s, n - 1); - auto ed = e > -1 ? e : -1; - ed = ed > st ? st : ed; - out_shape[i] = (st - ed + (-strides_[i]) - 1) / (-strides_[i]); - } - } - return {std::move(out_shape)}; -} - std::pair, std::vector> SliceUpdate::vmap( const std::vector& inputs, const std::vector& axes) { diff --git a/mlx/primitives.h b/mlx/primitives.h index 1d417c918e..2ce92b90f3 100644 --- a/mlx/primitives.h +++ b/mlx/primitives.h @@ -2044,7 +2044,6 @@ class Slice : public UnaryPrimitive { DEFINE_GRADS() DEFINE_NAME(Slice) bool is_equivalent(const Primitive& other) const override; - std::vector output_shapes(const std::vector& inputs) override; auto state() const { return std::make_tuple(start_indices_, end_indices_, strides_); } From d68eafa234d76584312a209790e8ecedf280b925 Mon Sep 17 00:00:00 2001 From: pHequals7 Date: Tue, 24 Mar 2026 15:53:52 +0530 Subject: [PATCH 4/6] fix AddMM::output_shapes input ordering, add test the previous implementation returned inputs[0].shape() assuming inputs were ordered {c, a, b}, but AddMM receives them as {a, b, c} (see ops.cpp:5440). this meant the output shape was derived from a's shape without replacing the last dim with b's last dim. it happened to work in practice because c is validated to match the output shape at construction, so inputs[0].shape() (a's shape) and inputs[2].shape() (c's shape) only differ in the last dim, which was never tested with varying sizes under shapeless compile. the fix mirrors Matmul::output_shapes: take a's shape, replace the last dim with b.shape(-1). adds test_shapeless_compile_addmm in test_compile.py covering: - basic addmm inside compile(shapeless=True) - second call with different shapes (same ranks) - custom alpha/beta parameters Co-Authored-By: Claude Opus 4.6 (1M context) --- mlx/primitives.cpp | 7 ++++--- python/tests/test_compile.py | 26 ++++++++++++++++++++++++++ 2 files changed, 30 insertions(+), 3 deletions(-) diff --git a/mlx/primitives.cpp b/mlx/primitives.cpp index 01ef6a2ffd..cd0da1875b 100644 --- a/mlx/primitives.cpp +++ b/mlx/primitives.cpp @@ -346,9 +346,10 @@ bool AddMM::is_equivalent(const Primitive& other) const { } std::vector AddMM::output_shapes(const std::vector& inputs) { - // C (inputs[0]) is validated to match the output shape at construction - // in ops.cpp, so its shape is the output shape directly. - return {inputs[0].shape()}; + // inputs are {a, b, c}, output shape is a's shape with last dim from b + auto out_shape = inputs[0].shape(); + out_shape.back() = inputs[1].shape(-1); + return {std::move(out_shape)}; } std::pair, std::vector> AddMM::vmap( diff --git a/python/tests/test_compile.py b/python/tests/test_compile.py index 7db471cc03..eab7f965f2 100644 --- a/python/tests/test_compile.py +++ b/python/tests/test_compile.py @@ -884,6 +884,32 @@ def test_shapeless_compile_matmul(self): fun = mx.compile(lambda a, b: a @ b, shapeless=True) self.assertTrue(mx.allclose(fun(a, b), a @ b)) + def test_shapeless_compile_addmm(self): + def fun(c, a, b): + return mx.addmm(c, a, b) + + cfun = mx.compile(fun, shapeless=True) + + # First shape + c = mx.ones((2, 4)) + a = mx.ones((2, 3)) + b = mx.ones((3, 4)) + self.assertTrue(mx.allclose(cfun(c, a, b), fun(c, a, b))) + + # Different shape, same ranks — should not recompile + c = mx.ones((3, 5)) + a = mx.ones((3, 6)) + b = mx.ones((6, 5)) + self.assertTrue(mx.allclose(cfun(c, a, b), fun(c, a, b))) + + # With alpha and beta + fun2 = mx.compile(lambda c, a, b: mx.addmm(c, a, b, alpha=2.0, beta=3.0), shapeless=True) + c = mx.ones((2, 4)) + a = mx.ones((2, 3)) + b = mx.ones((3, 4)) + expected = 3.0 * c + 2.0 * (a @ b) + self.assertTrue(mx.allclose(fun2(c, a, b), expected)) + def test_shapeless_compile_slice_update(self): def fun(x): x[2] = mx.array([3.0]) From 575a52b1217e080dc00146bb12edfd5c6e73ecfc Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Tue, 24 Mar 2026 15:22:31 -0700 Subject: [PATCH 5/6] Fix linter error --- python/tests/test_compile.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/python/tests/test_compile.py b/python/tests/test_compile.py index eab7f965f2..2e694eba93 100644 --- a/python/tests/test_compile.py +++ b/python/tests/test_compile.py @@ -903,7 +903,10 @@ def fun(c, a, b): self.assertTrue(mx.allclose(cfun(c, a, b), fun(c, a, b))) # With alpha and beta - fun2 = mx.compile(lambda c, a, b: mx.addmm(c, a, b, alpha=2.0, beta=3.0), shapeless=True) + fun2 = mx.compile( + lambda c, a, b: mx.addmm(c, a, b, alpha=2.0, beta=3.0), + shapeless=True + ) c = mx.ones((2, 4)) a = mx.ones((2, 3)) b = mx.ones((3, 4)) From d69d59a2856c78a011f12d0e319e7f66e1bd910a Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Tue, 24 Mar 2026 15:24:20 -0700 Subject: [PATCH 6/6] Fix linter error --- python/tests/test_compile.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/python/tests/test_compile.py b/python/tests/test_compile.py index 2e694eba93..20f1145223 100644 --- a/python/tests/test_compile.py +++ b/python/tests/test_compile.py @@ -904,8 +904,7 @@ def fun(c, a, b): # With alpha and beta fun2 = mx.compile( - lambda c, a, b: mx.addmm(c, a, b, alpha=2.0, beta=3.0), - shapeless=True + lambda c, a, b: mx.addmm(c, a, b, alpha=2.0, beta=3.0), shapeless=True ) c = mx.ones((2, 4)) a = mx.ones((2, 3))