Skip to content

Rewrite more cases of Blockwise IncSubtensor #1560

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 2 additions & 24 deletions pytensor/tensor/rewriting/blockwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
AdvancedIncSubtensor,
AdvancedSubtensor,
Subtensor,
indices_from_subtensor,
)


Expand Down Expand Up @@ -99,6 +98,8 @@ def local_blockwise_alloc(fgraph, node):
BOp(vector, alloc(vector, 10, 5)) -> alloc(BOp)(vector, vector), 10, 5)
BOp(vector, alloc(scalar, 10, 5)) -> alloc(BOp)(vector, alloc(scalar, 5), 10, 5)
BOp(matrix, alloc(vector, 10, 5)) -> BOp(matrix, vector)

This is critical to remove many unnecessary Blockwise, or to reduce the work done by it
"""

op: Blockwise = node.op # type: ignore
Expand Down Expand Up @@ -227,29 +228,6 @@ def local_blockwise_reshape(fgraph, node):
return [new_out]


@register_stabilize
@register_specialize
@node_rewriter([Blockwise])
def local_blockwise_of_subtensor(fgraph, node):
"""Rewrite Blockwise of Subtensor, where the only batch input is the indexed tensor.

Blockwise(Subtensor{a: b})(x, a, b) -> x[:, a:b] when x has one batch dimension, and a/b none
"""
if not isinstance(node.op.core_op, Subtensor):
return

x, *idxs = node.inputs
if not all(all(idx.type.broadcastable) for idx in idxs):
return

core_idxs = indices_from_subtensor(
[idx.squeeze() for idx in idxs], node.op.core_op.idx_list
)
# Add empty slices for the batch dims
none_slices = (slice(None),) * node.op.batch_ndim(node)
return [x[(*none_slices, *core_idxs)]]


class InplaceBlockwiseOptimizer(InplaceGraphOptimizer):
op = Blockwise

Expand Down
218 changes: 167 additions & 51 deletions pytensor/tensor/rewriting/subtensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
ScalarFromTensor,
TensorFromScalar,
alloc,
arange,
cast,
concatenate,
expand_dims,
Expand All @@ -34,9 +35,10 @@
switch,
)
from pytensor.tensor.basic import constant as tensor_constant
from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.blockwise import Blockwise, _squeeze_left
from pytensor.tensor.elemwise import Elemwise
from pytensor.tensor.exceptions import NotScalarConstantError
from pytensor.tensor.extra_ops import broadcast_to
from pytensor.tensor.math import (
add,
and_,
Expand All @@ -58,6 +60,7 @@
)
from pytensor.tensor.shape import (
shape_padleft,
shape_padright,
shape_tuple,
)
from pytensor.tensor.sharedvar import TensorSharedVariable
Expand Down Expand Up @@ -1573,68 +1576,181 @@ def local_uint_constant_indices(fgraph, node):
)


@register_stabilize
@register_specialize
@node_rewriter([Blockwise])
def local_blockwise_of_subtensor(fgraph, node):
"""Rewrite Blockwise of Subtensor, where the only batch input is the indexed tensor.

Blockwise(Subtensor{a: b})(x, a, b) -> x[:, a:b] when x has one batch dimension, and a/b none

TODO: Handle batched indices like we do with blockwise of inc_subtensor
TODO: Extend to AdvanceSubtensor
"""
if not isinstance(node.op.core_op, Subtensor):
return

x, *idxs = node.inputs
if not all(all(idx.type.broadcastable) for idx in idxs):
return

core_idxs = indices_from_subtensor(
[idx.squeeze() for idx in idxs], node.op.core_op.idx_list
)
# Add empty slices for the batch dims
none_slices = (slice(None),) * node.op.batch_ndim(node)
return [x[(*none_slices, *core_idxs)]]


@register_canonicalize("shape_unsafe")
@register_stabilize("shape_unsafe")
@register_specialize("shape_unsafe")
@node_rewriter([Blockwise])
def local_blockwise_advanced_inc_subtensor(fgraph, node):
"""Rewrite blockwise advanced inc_subtensor whithout batched indexes as an inc_subtensor with prepended empty slices."""
if not isinstance(node.op.core_op, AdvancedIncSubtensor):
return None
def local_blockwise_inc_subtensor(fgraph, node):
"""Rewrite blockwised inc_subtensors.

x, y, *idxs = node.inputs
Note: The reason we don't apply this rewrite eagerly in the `vectorize_node` dispatch
Is that we often have batch dimensions from alloc of shapes/reshape that can be removed by rewrites

# It is currently not possible to Vectorize such AdvancedIncSubtensor, but we check again just in case
if any(
(
isinstance(idx, SliceType | NoneTypeT)
or (idx.type.dtype == "bool" and idx.type.ndim > 0)
)
for idx in idxs
):
such as x[:vectorized(w.shape[0])].set(y), that will later be rewritten as x[:w.shape[1]].set(y),
and can be safely rewritten without Blockwise.
"""
core_op = node.op.core_op
if not isinstance(core_op, AdvancedIncSubtensor | IncSubtensor):
return None

op: Blockwise = node.op # type: ignore
batch_ndim = op.batch_ndim(node)

new_idxs = []
for idx in idxs:
if all(idx.type.broadcastable[:batch_ndim]):
new_idxs.append(idx.squeeze(tuple(range(batch_ndim))))
else:
# Rewrite does not apply
x, y, *idxs = node.inputs
[out] = node.outputs
if isinstance(node.op.core_op, AdvancedIncSubtensor):
if any(
(
# Blockwise requires all inputs to be tensors so it is not possible
# to wrap an AdvancedIncSubtensor with slice / newaxis inputs, but we check again just in case
# If this is ever supported we need to pay attention to special behavior of numpy when advanced indices
# are separated by basic indices
isinstance(idx, SliceType | NoneTypeT)
# Also get out if we have boolean indices as they cross dimension boundaries
# / can't be safely broadcasted depending on their runtime content
or (idx.type.dtype == "bool")
)
for idx in idxs
):
return None

x_batch_bcast = x.type.broadcastable[:batch_ndim]
y_batch_bcast = y.type.broadcastable[:batch_ndim]
if any(xb and not yb for xb, yb in zip(x_batch_bcast, y_batch_bcast, strict=True)):
# Need to broadcast batch x dims
batch_shape = tuple(
x_dim if (not xb or yb) else y_dim
for xb, x_dim, yb, y_dim in zip(
x_batch_bcast,
batch_ndim = node.op.batch_ndim(node)
idxs_core_ndim = [len(inp_sig) for inp_sig in node.op.inputs_sig[2:]]
max_idx_core_ndim = max(idxs_core_ndim, default=0)

# Step 1. Broadcast buffer to batch_shape
if x.type.broadcastable != out.type.broadcastable:
batch_shape = [1] * batch_ndim
for inp in node.inputs:
for i, (broadcastable, batch_dim) in enumerate(
zip(inp.type.broadcastable[:batch_ndim], tuple(inp.shape)[:batch_ndim])
):
if broadcastable:
# This dimension is broadcastable, it doesn't provide shape information
continue
if batch_shape[i] != 1:
# We already found a source of shape for this batch dimension
continue
batch_shape[i] = batch_dim
x = broadcast_to(x, (*batch_shape, *x.shape[batch_ndim:]))
assert x.type.broadcastable == out.type.broadcastable

# Step 2. Massage indices so they respect blockwise semantics
if isinstance(core_op, IncSubtensor):
# For basic IncSubtensor there are two cases:
# 1. Slice entries -> We need to squeeze away dummy dimensions so we can convert back to slice
# 2. Integers -> Can be used as is, but we try to squeeze away dummy batch dimensions
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where are these dummy dimensions coming from?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Blockwise adds expand_left dims for all inputs to match output batch dims, like Elemwise does. Makes other parts of the code easier to reason about

# in case we can end up with a basic IncSubtensor again
core_idxs = []
counter = 0
for idx in core_op.idx_list:
if isinstance(idx, slice):
# Squeeze away dummy dimensions so we can convert to slice
new_entries = [None, None, None]
for i, entry in enumerate((idx.start, idx.stop, idx.step)):
if entry is None:
continue
else:
new_entries[i] = new_entry = idxs[counter].squeeze()
counter += 1
if new_entry.ndim > 0:
# If the slice entry has dimensions after the squeeze we can't convert it to a slice
# We could try to convert to equivalent integer indices, but nothing guarantees
# that the slice is "square".
return None
core_idxs.append(slice(*new_entries))
else:
core_idxs.append(_squeeze_left(idxs[counter]))
counter += 1
else:
# For AdvancedIncSubtensor we have tensor integer indices,
# We need to expand batch indexes on the right, so they don't interact with core index dimensions
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It isn't clear to me why we're expanding right in this case. Is it because we're just reasoning about the batch indexes, and we want to broadcast them with the core indices in the end?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

batch indices by definition can't interact with core indices.

In advanced indexing interaction happens when indexes are on the same axes. When they are on different axes they broadcast/act like orthogonal indices (i.e, you visit all combinations of index1 and index2, instead of iterating simultaneously over both)

# We still squeeze on the left in case that allows us to use simpler indices
core_idxs = [
_squeeze_left(
shape_padright(idx, max_idx_core_ndim - idx_core_ndim),
stop_at_dim=batch_ndim,
)
for idx, idx_core_ndim in zip(idxs, idxs_core_ndim)
]

# Step 3. Create new indices for the new batch dimension of x
if not all(
all(idx.type.broadcastable[:batch_ndim])
for idx in idxs
if not isinstance(idx, slice)
):
# If indices have batch dimensions in the indices, they will interact with the new dimensions of x
# We build vectorized indexing with new arange indices that do not interact with core indices or each other
# (i.e., they broadcast)

# Note: due to how numpy handles non-consecutive advanced indexing (transposing it to the front),
# we don't want to create a mix of slice(None), and arange() indices for the new batch dimension,
# even if not all batch dimensions have corresponding batch indices.
batch_slices = [
shape_padright(arange(x_batch_shape, dtype="int64"), n)
for (x_batch_shape, n) in zip(
tuple(x.shape)[:batch_ndim],
y_batch_bcast,
tuple(y.shape)[:batch_ndim],
strict=True,
reversed(range(max_idx_core_ndim, max_idx_core_ndim + batch_ndim)),
)
)
core_shape = tuple(x.shape)[batch_ndim:]
x = alloc(x, *batch_shape, *core_shape)

new_idxs = [slice(None)] * batch_ndim + new_idxs
x_view = x[tuple(new_idxs)]

# We need to introduce any implicit expand_dims on core dimension of y
y_core_ndim = y.type.ndim - batch_ndim
if (missing_y_core_ndim := x_view.type.ndim - batch_ndim - y_core_ndim) > 0:
missing_axes = tuple(range(batch_ndim, batch_ndim + missing_y_core_ndim))
y = expand_dims(y, missing_axes)

symbolic_idxs = x_view.owner.inputs[1:]
new_out = op.core_op.make_node(x, y, *symbolic_idxs).outputs
copy_stack_trace(node.outputs, new_out)
return new_out
]
else:
# In the case we don't have batch indices,
# we can use slice(None) to broadcast the core indices to each new batch dimension of x / y
batch_slices = [slice(None)] * batch_ndim

new_idxs = (*batch_slices, *core_idxs)
x_view = x[new_idxs]

# Step 4. Introduce any implicit expand_dims on core dimension of y
missing_y_core_ndim = x_view.type.ndim - y.type.ndim
implicit_axes = tuple(range(batch_ndim, batch_ndim + missing_y_core_ndim))
y = _squeeze_left(expand_dims(y, implicit_axes), stop_at_dim=batch_ndim)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it possible for the squeeze_left to cancel some of the work done by expand_dims here?

Copy link
Member Author

@ricardoV94 ricardoV94 Aug 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I could call the rewrite that combines dimshuffle directly but that increases a bit the code complexity. It will be called anyway later


if isinstance(core_op, IncSubtensor):
# Check if we can still use a basic IncSubtensor
if isinstance(x_view.owner.op, Subtensor):
new_props = core_op._props_dict()
new_props["idx_list"] = x_view.owner.op.idx_list
new_core_op = type(core_op)(**new_props)
symbolic_idxs = x_view.owner.inputs[1:]
new_out = new_core_op(x, y, *symbolic_idxs)
else:
# We need to use AdvancedSet/IncSubtensor
if core_op.set_instead_of_inc:
new_out = x[new_idxs].set(y)
else:
new_out = x[new_idxs].inc(y)
else:
# AdvancedIncSubtensor takes symbolic indices/slices directly, no need to create a new op
symbolic_idxs = x_view.owner.inputs[1:]
new_out = core_op(x, y, *symbolic_idxs)

copy_stack_trace(out, new_out)
return [new_out]


@node_rewriter(tracks=[AdvancedSubtensor, AdvancedIncSubtensor])
Expand Down
50 changes: 19 additions & 31 deletions pytensor/tensor/subtensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1433,7 +1433,6 @@ def _process(self, idxs, op_inputs, pstate):
pprint.assign(Subtensor, SubtensorPrinter())


# TODO: Implement similar vectorize for Inc/SetSubtensor
@_vectorize_node.register(Subtensor)
def vectorize_subtensor(op: Subtensor, node, batch_x, *batch_idxs):
"""Rewrite subtensor with non-batched indexes as another Subtensor with prepended empty slices."""
Expand Down Expand Up @@ -1756,41 +1755,30 @@ def make_node(self, x, y, *inputs):
def decl_view(self):
return "PyArrayObject * zview = NULL;"

def perform(self, node, inputs, out_):
(out,) = out_
x, y = inputs[:2]
indices = list(reversed(inputs[2:]))

def _convert(entry):
if isinstance(entry, Type):
return indices.pop()
elif isinstance(entry, slice):
return slice(
_convert(entry.start), _convert(entry.stop), _convert(entry.step)
def perform(self, node, inputs, output_storage):
x, y, *flat_indices = inputs

flat_indices_iterator = iter(flat_indices)
indices = tuple(
(
next(flat_indices_iterator)
if isinstance(entry, Type)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is Type type?

Copy link
Member Author

@ricardoV94 ricardoV94 Aug 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Stuff like TensorType/ ScalarType (the type of a variable).

They use those in this flat map thing. Actually it's always ScalarType but I left as it was

else slice(
None if entry.start is None else next(flat_indices_iterator),
None if entry.stop is None else next(flat_indices_iterator),
None if entry.step is None else next(flat_indices_iterator),
)
else:
return entry
)
for entry in self.idx_list
)

cdata = tuple(map(_convert, self.idx_list))
if len(cdata) == 1:
cdata = cdata[0]
if not self.inplace:
x = x.copy()
sub_x = x.__getitem__(cdata)
if sub_x.shape:
# we've sliced out an N-D tensor with N > 0
if not self.set_instead_of_inc:
sub_x += y
else:
# sub_x += -sub_x + y
x.__setitem__(cdata, y)
if self.set_instead_of_inc:
x[indices] = y
else:
# scalar case
if not self.set_instead_of_inc:
x.__setitem__(cdata, sub_x + y)
else:
x.__setitem__(cdata, y)
out[0] = x
x[indices] += y
output_storage[0][0] = x

def c_code(self, node, name, inputs, outputs, sub):
# This method delegates much of the work to helper
Expand Down
Loading