Skip to content

Add output_shapes for AddMM#3262

Merged
zcbenz merged 6 commits intoml-explore:mainfrom
pHequals7:fix/addmm-slice-customkernel-output-shapes
Mar 25, 2026
Merged

Add output_shapes for AddMM#3262
zcbenz merged 6 commits intoml-explore:mainfrom
pHequals7:fix/addmm-slice-customkernel-output-shapes

Conversation

@pHequals7
Copy link
Copy Markdown
Contributor

@pHequals7 pHequals7 commented Mar 16, 2026

Summary

Adds output_shapes() to AddMM, enabling compile(shapeless=True) for models with biased Linear layers.

Change

AddMM::output_shapes returns inputs[0].shape() (the C matrix shape), which is already validated to match the output shape at construction in ops.cpp.

Context

Most transformer models use biased Linear layers, which dispatch through AddMM. Without this, compile(shapeless=True) throws "primitive does not have shape inference implemented". This follows the same pattern as #2601 (Convolution::output_shapes) and #1727 (shapeless SliceUpdate + Broadcast).

Discovered while porting mlx-whisper to Swift — whisper-small has 145 biased Linear layers that all fail in shapeless compile without this.

Files Changed

File Lines Change
mlx/primitives.h +1 Declaration
mlx/primitives.cpp +6 Implementation

Previous scope

Originally included Slice and CustomKernel output_shapes — dropped per review feedback. CustomKernel shape inference via a callback function could be a separate discussion/PR.

Enable `compile(shapeless: true)` for models that use:

1. **AddMM** (biased Linear layers): Most transformer models use Linear
   with bias=true, which dispatches AddMM. Without output_shapes,
   compile(shapeless:true) fails. The fix matches Matmul::output_shapes.

2. **Slice** (array subscripting): Any compiled function that slices
   arrays (e.g., `array[0..<N]`) needs Slice::output_shapes. The
   implementation re-normalizes slice bounds against runtime input shape.
   Limited to constant-dimension slices; variable-dimension slices
   should use take()/DynamicSlice.

3. **CustomKernel** (metalKernel API): Custom Metal kernels created via
   the metalKernel() API can now work inside compile(shapeless:true).
   Output shapes are stored at construction time and returned during
   compile-time shape inference. A -1 sentinel in output shapes triggers
   dynamic computation from input sizes (total_input_size / num_outputs),
   enabling kernels with variable output sizes (e.g., KV cache append).

Discovered while porting mlx-whisper to Swift using mlx-swift. All three
primitives are essential for compiled inference with custom fused kernels.
C (inputs[0]) is already validated to match the output shape at
construction in ops.cpp, so we can return its shape directly
instead of recalculating from B's last dimension.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
if (resolved_shapes[i][j] == -1) resolved_shapes[i][j] = per_out;
}
}
}
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is not doing what you were thinking it was doing.

resolved_shapes.size() is the number of outputs. total is the sum of number of elements of all inputs. What is per_out supposed to represent?

Say for instance I write a custom kernel that adds two arrays. per_out would be 2 times the input size 🤷‍♂️

The only way to do this properly is to pass a function that computes the output shapes from the input shapes. If this function is passed then shapeless compilation of the custom kernel will be automatically enabled otherwise not.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah you're right, the -1 sentinel was a hack that only worked for my specific kv cache concat case. a shape inference function passed to metalKernel() makes way more sense as a general api.

stripped this pr down to just AddMM which is the straightforward one. happy to open a separate issue for the CustomKernel shape inference function if that's useful — or leave it for someone with better context on the compile internals.

Comment on lines +4786 to +4787
// Works for constant-dimension slices; variable-dimension slices
// should use take()/DynamicSlice instead.
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
// Works for constant-dimension slices; variable-dimension slices
// should use take()/DynamicSlice instead.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done, dropped it. makes sense that constant slices don't need this.

Drop Slice and CustomKernel changes per review feedback:
- Slice::output_shapes unnecessary for constant-dimension slices
- CustomKernel -1 sentinel is not a general solution; proper
  approach is a shape inference function (separate discussion)

Keeping only AddMM::output_shapes which is straightforward —
C's shape is already validated to match the output at construction.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
@pHequals7 pHequals7 changed the title Add output_shapes for AddMM, Slice, and CustomKernel Add output_shapes for AddMM Mar 19, 2026
Copy link
Copy Markdown
Member

@angeloskath angeloskath left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! Could you add a test as well? One in python/tests/test_compile.py should be enough.

the previous implementation returned inputs[0].shape() assuming
inputs were ordered {c, a, b}, but AddMM receives them as {a, b, c}
(see ops.cpp:5440). this meant the output shape was derived from a's
shape without replacing the last dim with b's last dim.

it happened to work in practice because c is validated to match the
output shape at construction, so inputs[0].shape() (a's shape) and
inputs[2].shape() (c's shape) only differ in the last dim, which
was never tested with varying sizes under shapeless compile.

the fix mirrors Matmul::output_shapes: take a's shape, replace the
last dim with b.shape(-1).

adds test_shapeless_compile_addmm in test_compile.py covering:
- basic addmm inside compile(shapeless=True)
- second call with different shapes (same ranks)
- custom alpha/beta parameters

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
@pHequals7
Copy link
Copy Markdown
Contributor Author

added a test in python/tests/test_compile.py (test_shapeless_compile_addmm).

writing the test actually caught a bug in the previous implementation: AddMM receives inputs as {a, b, c} (ops.cpp:5440), not {c, a, b}, so inputs[0].shape() was returning a's shape, not c's. it worked by coincidence because c is validated to match the output shape at construction, so a's shape and c's shape only differ in the last dim, which was never exercised with varying sizes.

fixed to mirror Matmul::output_shapes: take a's shape, replace the last dim with b.shape(-1).

the test covers:

  • basic addmm inside compile(shapeless=True)
  • a second call with different shapes (same ranks) to verify shapeless reuse
  • custom alpha/beta parameters

Copy link
Copy Markdown
Member

@angeloskath angeloskath left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great! Thanks for adding the tests.

I am gonna leave this here for posterity but there is a general usage bug currently with MLX shapeless compile due to the ops actually not producing the same graph every time. For instance if broadcast is a noop then it is not added to the graph which ends up causing errors down the road (silent ones as well).

@zcbenz zcbenz merged commit bd200d6 into ml-explore:main Mar 25, 2026
16 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants