Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 60 additions & 1 deletion mlx/primitives.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -904,7 +904,66 @@ std::vector<array> BroadcastAxes::jvp(
std::pair<std::vector<array>, std::vector<int>> BroadcastAxes::vmap(
const std::vector<array>& inputs,
const std::vector<int>& axes) {
throw std::invalid_argument("[BroadcastAxes] VMAP NYI");
assert(inputs.size() == axes.size());
assert(!inputs.empty());

if (std::all_of(axes.begin(), axes.end(), [](int ax) { return ax == -1; })) {
return {
{array(
output_shape(inputs, ignore_axes_),
inputs[0].dtype(),
std::make_shared<BroadcastAxes>(stream(), ignore_axes_),
inputs)},
{-1}};
}

int ndim = 0;
for (int i = 0; i < inputs.size(); ++i) {
ndim = std::max(ndim, static_cast<int>(inputs[i].ndim()) + (axes[i] == -1));
}

auto expand_dims = [this, ndim](const array& in) {
auto shape = in.shape();
shape.insert(shape.begin(), ndim - shape.size(), 1);
return reshape(in, std::move(shape), stream());
};

auto aligned_inputs = inputs;
int to_ax = (ndim - static_cast<int>(inputs[0].ndim())) + axes[0];
for (int i = 0; i < aligned_inputs.size(); ++i) {
int from_ax = (ndim - static_cast<int>(inputs[i].ndim())) + axes[i];
aligned_inputs[i] = expand_dims(inputs[i]);

if (from_ax != to_ax) {
std::vector<int> tdims(aligned_inputs[i].ndim());
std::iota(tdims.begin(), tdims.end(), 0);
tdims.erase(tdims.begin() + from_ax);
tdims.insert(tdims.begin() + to_ax, from_ax);
aligned_inputs[i] = transpose(aligned_inputs[i], tdims, stream());
}
}

int prefix = ndim - static_cast<int>(inputs[0].ndim());
int unbatched_ndim = static_cast<int>(inputs[0].ndim()) - (axes[0] >= 0);
std::vector<int> ignore_axes;
ignore_axes.reserve(ignore_axes_.size());
// Reexpress ignore_axes_ in the normalized batched layout.
for (auto ax : ignore_axes_) {
auto pos_ax = unbatched_ndim + ax;
if (axes[0] >= 0 && pos_ax >= axes[0]) {
pos_ax++;
}
pos_ax += prefix;
ignore_axes.push_back(pos_ax - ndim);
}

return {
{array(
output_shape(aligned_inputs, ignore_axes),
aligned_inputs[0].dtype(),
std::make_shared<BroadcastAxes>(stream(), ignore_axes),
std::move(aligned_inputs))},
{to_ax}};
}

bool BroadcastAxes::is_equivalent(const Primitive& other) const {
Expand Down
67 changes: 67 additions & 0 deletions python/tests/test_vmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,30 @@ def test_vmap_reduce(self):
out = mx.vmap(lambda x: x.sum(axis=(0, 1)), in_axes=(2,))(a)
self.assertTrue(mx.array_equal(out, mx.full((2,), 15)))

def test_vmap_broadcast_to(self):
x = mx.arange(2 * 3 * 1 * 5).reshape(2, 3, 1, 5)

out = mx.vmap(lambda a: mx.broadcast_to(a, (3, 4, 5)), in_axes=0)(x)
expected = mx.stack(
[mx.broadcast_to(x[i], (3, 4, 5)) for i in range(x.shape[0])]
)
self.assertTrue(mx.array_equal(out, expected))

out = mx.vmap(lambda a: mx.broadcast_to(a, (2, 4, 5)), in_axes=1)(x)
expected = mx.stack(
[mx.broadcast_to(x[:, i, :, :], (2, 4, 5)) for i in range(x.shape[1])]
)
self.assertTrue(mx.array_equal(out, expected))

out = mx.vmap(lambda a: mx.broadcast_to(a, (2, 3, 4)), in_axes=-1, out_axes=-1)(
x
)
expected = mx.stack(
[mx.broadcast_to(x[:, :, :, i], (2, 3, 4)) for i in range(x.shape[-1])],
axis=-1,
)
self.assertTrue(mx.array_equal(out, expected))

def test_vmap_argreduce(self):
a = mx.array([[1, 2, 3], [2, 3, 1]])
out = mx.vmap(lambda x: mx.argmin(x))(a)
Expand Down Expand Up @@ -595,6 +619,25 @@ def fun(a, idx):
out = mx.vmap(fun, in_axes=(None, 0))(a, idx)
self.assertEqual(out.shape, (4, 2, 1))

a = mx.arange(3 * 2 * 5 * 4).reshape(3, 2, 5, 4)
idx = mx.zeros((3, 2, 1, 4), mx.int32)
out = mx.vmap(lambda x, y: mx.take_along_axis(x, y, axis=1), in_axes=(0, 0))(
a, idx
)
expected = mx.stack(
[mx.take_along_axis(a[i], idx[i], axis=1) for i in range(a.shape[0])]
)
self.assertTrue(mx.array_equal(out, expected))

idx = mx.zeros((3, 2, 5, 1), mx.int32)
out = mx.vmap(lambda x, y: mx.take_along_axis(x, y, axis=-1), in_axes=(0, 0))(
a, idx
)
expected = mx.stack(
[mx.take_along_axis(a[i], idx[i], axis=-1) for i in range(a.shape[0])]
)
self.assertTrue(mx.array_equal(out, expected))

a = mx.zeros((4, 5, 3))
idx = mx.zeros((2, 2, 1, 3), mx.int32)

Expand Down Expand Up @@ -627,6 +670,30 @@ def fun(a, idx, upd):
out = mx.vmap(fun, in_axes=(None, 1, 1))(a, idx, upd)
self.assertEqual(out.shape, (4, 5, 1))

a = mx.zeros((3, 2, 5, 4))
idx = mx.zeros((3, 2, 1, 4), mx.int32)
upd = mx.ones((3, 2, 1, 4))
out = mx.vmap(
lambda x, y, z: mx.put_along_axis(x, y, z, axis=1), in_axes=(0, 0, 0)
)(a, idx, upd)
expected = mx.stack(
[mx.put_along_axis(a[i], idx[i], upd[i], axis=1) for i in range(a.shape[0])]
)
self.assertTrue(mx.array_equal(out, expected))

idx = mx.zeros((3, 2, 5, 1), mx.int32)
upd = mx.ones((3, 2, 5, 1))
out = mx.vmap(
lambda x, y, z: mx.put_along_axis(x, y, z, axis=-1), in_axes=(0, 0, 0)
)(a, idx, upd)
expected = mx.stack(
[
mx.put_along_axis(a[i], idx[i], upd[i], axis=-1)
for i in range(a.shape[0])
]
)
self.assertTrue(mx.array_equal(out, expected))

def test_vmap_split_vmap(self):
def fun(x):
a, b = mx.split(x, 2, 1)
Expand Down
28 changes: 28 additions & 0 deletions tests/vmap_tests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#include "doctest/doctest.h"

#include "mlx/mlx.h"
#include "mlx/primitives.h"

using namespace mlx::core;

Expand Down Expand Up @@ -237,6 +238,33 @@ TEST_CASE("test vmap with eval") {
CHECK_THROWS(vmap(fun2)({x, y}));
}

TEST_CASE("test vmap broadcast axes primitive") {
auto s = default_stream(default_device());
{
auto p = BroadcastAxes(s, {-1});
auto x = reshape(arange(2 * 3 * 1 * 5, float32, s), {2, 3, 1, 5}, s);
auto y = zeros({1, 2, 4, 5}, float32, s);

auto [out, out_axes] = p.vmap({x, y}, {0, 1});
auto expected = broadcast_to(x, {2, 3, 4, 5}, s);
CHECK_EQ(out_axes.size(), 1);
CHECK_EQ(out_axes[0], 0);
CHECK(array_equal(out[0], expected).item<bool>());
}

{
auto p = BroadcastAxes(s, {-1});
auto x = reshape(arange(3 * 1 * 5, float32, s), {3, 1, 5}, s);
auto y = zeros({2, 1, 4, 5}, float32, s);

auto [out, out_axes] = p.vmap({x, y}, {-1, 0});
auto expected = broadcast_to(x, {2, 3, 4, 5}, s);
CHECK_EQ(out_axes.size(), 1);
CHECK_EQ(out_axes[0], 0);
CHECK(array_equal(out[0], expected).item<bool>());
}
}

TEST_CASE("test vmap comparison ops") {
// vmap equal
{
Expand Down