-
Couldn't load subscription status.
- Fork 146
Add more MLX dispatches #1684
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add more MLX dispatches #1684
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR adds MLX backend support for cumulative operations, repeat, and sorting operations. The changes enable PyTensor to compile these operations to Apple's MLX framework, expanding the set of operations available when using the MLX backend.
Key changes:
- Implements MLX dispatches for
CumOp(cumsum/cumprod),Repeat,SortOp, andArgSortOp - Adds comprehensive test coverage for the new MLX dispatches
- Registers the new dispatch modules in the MLX backend initialization
Reviewed Changes
Copilot reviewed 5 out of 5 changed files in this pull request and generated no comments.
Show a summary per file
| File | Description |
|---|---|
pytensor/link/mlx/dispatch/extra_ops.py |
Implements MLX dispatches for cumulative operations and repeat |
pytensor/link/mlx/dispatch/sort.py |
Implements MLX dispatches for sort and argsort operations |
pytensor/link/mlx/dispatch/__init__.py |
Registers the new dispatch modules |
tests/link/mlx/test_extra_ops.py |
Tests for cumsum, cumprod, and repeat with MLX backend |
tests/link/mlx/test_sort.py |
Tests for sort and argsort with MLX backend |
extra_ops and sort
Codecov Report❌ Patch coverage is
❌ Your patch status has failed because the patch coverage (80.70%) is below the target coverage (100.00%). You can increase the patch coverage or adjust the target coverage. Additional details and impacted files@@ Coverage Diff @@
## main #1684 +/- ##
==========================================
+ Coverage 81.61% 81.64% +0.02%
==========================================
Files 242 244 +2
Lines 53533 53590 +57
Branches 9433 9438 +5
==========================================
+ Hits 43691 43752 +61
+ Misses 7366 7356 -10
- Partials 2476 2482 +6
🚀 New features to boost your workflow:
|
| axis = op.axis | ||
|
|
||
| def repeat(x, repeats, axis=axis): | ||
| if not isinstance(repeats, int): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is known at dispatch time, raise then?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
repeats is a symbolic input. We only know axis as dispatch time.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's a weird mlx-specific limitation. There might be a work-around, but I don't want to do it in this PR. Just getting some common cases covered.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you know whether repeats is 0d or 1d at dispatch time? actually isn't the op always 1d now and we use alloc for 0d?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How? It's a symbolic input
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh by checking op.inputs[1] :fivehead:
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
by checking node.inputs. Node is the second argument to all dispatches
Description
PR adds MLX dispatches for:
I went through the list of Ops in
jax.linker.dispatch.extra_ops.py, looking for more stuff to add, but didn't find much in themlx.corenamespace. I checked forsearchsorted,unique,bartlett, anddiff, but didn't find anything. Maybe I looked in the wrong place?Related Issue
mlx#1350Checklist
Type of change
📚 Documentation preview 📚: https://pytensor--1684.org.readthedocs.build/en/1684/