Conversation
zcbenz
left a comment
There was a problem hiding this comment.
Can you explain how you are handling broadcasting with vmap? The code looks very confusing.
The
After that, I can also add a few comments in the implementation to make that flow easier to follow. |
|
Well that is just a rephrasing of what the code does. To be honest I don't know if you actually understood the algorithm and explained it to me, or just pasted what AI said. I'm leaving this to @angeloskath to judge. |
Summary
Implements
BroadcastAxes::vmapinmlx/primitives.cpp, replacing the NYI stub and enabling vmapped behavior forBroadcastAxes-backed ops.What changed
BroadcastAxes::vmapin C++ core.ignore_axes_to the aligned representation.Tests
C++
tests/vmap_tests.cpp:test vmap broadcast axes primitive{0,1}and{-1,0}) and output-axis expectations.Python
python/tests/test_vmap.py:test_vmap_broadcast_toin_axes=0,in_axes=1, andin_axes=-1, out_axes=-1.take_along_axisandput_along_axisvmap cases.Closes the NYI stub at
mlx/primitives.cpp:907.