Fix moved-from shape bug in broadcast_arrays causing vmap bus error#3310
Merged
angeloskath merged 1 commit intoml-explore:mainfrom Mar 25, 2026
Merged
Conversation
Validation on M3 Ultra 256GBI've validated the vmap + take_along_axis fix on M3 Ultra: Test Hardware:
Reproducer Test (Issue #3309): import mlx.core as mx
a = mx.arange(4*5*3).reshape(4,5,3)
idx = mx.zeros((2,2,1,3), dtype=mx.int32)
out = mx.vmap(lambda x,y: mx.take_along_axis(x,y,axis=0), in_axes=(None,0))(a, idx)Before PR #3310: After PR #3310: Additional Tests:
Key Finding:
Successfully fixes Issue #3309! 🎯 |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Fixes a moved-from shape bug in
broadcast_arraysthat caused a buserror when using
vmapwithtake_along_axisandin_axes=(None, 0)on higher-rank index tensors.
Root cause
In
mlx/ops.cpp,std::move(out_shape)was passed to the output arrayconstructor before
out_shapewas used to construct theBroadcastprimitive. Because argument evaluation order is unspecified in C++, the
primitive could be constructed with moved-from
out_shape(observed asshape_ = {}), causing undefined behavior under vmap.Fix
Copy
out_shapeinstead of moving it so both the output array and theBroadcastprimitive receive the correct shape.Regression tests
test vmap take_along_axis with unmapped input and mapped indextest_vmap_take_along_axiswith the higher-rankin_axes=(None, 0)caseReproducer (verified on clean upstream/main before fix)
Closes #3309.