-
Notifications
You must be signed in to change notification settings - Fork 138
Rewrite more cases of Blockwise IncSubtensor #1560
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -24,6 +24,7 @@ | |
ScalarFromTensor, | ||
TensorFromScalar, | ||
alloc, | ||
arange, | ||
cast, | ||
concatenate, | ||
expand_dims, | ||
|
@@ -34,9 +35,10 @@ | |
switch, | ||
) | ||
from pytensor.tensor.basic import constant as tensor_constant | ||
from pytensor.tensor.blockwise import Blockwise | ||
from pytensor.tensor.blockwise import Blockwise, _squeeze_left | ||
from pytensor.tensor.elemwise import Elemwise | ||
from pytensor.tensor.exceptions import NotScalarConstantError | ||
from pytensor.tensor.extra_ops import broadcast_to | ||
from pytensor.tensor.math import ( | ||
add, | ||
and_, | ||
|
@@ -58,6 +60,7 @@ | |
) | ||
from pytensor.tensor.shape import ( | ||
shape_padleft, | ||
shape_padright, | ||
shape_tuple, | ||
) | ||
from pytensor.tensor.sharedvar import TensorSharedVariable | ||
|
@@ -1573,68 +1576,181 @@ def local_uint_constant_indices(fgraph, node): | |
) | ||
|
||
|
||
@register_stabilize | ||
@register_specialize | ||
@node_rewriter([Blockwise]) | ||
def local_blockwise_of_subtensor(fgraph, node): | ||
"""Rewrite Blockwise of Subtensor, where the only batch input is the indexed tensor. | ||
|
||
Blockwise(Subtensor{a: b})(x, a, b) -> x[:, a:b] when x has one batch dimension, and a/b none | ||
|
||
TODO: Handle batched indices like we do with blockwise of inc_subtensor | ||
TODO: Extend to AdvanceSubtensor | ||
""" | ||
if not isinstance(node.op.core_op, Subtensor): | ||
return | ||
|
||
x, *idxs = node.inputs | ||
if not all(all(idx.type.broadcastable) for idx in idxs): | ||
return | ||
|
||
core_idxs = indices_from_subtensor( | ||
[idx.squeeze() for idx in idxs], node.op.core_op.idx_list | ||
) | ||
# Add empty slices for the batch dims | ||
none_slices = (slice(None),) * node.op.batch_ndim(node) | ||
return [x[(*none_slices, *core_idxs)]] | ||
|
||
|
||
@register_canonicalize("shape_unsafe") | ||
@register_stabilize("shape_unsafe") | ||
@register_specialize("shape_unsafe") | ||
@node_rewriter([Blockwise]) | ||
def local_blockwise_advanced_inc_subtensor(fgraph, node): | ||
"""Rewrite blockwise advanced inc_subtensor whithout batched indexes as an inc_subtensor with prepended empty slices.""" | ||
if not isinstance(node.op.core_op, AdvancedIncSubtensor): | ||
return None | ||
def local_blockwise_inc_subtensor(fgraph, node): | ||
"""Rewrite blockwised inc_subtensors. | ||
|
||
x, y, *idxs = node.inputs | ||
Note: The reason we don't apply this rewrite eagerly in the `vectorize_node` dispatch | ||
Is that we often have batch dimensions from alloc of shapes/reshape that can be removed by rewrites | ||
|
||
# It is currently not possible to Vectorize such AdvancedIncSubtensor, but we check again just in case | ||
if any( | ||
( | ||
isinstance(idx, SliceType | NoneTypeT) | ||
or (idx.type.dtype == "bool" and idx.type.ndim > 0) | ||
) | ||
for idx in idxs | ||
): | ||
such as x[:vectorized(w.shape[0])].set(y), that will later be rewritten as x[:w.shape[1]].set(y), | ||
and can be safely rewritten without Blockwise. | ||
""" | ||
core_op = node.op.core_op | ||
if not isinstance(core_op, AdvancedIncSubtensor | IncSubtensor): | ||
return None | ||
|
||
op: Blockwise = node.op # type: ignore | ||
batch_ndim = op.batch_ndim(node) | ||
|
||
new_idxs = [] | ||
for idx in idxs: | ||
if all(idx.type.broadcastable[:batch_ndim]): | ||
new_idxs.append(idx.squeeze(tuple(range(batch_ndim)))) | ||
else: | ||
# Rewrite does not apply | ||
x, y, *idxs = node.inputs | ||
[out] = node.outputs | ||
if isinstance(node.op.core_op, AdvancedIncSubtensor): | ||
if any( | ||
( | ||
# Blockwise requires all inputs to be tensors so it is not possible | ||
# to wrap an AdvancedIncSubtensor with slice / newaxis inputs, but we check again just in case | ||
# If this is ever supported we need to pay attention to special behavior of numpy when advanced indices | ||
# are separated by basic indices | ||
isinstance(idx, SliceType | NoneTypeT) | ||
# Also get out if we have boolean indices as they cross dimension boundaries | ||
# / can't be safely broadcasted depending on their runtime content | ||
or (idx.type.dtype == "bool") | ||
) | ||
for idx in idxs | ||
): | ||
return None | ||
|
||
x_batch_bcast = x.type.broadcastable[:batch_ndim] | ||
y_batch_bcast = y.type.broadcastable[:batch_ndim] | ||
if any(xb and not yb for xb, yb in zip(x_batch_bcast, y_batch_bcast, strict=True)): | ||
# Need to broadcast batch x dims | ||
batch_shape = tuple( | ||
x_dim if (not xb or yb) else y_dim | ||
for xb, x_dim, yb, y_dim in zip( | ||
x_batch_bcast, | ||
batch_ndim = node.op.batch_ndim(node) | ||
idxs_core_ndim = [len(inp_sig) for inp_sig in node.op.inputs_sig[2:]] | ||
max_idx_core_ndim = max(idxs_core_ndim, default=0) | ||
|
||
# Step 1. Broadcast buffer to batch_shape | ||
if x.type.broadcastable != out.type.broadcastable: | ||
batch_shape = [1] * batch_ndim | ||
for inp in node.inputs: | ||
for i, (broadcastable, batch_dim) in enumerate( | ||
zip(inp.type.broadcastable[:batch_ndim], tuple(inp.shape)[:batch_ndim]) | ||
): | ||
if broadcastable: | ||
# This dimension is broadcastable, it doesn't provide shape information | ||
continue | ||
if batch_shape[i] != 1: | ||
# We already found a source of shape for this batch dimension | ||
continue | ||
batch_shape[i] = batch_dim | ||
x = broadcast_to(x, (*batch_shape, *x.shape[batch_ndim:])) | ||
assert x.type.broadcastable == out.type.broadcastable | ||
|
||
# Step 2. Massage indices so they respect blockwise semantics | ||
if isinstance(core_op, IncSubtensor): | ||
# For basic IncSubtensor there are two cases: | ||
# 1. Slice entries -> We need to squeeze away dummy dimensions so we can convert back to slice | ||
# 2. Integers -> Can be used as is, but we try to squeeze away dummy batch dimensions | ||
# in case we can end up with a basic IncSubtensor again | ||
core_idxs = [] | ||
counter = 0 | ||
for idx in core_op.idx_list: | ||
if isinstance(idx, slice): | ||
# Squeeze away dummy dimensions so we can convert to slice | ||
new_entries = [None, None, None] | ||
for i, entry in enumerate((idx.start, idx.stop, idx.step)): | ||
if entry is None: | ||
continue | ||
else: | ||
new_entries[i] = new_entry = idxs[counter].squeeze() | ||
counter += 1 | ||
if new_entry.ndim > 0: | ||
# If the slice entry has dimensions after the squeeze we can't convert it to a slice | ||
# We could try to convert to equivalent integer indices, but nothing guarantees | ||
# that the slice is "square". | ||
return None | ||
core_idxs.append(slice(*new_entries)) | ||
else: | ||
core_idxs.append(_squeeze_left(idxs[counter])) | ||
counter += 1 | ||
else: | ||
# For AdvancedIncSubtensor we have tensor integer indices, | ||
# We need to expand batch indexes on the right, so they don't interact with core index dimensions | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It isn't clear to me why we're expanding right in this case. Is it because we're just reasoning about the batch indexes, and we want to broadcast them with the core indices in the end? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. batch indices by definition can't interact with core indices. In advanced indexing interaction happens when indexes are on the same axes. When they are on different axes they broadcast/act like orthogonal indices (i.e, you visit all combinations of index1 and index2, instead of iterating simultaneously over both) |
||
# We still squeeze on the left in case that allows us to use simpler indices | ||
core_idxs = [ | ||
_squeeze_left( | ||
shape_padright(idx, max_idx_core_ndim - idx_core_ndim), | ||
stop_at_dim=batch_ndim, | ||
) | ||
for idx, idx_core_ndim in zip(idxs, idxs_core_ndim) | ||
] | ||
|
||
# Step 3. Create new indices for the new batch dimension of x | ||
if not all( | ||
all(idx.type.broadcastable[:batch_ndim]) | ||
for idx in idxs | ||
if not isinstance(idx, slice) | ||
): | ||
# If indices have batch dimensions in the indices, they will interact with the new dimensions of x | ||
# We build vectorized indexing with new arange indices that do not interact with core indices or each other | ||
# (i.e., they broadcast) | ||
|
||
# Note: due to how numpy handles non-consecutive advanced indexing (transposing it to the front), | ||
# we don't want to create a mix of slice(None), and arange() indices for the new batch dimension, | ||
# even if not all batch dimensions have corresponding batch indices. | ||
batch_slices = [ | ||
shape_padright(arange(x_batch_shape, dtype="int64"), n) | ||
for (x_batch_shape, n) in zip( | ||
tuple(x.shape)[:batch_ndim], | ||
y_batch_bcast, | ||
tuple(y.shape)[:batch_ndim], | ||
strict=True, | ||
reversed(range(max_idx_core_ndim, max_idx_core_ndim + batch_ndim)), | ||
) | ||
) | ||
core_shape = tuple(x.shape)[batch_ndim:] | ||
x = alloc(x, *batch_shape, *core_shape) | ||
|
||
new_idxs = [slice(None)] * batch_ndim + new_idxs | ||
x_view = x[tuple(new_idxs)] | ||
|
||
# We need to introduce any implicit expand_dims on core dimension of y | ||
y_core_ndim = y.type.ndim - batch_ndim | ||
if (missing_y_core_ndim := x_view.type.ndim - batch_ndim - y_core_ndim) > 0: | ||
missing_axes = tuple(range(batch_ndim, batch_ndim + missing_y_core_ndim)) | ||
y = expand_dims(y, missing_axes) | ||
|
||
symbolic_idxs = x_view.owner.inputs[1:] | ||
new_out = op.core_op.make_node(x, y, *symbolic_idxs).outputs | ||
copy_stack_trace(node.outputs, new_out) | ||
return new_out | ||
] | ||
else: | ||
# In the case we don't have batch indices, | ||
# we can use slice(None) to broadcast the core indices to each new batch dimension of x / y | ||
batch_slices = [slice(None)] * batch_ndim | ||
|
||
new_idxs = (*batch_slices, *core_idxs) | ||
x_view = x[new_idxs] | ||
|
||
# Step 4. Introduce any implicit expand_dims on core dimension of y | ||
missing_y_core_ndim = x_view.type.ndim - y.type.ndim | ||
implicit_axes = tuple(range(batch_ndim, batch_ndim + missing_y_core_ndim)) | ||
y = _squeeze_left(expand_dims(y, implicit_axes), stop_at_dim=batch_ndim) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is it possible for the squeeze_left to cancel some of the work done by expand_dims here? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I could call the rewrite that combines dimshuffle directly but that increases a bit the code complexity. It will be called anyway later |
||
|
||
if isinstance(core_op, IncSubtensor): | ||
# Check if we can still use a basic IncSubtensor | ||
if isinstance(x_view.owner.op, Subtensor): | ||
new_props = core_op._props_dict() | ||
new_props["idx_list"] = x_view.owner.op.idx_list | ||
new_core_op = type(core_op)(**new_props) | ||
symbolic_idxs = x_view.owner.inputs[1:] | ||
new_out = new_core_op(x, y, *symbolic_idxs) | ||
else: | ||
# We need to use AdvancedSet/IncSubtensor | ||
if core_op.set_instead_of_inc: | ||
new_out = x[new_idxs].set(y) | ||
else: | ||
new_out = x[new_idxs].inc(y) | ||
else: | ||
# AdvancedIncSubtensor takes symbolic indices/slices directly, no need to create a new op | ||
symbolic_idxs = x_view.owner.inputs[1:] | ||
new_out = core_op(x, y, *symbolic_idxs) | ||
|
||
copy_stack_trace(out, new_out) | ||
return [new_out] | ||
|
||
|
||
@node_rewriter(tracks=[AdvancedSubtensor, AdvancedIncSubtensor]) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -1433,7 +1433,6 @@ def _process(self, idxs, op_inputs, pstate): | |
pprint.assign(Subtensor, SubtensorPrinter()) | ||
|
||
|
||
# TODO: Implement similar vectorize for Inc/SetSubtensor | ||
@_vectorize_node.register(Subtensor) | ||
def vectorize_subtensor(op: Subtensor, node, batch_x, *batch_idxs): | ||
"""Rewrite subtensor with non-batched indexes as another Subtensor with prepended empty slices.""" | ||
|
@@ -1756,41 +1755,30 @@ def make_node(self, x, y, *inputs): | |
def decl_view(self): | ||
return "PyArrayObject * zview = NULL;" | ||
|
||
def perform(self, node, inputs, out_): | ||
(out,) = out_ | ||
x, y = inputs[:2] | ||
indices = list(reversed(inputs[2:])) | ||
|
||
def _convert(entry): | ||
if isinstance(entry, Type): | ||
return indices.pop() | ||
elif isinstance(entry, slice): | ||
return slice( | ||
_convert(entry.start), _convert(entry.stop), _convert(entry.step) | ||
def perform(self, node, inputs, output_storage): | ||
x, y, *flat_indices = inputs | ||
|
||
flat_indices_iterator = iter(flat_indices) | ||
indices = tuple( | ||
( | ||
next(flat_indices_iterator) | ||
if isinstance(entry, Type) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What is There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Stuff like TensorType/ ScalarType (the type of a variable). They use those in this flat map thing. Actually it's always ScalarType but I left as it was |
||
else slice( | ||
None if entry.start is None else next(flat_indices_iterator), | ||
None if entry.stop is None else next(flat_indices_iterator), | ||
None if entry.step is None else next(flat_indices_iterator), | ||
) | ||
else: | ||
return entry | ||
) | ||
for entry in self.idx_list | ||
) | ||
|
||
cdata = tuple(map(_convert, self.idx_list)) | ||
if len(cdata) == 1: | ||
cdata = cdata[0] | ||
if not self.inplace: | ||
x = x.copy() | ||
sub_x = x.__getitem__(cdata) | ||
if sub_x.shape: | ||
# we've sliced out an N-D tensor with N > 0 | ||
if not self.set_instead_of_inc: | ||
sub_x += y | ||
else: | ||
# sub_x += -sub_x + y | ||
x.__setitem__(cdata, y) | ||
if self.set_instead_of_inc: | ||
x[indices] = y | ||
else: | ||
# scalar case | ||
if not self.set_instead_of_inc: | ||
x.__setitem__(cdata, sub_x + y) | ||
else: | ||
x.__setitem__(cdata, y) | ||
out[0] = x | ||
x[indices] += y | ||
output_storage[0][0] = x | ||
|
||
def c_code(self, node, name, inputs, outputs, sub): | ||
# This method delegates much of the work to helper | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Where are these dummy dimensions coming from?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Blockwise adds expand_left dims for all inputs to match output batch dims, like Elemwise does. Makes other parts of the code easier to reason about