Skip to content

Fix moved-from shape bug in broadcast_arrays causing vmap bus error#3310

Merged
angeloskath merged 1 commit intoml-explore:mainfrom
Aristide021:vmap-take-along-axis-crash
Mar 25, 2026
Merged

Fix moved-from shape bug in broadcast_arrays causing vmap bus error#3310
angeloskath merged 1 commit intoml-explore:mainfrom
Aristide021:vmap-take-along-axis-crash

Conversation

@Aristide021
Copy link
Copy Markdown
Contributor

Fixes a moved-from shape bug in broadcast_arrays that caused a bus
error when using vmap with take_along_axis and in_axes=(None, 0)
on higher-rank index tensors.

Root cause

In mlx/ops.cpp, std::move(out_shape) was passed to the output array
constructor before out_shape was used to construct the Broadcast
primitive. Because argument evaluation order is unspecified in C++, the
primitive could be constructed with moved-from out_shape (observed as
shape_ = {}), causing undefined behavior under vmap.

Fix

Copy out_shape instead of moving it so both the output array and the
Broadcast primitive receive the correct shape.

Regression tests

  • C++: test vmap take_along_axis with unmapped input and mapped index
  • Python: extended test_vmap_take_along_axis with the higher-rank
    in_axes=(None, 0) case

Reproducer (verified on clean upstream/main before fix)

import mlx.core as mx
a = mx.arange(4*5*3).reshape(4,5,3)
idx = mx.zeros((2,2,1,3), dtype=mx.int32)
out = mx.vmap(lambda x,y: mx.take_along_axis(x,y,axis=0), in_axes=(None,0))(a, idx)
# Previously: Bus error (core dumped)
# After fix: out.shape == (2, 2, 5, 3)

Closes #3309.

Copy link
Copy Markdown
Member

@angeloskath angeloskath left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch thanks!

@hnshah
Copy link
Copy Markdown

hnshah commented Mar 25, 2026

Validation on M3 Ultra 256GB

I've validated the vmap + take_along_axis fix on M3 Ultra:

Test Hardware:

  • Mac Studio M3 Ultra (256GB)
  • macOS 25.3.0 (Darwin 25.3.0)
  • MLX: 0.31.2.dev20260324+63b73e7a (from your vmap-take-along-axis-crash branch)

Reproducer Test (Issue #3309):

import mlx.core as mx
a = mx.arange(4*5*3).reshape(4,5,3)
idx = mx.zeros((2,2,1,3), dtype=mx.int32)
out = mx.vmap(lambda x,y: mx.take_along_axis(x,y,axis=0), in_axes=(None,0))(a, idx)

Before PR #3310:

Command aborted by signal SIGBUS

After PR #3310:

✅ No crash! Output shape: (2, 2, 5, 3)

Additional Tests:

Test Case Result
Original reproducer (in_axes=(None,0)) ✅ Pass
Both inputs mapped (in_axes=(0,0)) ✅ Pass

Key Finding:

The bus error is completely resolved. The moved-from shape bug in broadcast_arrays is fixed - both the output array and the Broadcast primitive now receive the correct shape.

Successfully fixes Issue #3309! 🎯

@angeloskath angeloskath merged commit 1b1c563 into ml-explore:main Mar 25, 2026
16 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Bus error in vmap + take_along_axis with in_axes=(None, 0) on higher-rank index

3 participants