Conversation
Enable `compile(shapeless: true)` for models that use: 1. **AddMM** (biased Linear layers): Most transformer models use Linear with bias=true, which dispatches AddMM. Without output_shapes, compile(shapeless:true) fails. The fix matches Matmul::output_shapes. 2. **Slice** (array subscripting): Any compiled function that slices arrays (e.g., `array[0..<N]`) needs Slice::output_shapes. The implementation re-normalizes slice bounds against runtime input shape. Limited to constant-dimension slices; variable-dimension slices should use take()/DynamicSlice. 3. **CustomKernel** (metalKernel API): Custom Metal kernels created via the metalKernel() API can now work inside compile(shapeless:true). Output shapes are stored at construction time and returned during compile-time shape inference. A -1 sentinel in output shapes triggers dynamic computation from input sizes (total_input_size / num_outputs), enabling kernels with variable output sizes (e.g., KV cache append). Discovered while porting mlx-whisper to Swift using mlx-swift. All three primitives are essential for compiled inference with custom fused kernels.
C (inputs[0]) is already validated to match the output shape at construction in ops.cpp, so we can return its shape directly instead of recalculating from B's last dimension. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
mlx/backend/metal/custom_kernel.cpp
Outdated
| if (resolved_shapes[i][j] == -1) resolved_shapes[i][j] = per_out; | ||
| } | ||
| } | ||
| } |
There was a problem hiding this comment.
I think this is not doing what you were thinking it was doing.
resolved_shapes.size() is the number of outputs. total is the sum of number of elements of all inputs. What is per_out supposed to represent?
Say for instance I write a custom kernel that adds two arrays. per_out would be 2 times the input size 🤷♂️
The only way to do this properly is to pass a function that computes the output shapes from the input shapes. If this function is passed then shapeless compilation of the custom kernel will be automatically enabled otherwise not.
There was a problem hiding this comment.
yeah you're right, the -1 sentinel was a hack that only worked for my specific kv cache concat case. a shape inference function passed to metalKernel() makes way more sense as a general api.
stripped this pr down to just AddMM which is the straightforward one. happy to open a separate issue for the CustomKernel shape inference function if that's useful — or leave it for someone with better context on the compile internals.
mlx/primitives.cpp
Outdated
| // Works for constant-dimension slices; variable-dimension slices | ||
| // should use take()/DynamicSlice instead. |
There was a problem hiding this comment.
| // Works for constant-dimension slices; variable-dimension slices | |
| // should use take()/DynamicSlice instead. |
There was a problem hiding this comment.
done, dropped it. makes sense that constant slices don't need this.
Drop Slice and CustomKernel changes per review feedback: - Slice::output_shapes unnecessary for constant-dimension slices - CustomKernel -1 sentinel is not a general solution; proper approach is a shape inference function (separate discussion) Keeping only AddMM::output_shapes which is straightforward — C's shape is already validated to match the output at construction. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
angeloskath
left a comment
There was a problem hiding this comment.
Thanks! Could you add a test as well? One in python/tests/test_compile.py should be enough.
the previous implementation returned inputs[0].shape() assuming
inputs were ordered {c, a, b}, but AddMM receives them as {a, b, c}
(see ops.cpp:5440). this meant the output shape was derived from a's
shape without replacing the last dim with b's last dim.
it happened to work in practice because c is validated to match the
output shape at construction, so inputs[0].shape() (a's shape) and
inputs[2].shape() (c's shape) only differ in the last dim, which
was never tested with varying sizes under shapeless compile.
the fix mirrors Matmul::output_shapes: take a's shape, replace the
last dim with b.shape(-1).
adds test_shapeless_compile_addmm in test_compile.py covering:
- basic addmm inside compile(shapeless=True)
- second call with different shapes (same ranks)
- custom alpha/beta parameters
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
|
added a test in writing the test actually caught a bug in the previous implementation: fixed to mirror the test covers:
|
angeloskath
left a comment
There was a problem hiding this comment.
Looks great! Thanks for adding the tests.
I am gonna leave this here for posterity but there is a general usage bug currently with MLX shapeless compile due to the ops actually not producing the same graph every time. For instance if broadcast is a noop then it is not added to the graph which ends up causing errors down the road (silent ones as well).
Summary
Adds
output_shapes()toAddMM, enablingcompile(shapeless=True)for models with biased Linear layers.Change
AddMM::output_shapesreturnsinputs[0].shape()(the C matrix shape), which is already validated to match the output shape at construction inops.cpp.Context
Most transformer models use biased Linear layers, which dispatch through AddMM. Without this,
compile(shapeless=True)throws "primitive does not have shape inference implemented". This follows the same pattern as #2601 (Convolution::output_shapes) and #1727 (shapeless SliceUpdate + Broadcast).Discovered while porting mlx-whisper to Swift — whisper-small has 145 biased Linear layers that all fail in shapeless compile without this.
Files Changed
mlx/primitives.hmlx/primitives.cppPrevious scope
Originally included Slice and CustomKernel
output_shapes— dropped per review feedback. CustomKernel shape inference via a callback function could be a separate discussion/PR.