Skip to content

Conversation

@jessegrabowski
Copy link
Member

@jessegrabowski jessegrabowski commented Oct 22, 2025

Description

PR adds MLX dispatches for:

  • CumOp
  • Repeat
  • SortOp
  • ArgSortOp
  • LogSoftmax

I went through the list of Ops in jax.linker.dispatch.extra_ops.py, looking for more stuff to add, but didn't find much in the mlx.core namespace. I checked for searchsorted, unique, bartlett, and diff, but didn't find anything. Maybe I looked in the wrong place?

Related Issue

Checklist

Type of change

  • New feature / enhancement
  • Bug fix
  • Documentation
  • Maintenance
  • Other (please specify):

📚 Documentation preview 📚: https://pytensor--1684.org.readthedocs.build/en/1684/

Copy link

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR adds MLX backend support for cumulative operations, repeat, and sorting operations. The changes enable PyTensor to compile these operations to Apple's MLX framework, expanding the set of operations available when using the MLX backend.

Key changes:

  • Implements MLX dispatches for CumOp (cumsum/cumprod), Repeat, SortOp, and ArgSortOp
  • Adds comprehensive test coverage for the new MLX dispatches
  • Registers the new dispatch modules in the MLX backend initialization

Reviewed Changes

Copilot reviewed 5 out of 5 changed files in this pull request and generated no comments.

Show a summary per file
File Description
pytensor/link/mlx/dispatch/extra_ops.py Implements MLX dispatches for cumulative operations and repeat
pytensor/link/mlx/dispatch/sort.py Implements MLX dispatches for sort and argsort operations
pytensor/link/mlx/dispatch/__init__.py Registers the new dispatch modules
tests/link/mlx/test_extra_ops.py Tests for cumsum, cumprod, and repeat with MLX backend
tests/link/mlx/test_sort.py Tests for sort and argsort with MLX backend

@jessegrabowski jessegrabowski changed the title MLX dispatch for extra_ops and sort Add more MLX dispatches Oct 22, 2025
@codecov
Copy link

codecov bot commented Oct 22, 2025

Codecov Report

❌ Patch coverage is 80.70175% with 11 lines in your changes missing coverage. Please review.
✅ Project coverage is 81.64%. Comparing base (3082ed5) to head (cad28b6).
⚠️ Report is 2 commits behind head on main.

Files with missing lines Patch % Lines
pytensor/link/mlx/dispatch/extra_ops.py 62.50% 8 Missing and 1 partial ⚠️
pytensor/link/mlx/dispatch/sort.py 90.00% 1 Missing and 1 partial ⚠️

❌ Your patch status has failed because the patch coverage (80.70%) is below the target coverage (100.00%). You can increase the patch coverage or adjust the target coverage.

Additional details and impacted files

Impacted file tree graph

@@            Coverage Diff             @@
##             main    #1684      +/-   ##
==========================================
+ Coverage   81.61%   81.64%   +0.02%     
==========================================
  Files         242      244       +2     
  Lines       53533    53590      +57     
  Branches     9433     9438       +5     
==========================================
+ Hits        43691    43752      +61     
+ Misses       7366     7356      -10     
- Partials     2476     2482       +6     
Files with missing lines Coverage Δ
pytensor/link/mlx/dispatch/__init__.py 100.00% <100.00%> (ø)
pytensor/link/mlx/dispatch/core.py 63.69% <100.00%> (+8.92%) ⬆️
pytensor/link/mlx/dispatch/elemwise.py 77.73% <100.00%> (+0.64%) ⬆️
pytensor/link/mlx/dispatch/sort.py 90.00% <90.00%> (ø)
pytensor/link/mlx/dispatch/extra_ops.py 62.50% <62.50%> (ø)

... and 2 files with indirect coverage changes

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

axis = op.axis

def repeat(x, repeats, axis=axis):
if not isinstance(repeats, int):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is known at dispatch time, raise then?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

repeats is a symbolic input. We only know axis as dispatch time.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a weird mlx-specific limitation. There might be a work-around, but I don't want to do it in this PR. Just getting some common cases covered.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you know whether repeats is 0d or 1d at dispatch time? actually isn't the op always 1d now and we use alloc for 0d?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How? It's a symbolic input

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh by checking op.inputs[1] :fivehead:

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

by checking node.inputs. Node is the second argument to all dispatches

@jessegrabowski jessegrabowski merged commit ab5037e into pymc-devs:main Oct 24, 2025
57 of 58 checks passed
@jessegrabowski jessegrabowski deleted the mlx-extra-ops branch October 24, 2025 01:52
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants