Skip to content

Optimize matmuls involving block diagonal matrices #1493

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jul 26, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 38 additions & 1 deletion pytensor/tensor/extra_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from pytensor.scalar import upcast
from pytensor.tensor import TensorLike, as_tensor_variable
from pytensor.tensor import basic as ptb
from pytensor.tensor.basic import alloc, second
from pytensor.tensor.basic import alloc, join, second
from pytensor.tensor.exceptions import NotScalarConstantError
from pytensor.tensor.math import abs as pt_abs
from pytensor.tensor.math import all as pt_all
Expand Down Expand Up @@ -2018,6 +2018,42 @@ def broadcast_with_others(a, others):
return brodacasted_vars


def concat_with_broadcast(tensor_list, axis=0):
"""
Concatenate a list of tensors, broadcasting the non-concatenated dimensions to align.
"""
if not tensor_list:
raise ValueError("Cannot concatenate an empty list of tensors.")

ndim = tensor_list[0].ndim
if not all(t.ndim == ndim for t in tensor_list):
raise TypeError(
"Only tensors with the same number of dimensions can be concatenated. "
f"Input ndims were: {[x.ndim for x in tensor_list]}"
)

axis = normalize_axis_index(axis=axis, ndim=ndim)
non_concat_shape = [1 if i != axis else None for i in range(ndim)]

for tensor_inp in tensor_list:
for i, (bcast, sh) in enumerate(
zip(tensor_inp.type.broadcastable, tensor_inp.shape)
):
if bcast or i == axis:
continue
non_concat_shape[i] = sh

assert non_concat_shape.count(None) == 1

bcast_tensor_inputs = []
for tensor_inp in tensor_list:
# We modify the concat_axis in place, as we don't need the list anywhere else
non_concat_shape[axis] = tensor_inp.shape[axis]
bcast_tensor_inputs.append(broadcast_to(tensor_inp, non_concat_shape))

return join(axis, *bcast_tensor_inputs)


__all__ = [
"searchsorted",
"cumsum",
Expand All @@ -2035,6 +2071,7 @@ def broadcast_with_others(a, others):
"ravel_multi_index",
"broadcast_shape",
"broadcast_to",
"concat_with_broadcast",
"geomspace",
"logspace",
"linspace",
Expand Down
72 changes: 67 additions & 5 deletions pytensor/tensor/rewriting/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,18 +32,21 @@
moveaxis,
ones_like,
register_infer_shape,
split,
switch,
zeros,
zeros_like,
)
from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise
from pytensor.tensor.exceptions import NotScalarConstantError
from pytensor.tensor.extra_ops import broadcast_arrays
from pytensor.tensor.extra_ops import broadcast_arrays, concat_with_broadcast
from pytensor.tensor.math import (
Dot,
Prod,
Sum,
_conj,
_dot,
_matmul,
add,
digamma,
Expand Down Expand Up @@ -96,6 +99,7 @@
from pytensor.tensor.rewriting.elemwise import apply_local_dimshuffle_lift
from pytensor.tensor.rewriting.linalg import is_matrix_transpose
from pytensor.tensor.shape import Shape, Shape_i
from pytensor.tensor.slinalg import BlockDiagonal
from pytensor.tensor.subtensor import Subtensor
from pytensor.tensor.type import (
complex_dtypes,
Expand Down Expand Up @@ -146,6 +150,68 @@ def local_0_dot_x(fgraph, node):
return [zeros((x.shape[0], y.shape[1]), dtype=node.outputs[0].type.dtype)]


@register_stabilize
@node_rewriter([Blockwise])
def local_block_diag_dot_to_dot_block_diag(fgraph, node):
r"""
Perform the rewrite ``dot(block_diag(A, B), C) -> concat(dot(A, C), dot(B, C))``

BlockDiag results in the creation of a matrix of shape ``(n1 * n2, m1 * m2)``. Because dot has complexity
of approximately O(n^3), it's always better to perform two dot products on the smaller matrices, rather than
a single dot on the larger matrix.
"""
if not isinstance(node.op.core_op, BlockDiagonal):
return

# Check that the BlockDiagonal is an input to a Dot node:
for client in itertools.chain.from_iterable(
get_clients_at_depth(fgraph, node, depth=i) for i in [1, 2]
):
if client.op not in (_dot, _matmul):
continue

[blockdiag_result] = node.outputs
blockdiag_inputs = node.inputs

dot_op = client.op

try:
client_idx = client.inputs.index(blockdiag_result)
except ValueError:
# If the blockdiag result is not an input to the dot, there is at least one Op between them (usually a
# DimShuffle). In this case, we need to figure out which of the inputs of the dot eventually leads to the
# blockdiag result.
for ancestor in client.inputs:
if ancestor.owner and blockdiag_result in ancestor.owner.inputs:
client_idx = client.inputs.index(ancestor)
break

other_input = client.inputs[1 - client_idx]

split_axis = -2 if client_idx == 0 else -1
split_size_axis = -1 if client_idx == 0 else -2

other_dot_input_split = split(
other_input,
splits_size=[
component.shape[split_size_axis] for component in blockdiag_inputs
],
n_splits=len(blockdiag_inputs),
axis=split_axis,
)

split_dot_results = [
dot_op(component, other_split)
if client_idx == 0
else dot_op(other_split, component)
for component, other_split in zip(blockdiag_inputs, other_dot_input_split)
]
new_output = concat_with_broadcast(split_dot_results, axis=split_axis)

copy_stack_trace(node.outputs[0], new_output)
return {client.outputs[0]: new_output}


@register_canonicalize
@node_rewriter([Dot, _matmul])
def local_lift_transpose_through_dot(fgraph, node):
Expand Down Expand Up @@ -2582,7 +2648,6 @@ def add_calculate(num, denum, aslist=False, out_type=None):
name="add_canonizer_group",
)


register_canonicalize(local_add_canonizer, "shape_unsafe", name="local_add_canonizer")


Expand Down Expand Up @@ -3720,7 +3785,6 @@ def logmexpm1_to_log1mexp(fgraph, node):
)
register_stabilize(logdiffexp_to_log1mexpdiff, name="logdiffexp_to_log1mexpdiff")


