Skip to content

Implement BroadcastAxes::vmap#3319

Open
Aristide021 wants to merge 4 commits intoml-explore:mainfrom
Aristide021:broadcastaxes-vmap
Open

Implement BroadcastAxes::vmap#3319
Aristide021 wants to merge 4 commits intoml-explore:mainfrom
Aristide021:broadcastaxes-vmap

Conversation

@Aristide021
Copy link
Copy Markdown
Contributor

Summary

Implements BroadcastAxes::vmap in mlx/primitives.cpp, replacing the NYI stub and enabling vmapped behavior for BroadcastAxes-backed ops.

What changed

  • Implemented BroadcastAxes::vmap in C++ core.
  • Handles mixed vmapped/unmapped inputs by:
    • expanding inputs to a common rank,
    • aligning vmapped axes via transpose,
    • remapping ignore_axes_ to the aligned representation.
  • Returns correctly positioned output vmap axis.

Tests

C++

  • Added direct primitive coverage in tests/vmap_tests.cpp:
    • test vmap broadcast axes primitive
    • covers representative axis patterns ({0,1} and {-1,0}) and output-axis expectations.

Python

  • Added direct API coverage in python/tests/test_vmap.py:
    • test_vmap_broadcast_to
    • covers in_axes=0, in_axes=1, and in_axes=-1, out_axes=-1.
  • Expanded related vmap indexing coverage:
    • additional higher-rank take_along_axis and put_along_axis vmap cases.

Closes the NYI stub at mlx/primitives.cpp:907.

Copy link
Copy Markdown
Collaborator

@zcbenz zcbenz left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you explain how you are handling broadcasting with vmap? The code looks very confusing.

@Aristide021
Copy link
Copy Markdown
Contributor Author

Can you explain how you are handling broadcasting with vmap? The code looks very confusing.

The vmap logic here is just normalizing all inputs into a canonical batched layout, then reusing the existing output_shape() logic unchanged.

  1. If all inputs are unbatched (axes == -1), it preserves the old path exactly.
  2. Otherwise it computes ndim = max(input.ndim() + (axis == -1)), so unbatched inputs reserve one extra dimension for the mapped axis.
  3. Each input is left-padded with leading size-1 dimensions to reach ndim.
  4. It then aligns the mapped axis of every input to a single physical position to_ax:
    • if an input is already vmapped, its existing mapped axis is moved to to_ax
    • if an input is unvmapped, the inserted singleton dimension serves as its mapped axis
  5. ignore_axes_ is defined relative to the original unbatched layout, so after inserting/aligning the mapped axis it has to be remapped into the normalized layout. The code converts each ignored axis to a positive index in the unbatched view, shifts it if the mapped axis was inserted before it, applies the left-padding offset, then converts it back to the negative indexing convention expected by output_shape().

After that, output_shape(aligned_inputs, remapped_ignore_axes) computes the same broadcast result as before, just on the normalized batched representation, and the output mapped axis is to_ax.

I can also add a few comments in the implementation to make that flow easier to follow.

@zcbenz
Copy link
Copy Markdown
Collaborator

zcbenz commented Mar 26, 2026

Well that is just a rephrasing of what the code does. To be honest I don't know if you actually understood the algorithm and explained it to me, or just pasted what AI said. I'm leaving this to @angeloskath to judge.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants