Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 22 additions & 30 deletions mlx/ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1676,23 +1676,19 @@ std::vector<array> broadcast_arrays(
for (int i = 0; i < inputs.size(); ++i) {
auto& in = inputs[i];
auto out_shape = check_and_get_shape(in);
if (in.shape() == out_shape) {
outputs.push_back(in);
} else {
// broadcasted array goes first followed by other stopgrad inputs
std::vector<array> p_inputs = {in};
for (int j = 0; j < inputs.size(); ++j) {
if (j == i) {
continue;
}
p_inputs.push_back(stop_grad_inputs[j]);
// broadcasted array goes first followed by other stopgrad inputs
std::vector<array> p_inputs = {in};
for (int j = 0; j < inputs.size(); ++j) {
if (j == i) {
continue;
}
outputs.push_back(array(
std::move(out_shape),
in.dtype(),
std::make_shared<BroadcastAxes>(to_stream(s), ignore_axes),
std::move(p_inputs)));
p_inputs.push_back(stop_grad_inputs[j]);
}
outputs.push_back(array(
out_shape,
in.dtype(),
std::make_shared<BroadcastAxes>(to_stream(s), ignore_axes),
std::move(p_inputs)));
}
return outputs;
}
Expand Down Expand Up @@ -1727,23 +1723,19 @@ std::vector<array> broadcast_arrays(
}
for (int i = 0; i < inputs.size(); ++i) {
auto& in = inputs[i];
if (in.shape() == shape) {
outputs.push_back(in);
} else {
// broadcasted array goes first followed by other stopgrad inputs
std::vector<array> p_inputs = {in};
for (int j = 0; j < inputs.size(); ++j) {
if (j == i) {
continue;
}
p_inputs.push_back(stop_grad_inputs[j]);
// broadcasted array goes first followed by other stopgrad inputs
std::vector<array> p_inputs = {in};
for (int j = 0; j < inputs.size(); ++j) {
if (j == i) {
continue;
}
outputs.push_back(array(
shape,
in.dtype(),
std::make_shared<Broadcast>(to_stream(s), shape),
std::move(p_inputs)));
p_inputs.push_back(stop_grad_inputs[j]);
}
outputs.push_back(array(
shape,
in.dtype(),
std::make_shared<Broadcast>(to_stream(s), shape),
std::move(p_inputs)));
}
return outputs;
}
Expand Down
35 changes: 34 additions & 1 deletion mlx/primitives.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -904,7 +904,40 @@ std::vector<array> BroadcastAxes::jvp(
std::pair<std::vector<array>, std::vector<int>> BroadcastAxes::vmap(
const std::vector<array>& inputs,
const std::vector<int>& axes) {
throw std::invalid_argument("[BroadcastAxes] VMAP NYI");
std::vector<array> new_inputs = inputs;
std::vector<int> new_axes = axes;
size_t ndim = 0;
bool have_batch = false;
for (int i = 0; i < inputs.size(); i++) {
have_batch |= axes[i] >= 0;
ndim = std::max(inputs[i].ndim(), ndim);
}

std::vector<int> expand;
expand.reserve(ndim);
for (int i = 0; i < inputs.size(); i++) {
int extra = ndim - inputs[i].ndim();
if (axes[i] >= 0 && extra > 0) {
new_axes[i] += extra;
expand.resize(extra);
std::iota(expand.begin(), expand.end(), 0);
new_inputs[i] = expand_dims(new_inputs[i], expand, stream());
}

if (new_axes[i] > 0) {
new_inputs[i] = moveaxis(new_inputs[i], new_axes[i], 0, stream());
}
}

auto shape = output_shape(new_inputs, ignore_axes_);
auto dtype = new_inputs[0].dtype();
return {
{array(
shape,
dtype,
std::make_shared<BroadcastAxes>(stream(), ignore_axes_),
std::move(new_inputs))},
{have_batch ? 0 : -1}};
}

bool BroadcastAxes::is_equivalent(const Primitive& other) const {
Expand Down
45 changes: 45 additions & 0 deletions python/tests/test_vmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -899,6 +899,51 @@ def scatter_fn(x, m, src):
out = double_scatter(a + 0, mask, src)
self.assertTrue(mx.array_equal(expected, out))

def test_broadcast_axes_vmap(self):
# Broadcast axes requires shapeless compile to properly test

counter = [0]

def fn(x, y):
counter[0] += 1
return mx.matmul(x, y)

x = mx.random.normal((2, 3, 1, 4, 5))
y = mx.random.normal((1, 2, 5, 6))
z = mx.random.normal((3, 2, 1, 4, 5))
w = mx.random.normal((2, 3, 5, 6))

vmap_fn = mx.vmap(fn, in_axes=(0, 1))
cvmap_fn = mx.compile(vmap_fn, shapeless=True)

expected = vmap_fn(x, y)
out = cvmap_fn(x, y)
self.assertTrue(mx.array_equal(expected, out))
self.assertEqual(2, counter[0])

expected = vmap_fn(z, w)
out = cvmap_fn(z, w)
self.assertTrue(mx.array_equal(expected, out))
self.assertEqual(3, counter[0])

x = mx.random.normal((2, 3, 1, 4, 5))
y = mx.random.normal((1, 2, 5, 6))
z = mx.random.normal((2, 3, 1, 7, 2))
w = mx.random.normal((1, 2, 2, 3))

vmap_fn = mx.vmap(fn, in_axes=(0, None))
cvmap_fn = mx.compile(vmap_fn, shapeless=True)

expected = vmap_fn(x, y)
out = cvmap_fn(x, y)
self.assertTrue(mx.array_equal(expected, out))
self.assertEqual(5, counter[0])

expected = vmap_fn(z, w)
out = cvmap_fn(z, w)
self.assertTrue(mx.array_equal(expected, out))
self.assertEqual(6, counter[0])


if __name__ == "__main__":
mlx_tests.MLXTestRunner()
Loading