# log(sigmoid(x) / (1 - sigmoid(x))) -> x
# i.e logit(sigmoid(x)) -> x
local_logit_sigmoid = PatternNodeRewriter(
Expand All @@ -3734,7 +3798,6 @@ def logmexpm1_to_log1mexp(fgraph, node):
register_canonicalize(local_logit_sigmoid)
register_specialize(local_logit_sigmoid)


# sigmoid(log(x / (1-x)) -> x
# i.e., sigmoid(logit(x)) -> x
local_sigmoid_logit = PatternNodeRewriter(
Expand Down Expand Up @@ -3775,7 +3838,6 @@ def local_useless_conj(fgraph, node):

register_specialize(local_polygamma_to_tri_gamma)


local_log_kv = PatternNodeRewriter(
# Rewrite log(kv(v, x)) = log(kve(v, x) * exp(-x)) -> log(kve(v, x)) - x
# During stabilize -x is converted to -1.0 * x
Expand Down
25 changes: 2 additions & 23 deletions pytensor/xtensor/rewriting/shape.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
from pytensor.graph import node_rewriter
from pytensor.tensor import (
broadcast_to,
concat_with_broadcast,
expand_dims,
join,
moveaxis,
specify_shape,
squeeze,
Expand Down Expand Up @@ -74,28 +74,7 @@ def lower_concat(fgraph, node):

# Convert input XTensors to Tensors and align batch dimensions
tensor_inputs = [lower_aligned(inp, out_dims) for inp in node.inputs]

# Broadcast non-concatenated dimensions of each input
non_concat_shape = [None] * len(out_dims)
for tensor_inp in tensor_inputs:
# TODO: This is assuming the graph is correct and every non-concat dimension matches in shape at runtime
# I'm running this as "shape_unsafe" to simplify the logic / returned graph
for i, (bcast, sh) in enumerate(
zip(tensor_inp.type.broadcastable, tensor_inp.shape)
):
if bcast or i == concat_axis or non_concat_shape[i] is not None:
continue
non_concat_shape[i] = sh

assert non_concat_shape.count(None) == 1

bcast_tensor_inputs = []
for tensor_inp in tensor_inputs:
# We modify the concat_axis in place, as we don't need the list anywhere else
non_concat_shape[concat_axis] = tensor_inp.shape[concat_axis]
bcast_tensor_inputs.append(broadcast_to(tensor_inp, non_concat_shape))

joined_tensor = join(concat_axis, *bcast_tensor_inputs)
joined_tensor = concat_with_broadcast(tensor_inputs, axis=concat_axis)
new_out = xtensor_from_tensor(joined_tensor, dims=out_dims)
return [new_out]

Expand Down
119 changes: 119 additions & 0 deletions tests/tensor/rewriting/test_math.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@
simplify_mul,
)
from pytensor.tensor.shape import Reshape, Shape_i, SpecifyShape, specify_shape
from pytensor.tensor.slinalg import BlockDiagonal
from pytensor.tensor.type import (
TensorType,
cmatrix,
Expand Down Expand Up @@ -4745,3 +4746,121 @@ def test_local_dot_to_mul(batched, a_shape, b_shape):
out.eval({a: a_test, b: b_test}, mode=test_mode),
rewritten_out.eval({a: a_test, b: b_test}, mode=test_mode),
)


@pytest.mark.parametrize("left_multiply", [True, False], ids=["left", "right"])
@pytest.mark.parametrize(
"batch_blockdiag", [True, False], ids=["batch_blockdiag", "unbatched_blockdiag"]
)
@pytest.mark.parametrize(
"batch_other", [True, False], ids=["batched_other", "unbatched_other"]
)
def test_local_block_diag_dot_to_dot_block_diag(
left_multiply, batch_blockdiag, batch_other
):
"""
Test that dot(block_diag(x, y,), z) is rewritten to concat(dot(x, z[:n]), dot(y, z[n:]))
"""

def has_blockdiag(graph):
return any(
(
var.owner
and (
isinstance(var.owner.op, BlockDiagonal)
or (
isinstance(var.owner.op, Blockwise)
and isinstance(var.owner.op.core_op, BlockDiagonal)
)
)
)
for var in ancestors([graph])
)

a = tensor("a", shape=(4, 2))
b = tensor("b", shape=(2, 4) if not batch_blockdiag else (3, 2, 4))
c = tensor("c", shape=(4, 4))
x = pt.linalg.block_diag(a, b, c)

d = tensor("d", shape=(10, 10) if not batch_other else (3, 1, 10, 10))

# Test multiple clients are all rewritten
if left_multiply:
out = x @ d
else:
out = d @ x

assert has_blockdiag(out)
fn = pytensor.function([a, b, c, d], out, mode=rewrite_mode)
assert not has_blockdiag(fn.maker.fgraph.outputs[0])

n_dots_rewrite = sum(
isinstance(node.op, Dot | Dot22)
or (isinstance(node.op, Blockwise) and isinstance(node.op.core_op, Dot | Dot22))
for node in fn.maker.fgraph.apply_nodes
)
assert n_dots_rewrite == 3

fn_expected = pytensor.function(
[a, b, c, d],
out,
mode=Mode(linker="py", optimizer=None),
)
assert has_blockdiag(fn_expected.maker.fgraph.outputs[0])

n_dots_no_rewrite = sum(
isinstance(node.op, Dot | Dot22)
or (isinstance(node.op, Blockwise) and isinstance(node.op.core_op, Dot | Dot22))
for node in fn_expected.maker.fgraph.apply_nodes
)
assert n_dots_no_rewrite == 1

rng = np.random.default_rng()
a_val = rng.normal(size=a.type.shape).astype(a.type.dtype)
b_val = rng.normal(size=b.type.shape).astype(b.type.dtype)
c_val = rng.normal(size=c.type.shape).astype(c.type.dtype)
d_val = rng.normal(size=d.type.shape).astype(d.type.dtype)

rewrite_out = fn(a_val, b_val, c_val, d_val)
expected_out = fn_expected(a_val, b_val, c_val, d_val)
np.testing.assert_allclose(
rewrite_out,
expected_out,
atol=1e-6 if config.floatX == "float32" else 1e-12,
rtol=1e-6 if config.floatX == "float32" else 1e-12,
)


@pytest.mark.parametrize("rewrite", [True, False], ids=["rewrite", "no_rewrite"])
@pytest.mark.parametrize("size", [10, 100, 1000], ids=["small", "medium", "large"])
def test_block_diag_dot_to_dot_concat_benchmark(benchmark, size, rewrite):
rng = np.random.default_rng()
a_size = int(rng.uniform(1, int(0.8 * size)))
b_size = int(rng.uniform(1, int(0.8 * (size - a_size))))
c_size = size - a_size - b_size

a = tensor("a", shape=(a_size, a_size))
b = tensor("b", shape=(b_size, b_size))
c = tensor("c", shape=(c_size, c_size))
d = tensor("d", shape=(size,))

x = pt.linalg.block_diag(a, b, c)
out = x @ d

mode = get_default_mode()
if not rewrite:
mode = mode.excluding("local_block_diag_dot_to_dot_block_diag")
fn = pytensor.function([a, b, c, d], out, mode=mode)

a_val = rng.normal(size=a.type.shape).astype(a.type.dtype)
b_val = rng.normal(size=b.type.shape).astype(b.type.dtype)
c_val = rng.normal(size=c.type.shape).astype(c.type.dtype)
d_val = rng.normal(size=d.type.shape).astype(d.type.dtype)

benchmark(
fn,
a_val,
b_val,
c_val,
d_val,
)
Loading