Skip to content

Slice update with operation#3266

Merged
angeloskath merged 19 commits intomainfrom
slice-update
Mar 18, 2026
Merged

Slice update with operation#3266
angeloskath merged 19 commits intomainfrom
slice-update

Conversation

@angeloskath
Copy link
Copy Markdown
Member

@angeloskath angeloskath commented Mar 16, 2026

Adds slice_update_op variants. Allow for faster implementation of slice updates that don't fall back to scatter. The CPU and CUDA implementations are still missing (will add them before merging).

There are still optimizations to be done but the first numbers are (M3U gpu)

Dtype        Dst Shape                 Update Shape       MLX^- (ms)   MLX (ms)    Torch (ms)   MLX^- GB/s   MLX GB/s     Torch GB/s
-------------------------------------------------------------------------------------------------------------------------------------
float32      (10000000,)               (1000000,)         0.870        0.699       1.953        413.85       515.34       184.32
float32      (100000,)                 (10000,)           0.466        0.454       1.804        7.73         7.94         2.00
float32      (1000, 64)                (100, 64)          0.482        0.456       1.812        4.78         5.05         1.27
float32      (100, 100, 64)            (20, 100, 64)      0.541        0.478       1.857        85.17        96.46        24.81
float32      (2048, 2048, 128)         (1000, 1000, 64)   37.590       33.338      77.250       612.93       691.10       298.25
float32      (2048, 2048, 128)         (50, 100, 64)      1.301        1.218       1.080        88.54        94.55        106.70
float32      (2048, 2048, 128)         (10, 10, 64)       1.276        1.187       0.698        1.81         1.94         3.30
bfloat16     (10000000,)               (1000000,)         1.203        0.634       1.971        149.60       283.77       91.32
bfloat16     (100000,)                 (10000,)           0.524        0.489       1.935        3.44         3.68         0.93
bfloat16     (1000, 64)                (100, 64)          0.526        0.453       1.908        2.19         2.54         0.60
bfloat16     (100, 100, 64)            (20, 100, 64)      0.603        0.497       1.936        38.21        46.35        11.90
bfloat16     (2048, 2048, 128)         (1000, 1000, 64)   36.939       17.693      76.855       311.87       651.10       149.89
bfloat16     (2048, 2048, 128)         (50, 100, 64)      1.407        1.214       1.098        40.94        47.46        52.46
bfloat16     (2048, 2048, 128)         (10, 10, 64)       1.287        1.209       0.737        0.89         0.95         1.56

where MLX^- means before this PR so it converts the slices to index arrays and uses scatter.

One of the main benefits of this PR is that changing code like x[idx] += 2 to x = x.at[idx].add(2) will almost certainly be significantly more efficient now since it will allow donating x.

The CPU version gets a pretty big boost as it is much simpler to implement (and I added a small SIMD optimization). M3 Ultra numbers below:

Dtype        Dst Shape                 Update Shape         MLX^- (ms)   MLX (ms)   Torch (ms)   MLX^- GB/s   MLX GB/s   Torch GB/s
------------------------------------------------------------------------------------------------------------------------------------
float32      (10000000,)               (1000000,)           54.328       2.653      6.468        6.63         135.67     55.66
float32      (100000,)                 (10000,)             0.542        0.078      0.099        6.65         45.98      36.18
float32      (1000, 64)                (100, 64)            0.353        0.075      0.076        6.53         30.85      30.16
float32      (100, 100, 64)            (20, 100, 64)        5.633        0.376      6.129        8.18         122.47     7.52
bfloat16     (10000000,)               (1000000,)           52.641       4.251      6.469        3.42         42.35      27.82
bfloat16     (100000,)                 (10000,)             0.441        0.101      0.179        4.08         17.89      10.05
bfloat16     (1000, 64)                (100, 64)            0.308        0.086      0.134        3.74         13.33      8.57
bfloat16     (100, 100, 64)            (20, 100, 64)        4.720        0.596      6.117        4.88         38.68      3.77

@angeloskath angeloskath requested a review from nastya236 March 16, 2026 11:32
@angeloskath
Copy link
Copy Markdown
Member Author

Fixed CUDA. This is the benchmark on H100 now, pretty clearly faster than PT

Dtype        Dst Shape                 Update Shape         MLX (ms)     MLX GB/s     Torch (ms)   Torch GB/s
--------------------------------------------------------------------------------------------------------------
float32      (10000000,)               (1000000,)           0.504        952.71       0.899        534.16
float32      (100000,)                 (10000,)             0.333        14.39        0.793        6.05
float32      (1000, 64)                (100, 64)            0.328        9.36         0.802        3.83
float32      (100, 100, 64)            (20, 100, 64)        0.342        179.84       0.811        75.73
float32      (2048, 2048, 128)         (1000, 1000, 64)     16.224       1893.50      18.634       1648.64
float32      (2048, 2048, 128)         (50, 100, 64)        0.437        351.49       0.935        164.31
float32      (2048, 2048, 128)         (10, 10, 64)         0.376        8.16         0.924        3.33
bfloat16     (10000000,)               (1000000,)           0.379        634.03       0.853        281.47
bfloat16     (100000,)                 (10000,)             0.307        7.83         0.799        3.00
bfloat16     (1000, 64)                (100, 64)            0.299        5.14         0.805        1.91
bfloat16     (100, 100, 64)            (20, 100, 64)        0.307        100.21       0.807        38.09
bfloat16     (2048, 2048, 128)         (1000, 1000, 64)     9.292        1653.09      13.671       1123.53
bfloat16     (2048, 2048, 128)         (50, 100, 64)        1.900        40.42        0.883        86.94
bfloat16     (2048, 2048, 128)         (10, 10, 64)         0.325        4.72         0.886        1.73

Copy link
Copy Markdown
Collaborator

@zcbenz zcbenz left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Beautiful change

@angeloskath angeloskath merged commit 7bc61cc into main Mar 18, 2026
16 checks passed
@angeloskath angeloskath deleted the slice-update branch March 18, 2026 13:18
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants