Skip to content

Commit a0225e1

Browse files
authored
Fix Blockwise vmap dispatch for no batch dimensions (#1705)
* Fix Blockwise vmap dispatch for no batch dimensions Updates comments in funcify_Blockwise to avoid confusion about behaviour. Adds tests to verify correct behavior for these cases. * pre-commit * Take our docs
1 parent 1f9a67b commit a0225e1

File tree

2 files changed

+53
-4
lines changed

2 files changed

+53
-4
lines changed

pytensor/link/mlx/dispatch/blockwise.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,30 +6,37 @@
66

77
@mlx_funcify.register(Blockwise)
88
def funcify_Blockwise(op: Blockwise, node, **kwargs):
9-
# 2) Otherwise, get the core python function for this Blockwise
9+
# Get the core python function for this Blockwise operation
1010
core_node = op._create_dummy_core_node(node.inputs)
1111
core_f = mlx_funcify(op.core_op, core_node)
1212

13-
# 3) Determine how many inputs correspond to batch dimensions
13+
# Determine how many batch dimensions are present in the output
1414
n_batch = op.batch_ndim(node)
1515

16-
# 4) Handle case where no vectorization is needed
16+
# If there are no batch dimensions, just return the core function
1717
if n_batch == 0:
1818
return core_f
1919

20-
# 5) Vectorize using mx.vmap over any batched inputs
20+
# Build in_axes specification for mx.vmap
21+
# Each input can be vectorized (axis=0) or static (axis=None)
2122
in_axes: list[int | None] = []
2223
for inp, sig in zip(node.inputs, op.inputs_sig):
2324
batch_ndim = inp.type.ndim - len(sig)
2425
if batch_ndim == 0:
26+
# Input has no batch dimensions - treat as static
2527
in_axes.append(None)
2628
continue
2729

2830
batch_bcast = inp.type.broadcastable[:batch_ndim]
2931
# If all batch dims are broadcastable (size 1), treat input as static
32+
# Otherwise, vectorize over the first dimension (axis=0)
3033
in_axes.append(0 if not all(batch_bcast) else None)
3134

35+
# If all inputs are static (no actual vectorization needed), return core function
36+
# This prevents calling mx.vmap with all-None in_axes, which would raise:
37+
# "ValueError: At least one of in_axes must be non-None"
3238
if not any(axis == 0 for axis in in_axes):
3339
return core_f
3440

41+
# Apply mx.vmap to vectorize the core function over batch dimensions
3542
return mx.vmap(core_f, in_axes=tuple(in_axes))

tests/link/mlx/test_blockwise.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,3 +25,45 @@ def test_blockwise_conv1d():
2525

2626
# assert isinstance(out.owner.op, Blockwise)
2727
compare_mlx_and_py([a, b], [out], test_values, must_be_device_array=True)
28+
29+
30+
def test_blockwise_no_batch_dimensions():
31+
"""Test that Blockwise returns the core function when there are no batch dimensions.
32+
33+
This verifies the fix for the vmap dispatcher issue where mx.vmap should not
34+
be called when there are no batch dimensions to vectorize over.
35+
"""
36+
rng = np.random.default_rng(42)
37+
38+
# Create a blockwise matmul with no batch dimensions (core operation only)
39+
x = pt.matrix("x")
40+
y = pt.matrix("y")
41+
42+
blockwise_matmul = Blockwise(Dot(), signature="(i,j),(j,k)->(i,k)")
43+
z = blockwise_matmul(x, y)
44+
45+
x_test = rng.normal(size=(2, 3))
46+
y_test = rng.normal(size=(3, 4))
47+
48+
compare_mlx_and_py([x, y], [z], [x_test, y_test], must_be_device_array=True)
49+
50+
51+
def test_blockwise_all_broadcastable_batch_dims():
52+
"""Test that Blockwise returns the core function when all batch dims are broadcastable.
53+
54+
When all batch dimensions are size-1 (broadcastable), vmap should not be called
55+
since there's no actual vectorization needed.
56+
"""
57+
rng = np.random.default_rng(43)
58+
59+
# Create inputs with size-1 batch dimensions
60+
x = tensor("x", shape=(1, 2, 3))
61+
y = tensor("y", shape=(1, 3, 4))
62+
63+
blockwise_matmul = Blockwise(Dot(), signature="(i,j),(j,k)->(i,k)")
64+
z = blockwise_matmul(x, y)
65+
66+
x_test = rng.normal(size=(1, 2, 3))
67+
y_test = rng.normal(size=(1, 3, 4))
68+
69+
compare_mlx_and_py([x, y], [z], [x_test, y_test], must_be_device_array=True)

0 commit comments

Comments
 (0)