From 2b643b54b9852d23736a325a86be983022060361 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Mon, 28 Jul 2025 15:41:08 +0200 Subject: [PATCH 1/4] Fix python implementation of IncSubtensor --- pytensor/tensor/subtensor.py | 49 ++++++++++++++---------------------- 1 file changed, 19 insertions(+), 30 deletions(-) diff --git a/pytensor/tensor/subtensor.py b/pytensor/tensor/subtensor.py index 334382d132..84bf8137db 100644 --- a/pytensor/tensor/subtensor.py +++ b/pytensor/tensor/subtensor.py @@ -1756,41 +1756,30 @@ def make_node(self, x, y, *inputs): def decl_view(self): return "PyArrayObject * zview = NULL;" - def perform(self, node, inputs, out_): - (out,) = out_ - x, y = inputs[:2] - indices = list(reversed(inputs[2:])) - - def _convert(entry): - if isinstance(entry, Type): - return indices.pop() - elif isinstance(entry, slice): - return slice( - _convert(entry.start), _convert(entry.stop), _convert(entry.step) + def perform(self, node, inputs, output_storage): + x, y, *flat_indices = inputs + + flat_indices_iterator = iter(flat_indices) + indices = tuple( + ( + next(flat_indices_iterator) + if isinstance(entry, Type) + else slice( + None if entry.start is None else next(flat_indices_iterator), + None if entry.stop is None else next(flat_indices_iterator), + None if entry.step is None else next(flat_indices_iterator), ) - else: - return entry + ) + for entry in self.idx_list + ) - cdata = tuple(map(_convert, self.idx_list)) - if len(cdata) == 1: - cdata = cdata[0] if not self.inplace: x = x.copy() - sub_x = x.__getitem__(cdata) - if sub_x.shape: - # we've sliced out an N-D tensor with N > 0 - if not self.set_instead_of_inc: - sub_x += y - else: - # sub_x += -sub_x + y - x.__setitem__(cdata, y) + if self.set_instead_of_inc: + x[indices] = y else: - # scalar case - if not self.set_instead_of_inc: - x.__setitem__(cdata, sub_x + y) - else: - x.__setitem__(cdata, y) - out[0] = x + x[indices] += y + output_storage[0][0] = x def c_code(self, node, name, inputs, outputs, sub): # This method delegates much of the work to helper From d89c9cdcd8f30c91c2416a6206a6b180b9c1025f Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Mon, 28 Jul 2025 15:43:20 +0200 Subject: [PATCH 2/4] Add developer note to rewrite --- pytensor/tensor/rewriting/blockwise.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/pytensor/tensor/rewriting/blockwise.py b/pytensor/tensor/rewriting/blockwise.py index 4879f86a72..7eaca262ad 100644 --- a/pytensor/tensor/rewriting/blockwise.py +++ b/pytensor/tensor/rewriting/blockwise.py @@ -99,6 +99,8 @@ def local_blockwise_alloc(fgraph, node): BOp(vector, alloc(vector, 10, 5)) -> alloc(BOp)(vector, vector), 10, 5) BOp(vector, alloc(scalar, 10, 5)) -> alloc(BOp)(vector, alloc(scalar, 5), 10, 5) BOp(matrix, alloc(vector, 10, 5)) -> BOp(matrix, vector) + + This is critical to remove many unnecessary Blockwise, or to reduce the work done by it """ op: Blockwise = node.op # type: ignore From ed37630928cf62c12adfe5cbcec591e1094e1c62 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Mon, 28 Jul 2025 16:08:19 +0200 Subject: [PATCH 3/4] Move subtensor blockwise rewrite --- pytensor/tensor/rewriting/blockwise.py | 24 ------------------------ pytensor/tensor/rewriting/subtensor.py | 23 +++++++++++++++++++++++ 2 files changed, 23 insertions(+), 24 deletions(-) diff --git a/pytensor/tensor/rewriting/blockwise.py b/pytensor/tensor/rewriting/blockwise.py index 7eaca262ad..7b70bf8860 100644 --- a/pytensor/tensor/rewriting/blockwise.py +++ b/pytensor/tensor/rewriting/blockwise.py @@ -17,7 +17,6 @@ AdvancedIncSubtensor, AdvancedSubtensor, Subtensor, - indices_from_subtensor, ) @@ -229,29 +228,6 @@ def local_blockwise_reshape(fgraph, node): return [new_out] -@register_stabilize -@register_specialize -@node_rewriter([Blockwise]) -def local_blockwise_of_subtensor(fgraph, node): - """Rewrite Blockwise of Subtensor, where the only batch input is the indexed tensor. - - Blockwise(Subtensor{a: b})(x, a, b) -> x[:, a:b] when x has one batch dimension, and a/b none - """ - if not isinstance(node.op.core_op, Subtensor): - return - - x, *idxs = node.inputs - if not all(all(idx.type.broadcastable) for idx in idxs): - return - - core_idxs = indices_from_subtensor( - [idx.squeeze() for idx in idxs], node.op.core_op.idx_list - ) - # Add empty slices for the batch dims - none_slices = (slice(None),) * node.op.batch_ndim(node) - return [x[(*none_slices, *core_idxs)]] - - class InplaceBlockwiseOptimizer(InplaceGraphOptimizer): op = Blockwise diff --git a/pytensor/tensor/rewriting/subtensor.py b/pytensor/tensor/rewriting/subtensor.py index 31b8bfd2bd..5006ae5cd5 100644 --- a/pytensor/tensor/rewriting/subtensor.py +++ b/pytensor/tensor/rewriting/subtensor.py @@ -1573,6 +1573,29 @@ def local_uint_constant_indices(fgraph, node): ) +@register_stabilize +@register_specialize +@node_rewriter([Blockwise]) +def local_blockwise_of_subtensor(fgraph, node): + """Rewrite Blockwise of Subtensor, where the only batch input is the indexed tensor. + + Blockwise(Subtensor{a: b})(x, a, b) -> x[:, a:b] when x has one batch dimension, and a/b none + """ + if not isinstance(node.op.core_op, Subtensor): + return + + x, *idxs = node.inputs + if not all(all(idx.type.broadcastable) for idx in idxs): + return + + core_idxs = indices_from_subtensor( + [idx.squeeze() for idx in idxs], node.op.core_op.idx_list + ) + # Add empty slices for the batch dims + none_slices = (slice(None),) * node.op.batch_ndim(node) + return [x[(*none_slices, *core_idxs)]] + + @register_canonicalize("shape_unsafe") @register_stabilize("shape_unsafe") @register_specialize("shape_unsafe") From 9e9d8ed342a2af16942e7257dd857e89a6fc7b7f Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Mon, 28 Jul 2025 15:43:06 +0200 Subject: [PATCH 4/4] Rewrite Blockwise IncSubtensor Also cover cases of AdvancedIncSubtensor with batch indices that were not supported before --- pytensor/tensor/rewriting/subtensor.py | 195 +++++++++++----- pytensor/tensor/subtensor.py | 1 - tests/tensor/rewriting/test_subtensor.py | 275 ++++++++++++++++------- 3 files changed, 333 insertions(+), 138 deletions(-) diff --git a/pytensor/tensor/rewriting/subtensor.py b/pytensor/tensor/rewriting/subtensor.py index 5006ae5cd5..13c50ee489 100644 --- a/pytensor/tensor/rewriting/subtensor.py +++ b/pytensor/tensor/rewriting/subtensor.py @@ -24,6 +24,7 @@ ScalarFromTensor, TensorFromScalar, alloc, + arange, cast, concatenate, expand_dims, @@ -34,9 +35,10 @@ switch, ) from pytensor.tensor.basic import constant as tensor_constant -from pytensor.tensor.blockwise import Blockwise +from pytensor.tensor.blockwise import Blockwise, _squeeze_left from pytensor.tensor.elemwise import Elemwise from pytensor.tensor.exceptions import NotScalarConstantError +from pytensor.tensor.extra_ops import broadcast_to from pytensor.tensor.math import ( add, and_, @@ -58,6 +60,7 @@ ) from pytensor.tensor.shape import ( shape_padleft, + shape_padright, shape_tuple, ) from pytensor.tensor.sharedvar import TensorSharedVariable @@ -1580,6 +1583,9 @@ def local_blockwise_of_subtensor(fgraph, node): """Rewrite Blockwise of Subtensor, where the only batch input is the indexed tensor. Blockwise(Subtensor{a: b})(x, a, b) -> x[:, a:b] when x has one batch dimension, and a/b none + + TODO: Handle batched indices like we do with blockwise of inc_subtensor + TODO: Extend to AdvanceSubtensor """ if not isinstance(node.op.core_op, Subtensor): return @@ -1600,64 +1606,151 @@ def local_blockwise_of_subtensor(fgraph, node): @register_stabilize("shape_unsafe") @register_specialize("shape_unsafe") @node_rewriter([Blockwise]) -def local_blockwise_advanced_inc_subtensor(fgraph, node): - """Rewrite blockwise advanced inc_subtensor whithout batched indexes as an inc_subtensor with prepended empty slices.""" - if not isinstance(node.op.core_op, AdvancedIncSubtensor): - return None +def local_blockwise_inc_subtensor(fgraph, node): + """Rewrite blockwised inc_subtensors. - x, y, *idxs = node.inputs + Note: The reason we don't apply this rewrite eagerly in the `vectorize_node` dispatch + Is that we often have batch dimensions from alloc of shapes/reshape that can be removed by rewrites - # It is currently not possible to Vectorize such AdvancedIncSubtensor, but we check again just in case - if any( - ( - isinstance(idx, SliceType | NoneTypeT) - or (idx.type.dtype == "bool" and idx.type.ndim > 0) - ) - for idx in idxs - ): + such as x[:vectorized(w.shape[0])].set(y), that will later be rewritten as x[:w.shape[1]].set(y), + and can be safely rewritten without Blockwise. + """ + core_op = node.op.core_op + if not isinstance(core_op, AdvancedIncSubtensor | IncSubtensor): return None - op: Blockwise = node.op # type: ignore - batch_ndim = op.batch_ndim(node) - - new_idxs = [] - for idx in idxs: - if all(idx.type.broadcastable[:batch_ndim]): - new_idxs.append(idx.squeeze(tuple(range(batch_ndim)))) - else: - # Rewrite does not apply + x, y, *idxs = node.inputs + [out] = node.outputs + if isinstance(node.op.core_op, AdvancedIncSubtensor): + if any( + ( + # Blockwise requires all inputs to be tensors so it is not possible + # to wrap an AdvancedIncSubtensor with slice / newaxis inputs, but we check again just in case + # If this is ever supported we need to pay attention to special behavior of numpy when advanced indices + # are separated by basic indices + isinstance(idx, SliceType | NoneTypeT) + # Also get out if we have boolean indices as they cross dimension boundaries + # / can't be safely broadcasted depending on their runtime content + or (idx.type.dtype == "bool") + ) + for idx in idxs + ): return None - x_batch_bcast = x.type.broadcastable[:batch_ndim] - y_batch_bcast = y.type.broadcastable[:batch_ndim] - if any(xb and not yb for xb, yb in zip(x_batch_bcast, y_batch_bcast, strict=True)): - # Need to broadcast batch x dims - batch_shape = tuple( - x_dim if (not xb or yb) else y_dim - for xb, x_dim, yb, y_dim in zip( - x_batch_bcast, + batch_ndim = node.op.batch_ndim(node) + idxs_core_ndim = [len(inp_sig) for inp_sig in node.op.inputs_sig[2:]] + max_idx_core_ndim = max(idxs_core_ndim, default=0) + + # Step 1. Broadcast buffer to batch_shape + if x.type.broadcastable != out.type.broadcastable: + batch_shape = [1] * batch_ndim + for inp in node.inputs: + for i, (broadcastable, batch_dim) in enumerate( + zip(inp.type.broadcastable[:batch_ndim], tuple(inp.shape)[:batch_ndim]) + ): + if broadcastable: + # This dimension is broadcastable, it doesn't provide shape information + continue + if batch_shape[i] != 1: + # We already found a source of shape for this batch dimension + continue + batch_shape[i] = batch_dim + x = broadcast_to(x, (*batch_shape, *x.shape[batch_ndim:])) + assert x.type.broadcastable == out.type.broadcastable + + # Step 2. Massage indices so they respect blockwise semantics + if isinstance(core_op, IncSubtensor): + # For basic IncSubtensor there are two cases: + # 1. Slice entries -> We need to squeeze away dummy dimensions so we can convert back to slice + # 2. Integers -> Can be used as is, but we try to squeeze away dummy batch dimensions + # in case we can end up with a basic IncSubtensor again + core_idxs = [] + counter = 0 + for idx in core_op.idx_list: + if isinstance(idx, slice): + # Squeeze away dummy dimensions so we can convert to slice + new_entries = [None, None, None] + for i, entry in enumerate((idx.start, idx.stop, idx.step)): + if entry is None: + continue + else: + new_entries[i] = new_entry = idxs[counter].squeeze() + counter += 1 + if new_entry.ndim > 0: + # If the slice entry has dimensions after the squeeze we can't convert it to a slice + # We could try to convert to equivalent integer indices, but nothing guarantees + # that the slice is "square". + return None + core_idxs.append(slice(*new_entries)) + else: + core_idxs.append(_squeeze_left(idxs[counter])) + counter += 1 + else: + # For AdvancedIncSubtensor we have tensor integer indices, + # We need to expand batch indexes on the right, so they don't interact with core index dimensions + # We still squeeze on the left in case that allows us to use simpler indices + core_idxs = [ + _squeeze_left( + shape_padright(idx, max_idx_core_ndim - idx_core_ndim), + stop_at_dim=batch_ndim, + ) + for idx, idx_core_ndim in zip(idxs, idxs_core_ndim) + ] + + # Step 3. Create new indices for the new batch dimension of x + if not all( + all(idx.type.broadcastable[:batch_ndim]) + for idx in idxs + if not isinstance(idx, slice) + ): + # If indices have batch dimensions in the indices, they will interact with the new dimensions of x + # We build vectorized indexing with new arange indices that do not interact with core indices or each other + # (i.e., they broadcast) + + # Note: due to how numpy handles non-consecutive advanced indexing (transposing it to the front), + # we don't want to create a mix of slice(None), and arange() indices for the new batch dimension, + # even if not all batch dimensions have corresponding batch indices. + batch_slices = [ + shape_padright(arange(x_batch_shape, dtype="int64"), n) + for (x_batch_shape, n) in zip( tuple(x.shape)[:batch_ndim], - y_batch_bcast, - tuple(y.shape)[:batch_ndim], - strict=True, + reversed(range(max_idx_core_ndim, max_idx_core_ndim + batch_ndim)), ) - ) - core_shape = tuple(x.shape)[batch_ndim:] - x = alloc(x, *batch_shape, *core_shape) - - new_idxs = [slice(None)] * batch_ndim + new_idxs - x_view = x[tuple(new_idxs)] - - # We need to introduce any implicit expand_dims on core dimension of y - y_core_ndim = y.type.ndim - batch_ndim - if (missing_y_core_ndim := x_view.type.ndim - batch_ndim - y_core_ndim) > 0: - missing_axes = tuple(range(batch_ndim, batch_ndim + missing_y_core_ndim)) - y = expand_dims(y, missing_axes) - - symbolic_idxs = x_view.owner.inputs[1:] - new_out = op.core_op.make_node(x, y, *symbolic_idxs).outputs - copy_stack_trace(node.outputs, new_out) - return new_out + ] + else: + # In the case we don't have batch indices, + # we can use slice(None) to broadcast the core indices to each new batch dimension of x / y + batch_slices = [slice(None)] * batch_ndim + + new_idxs = (*batch_slices, *core_idxs) + x_view = x[new_idxs] + + # Step 4. Introduce any implicit expand_dims on core dimension of y + missing_y_core_ndim = x_view.type.ndim - y.type.ndim + implicit_axes = tuple(range(batch_ndim, batch_ndim + missing_y_core_ndim)) + y = _squeeze_left(expand_dims(y, implicit_axes), stop_at_dim=batch_ndim) + + if isinstance(core_op, IncSubtensor): + # Check if we can still use a basic IncSubtensor + if isinstance(x_view.owner.op, Subtensor): + new_props = core_op._props_dict() + new_props["idx_list"] = x_view.owner.op.idx_list + new_core_op = type(core_op)(**new_props) + symbolic_idxs = x_view.owner.inputs[1:] + new_out = new_core_op(x, y, *symbolic_idxs) + else: + # We need to use AdvancedSet/IncSubtensor + if core_op.set_instead_of_inc: + new_out = x[new_idxs].set(y) + else: + new_out = x[new_idxs].inc(y) + else: + # AdvancedIncSubtensor takes symbolic indices/slices directly, no need to create a new op + symbolic_idxs = x_view.owner.inputs[1:] + new_out = core_op(x, y, *symbolic_idxs) + + copy_stack_trace(out, new_out) + return [new_out] @node_rewriter(tracks=[AdvancedSubtensor, AdvancedIncSubtensor]) diff --git a/pytensor/tensor/subtensor.py b/pytensor/tensor/subtensor.py index 84bf8137db..c1bc4158d2 100644 --- a/pytensor/tensor/subtensor.py +++ b/pytensor/tensor/subtensor.py @@ -1433,7 +1433,6 @@ def _process(self, idxs, op_inputs, pstate): pprint.assign(Subtensor, SubtensorPrinter()) -# TODO: Implement similar vectorize for Inc/SetSubtensor @_vectorize_node.register(Subtensor) def vectorize_subtensor(op: Subtensor, node, batch_x, *batch_idxs): """Rewrite subtensor with non-batched indexes as another Subtensor with prepended empty slices.""" diff --git a/tests/tensor/rewriting/test_subtensor.py b/tests/tensor/rewriting/test_subtensor.py index 4cb2b0f4cd..95f84790d9 100644 --- a/tests/tensor/rewriting/test_subtensor.py +++ b/tests/tensor/rewriting/test_subtensor.py @@ -1790,101 +1790,204 @@ def test_local_uint_constant_indices(): assert new_index.type.dtype == "uint8" -@pytest.mark.parametrize("core_y_implicitly_batched", (False, True)) -@pytest.mark.parametrize("set_instead_of_inc", (True, False)) -def test_local_blockwise_advanced_inc_subtensor( - set_instead_of_inc, core_y_implicitly_batched -): - rng = np.random.default_rng([1764, set_instead_of_inc, core_y_implicitly_batched]) - - def np_inplace_f(x, idx, y): - if core_y_implicitly_batched: - y = y[..., None] - if set_instead_of_inc: - x[idx] = y - else: - x[idx] += y - - core_y_shape = () if core_y_implicitly_batched else (3,) - core_x = tensor("x", shape=(6,)) - core_y = tensor("y", shape=core_y_shape, dtype=int) - core_idxs = [0, 2, 4] - if set_instead_of_inc: - core_graph = set_subtensor(core_x[core_idxs], core_y) - else: - core_graph = inc_subtensor(core_x[core_idxs], core_y) +class TestBlockwiseIncSubtensor: + @staticmethod + def compile_fn_and_ref(*args, **kwargs): + fn = pytensor.function(*args, **kwargs, mode="FAST_RUN") + ref_fn = pytensor.function( + *args, **kwargs, mode=Mode(linker="py", optimizer=None) + ) + return fn, ref_fn - # Only x is batched - x = tensor("x", shape=(5, 2, 6)) - y = tensor("y", shape=core_y_shape, dtype=int) - out = vectorize_graph(core_graph, replace={core_x: x, core_y: y}) - assert isinstance(out.owner.op, Blockwise) + @staticmethod + def has_blockwise(fn): + return any( + isinstance(node.op, Blockwise) for node in fn.maker.fgraph.apply_nodes + ) - fn = pytensor.function([x, y], out, mode="FAST_RUN") - assert not any( - isinstance(node.op, Blockwise) for node in fn.maker.fgraph.apply_nodes + @pytest.mark.parametrize( + "core_y_implicitly_batched", (False, True), ids=["y_explicit", "y_implicit"] ) + @pytest.mark.parametrize("set_instead_of_inc", (True, False), ids=["set", "inc"]) + @pytest.mark.parametrize("basic_idx", (True, False), ids=["basic_idx", "adv_idx"]) + def test_idxs_not_vectorized( + self, basic_idx, set_instead_of_inc, core_y_implicitly_batched + ): + rng = np.random.default_rng( + [1764, set_instead_of_inc, core_y_implicitly_batched, basic_idx] + ) - test_x = np.ones(x.type.shape, dtype=x.type.dtype) - test_y = rng.integers(1, 10, size=y.type.shape, dtype=y.type.dtype) - expected_out = test_x.copy() - np_inplace_f(expected_out, np.s_[:, :, core_idxs], test_y) - np.testing.assert_allclose(fn(test_x, test_y), expected_out) - - # Only y is batched - x = tensor("y", shape=(6,)) - y = tensor("y", shape=(2, *core_y_shape), dtype=int) - out = vectorize_graph(core_graph, replace={core_x: x, core_y: y}) - assert isinstance(out.owner.op, Blockwise) - - fn = pytensor.function([x, y], out, mode="FAST_RUN") - assert not any( - isinstance(node.op, Blockwise) for node in fn.maker.fgraph.apply_nodes - ) + core_y_shape = () if core_y_implicitly_batched else (3,) + core_x = tensor("x", shape=(6, 6)) + core_y = tensor("y", shape=core_y_shape, dtype=int) + core_idxs = (-1, slice(None, 3)) if basic_idx else (-1, [0, 2, 4]) + if set_instead_of_inc: + core_graph = set_subtensor(core_x[core_idxs], core_y) + else: + core_graph = inc_subtensor(core_x[core_idxs], core_y) + assert isinstance( + core_graph.owner.op, IncSubtensor if basic_idx else AdvancedIncSubtensor + ) - test_x = np.ones(x.type.shape, dtype=x.type.dtype) - test_y = rng.integers(1, 10, size=y.type.shape, dtype=y.type.dtype) - expected_out = np.ones((2, *x.type.shape)) - np_inplace_f(expected_out, np.s_[:, core_idxs], test_y) - np.testing.assert_allclose(fn(test_x, test_y), expected_out) - - # Both x and y are batched, and do not need to be broadcasted - x = tensor("y", shape=(2, 6)) - y = tensor("y", shape=(2, *core_y_shape), dtype=int) - out = vectorize_graph(core_graph, replace={core_x: x, core_y: y}) - assert isinstance(out.owner.op, Blockwise) - - fn = pytensor.function([x, y], out, mode="FAST_RUN") - assert not any( - isinstance(node.op, Blockwise) for node in fn.maker.fgraph.apply_nodes + # Only x is batched + x = tensor("x", shape=(5, 2, 6, 6)) + y = tensor("y", shape=core_y_shape, dtype=int) + out = vectorize_graph(core_graph, replace={core_x: x, core_y: y}) + fn, ref_fn = self.compile_fn_and_ref([x, y], out) + assert self.has_blockwise(ref_fn) + assert not self.has_blockwise(fn) + test_x = np.ones(x.type.shape, dtype=x.type.dtype) + test_y = rng.integers(1, 10, size=y.type.shape, dtype=y.type.dtype) + np.testing.assert_allclose(fn(test_x, test_y), ref_fn(test_x, test_y)) + + # Only y is batched + x = tensor("y", shape=(6, 6)) + y = tensor("y", shape=(2, *core_y_shape), dtype=int) + out = vectorize_graph(core_graph, replace={core_x: x, core_y: y}) + fn, ref_fn = self.compile_fn_and_ref([x, y], out) + assert self.has_blockwise(ref_fn) + assert not self.has_blockwise(fn) + test_x = np.ones(x.type.shape, dtype=x.type.dtype) + test_y = rng.integers(1, 10, size=y.type.shape, dtype=y.type.dtype) + np.testing.assert_allclose(fn(test_x, test_y), ref_fn(test_x, test_y)) + + # Both x and y are batched, and do not need to be broadcasted + x = tensor("y", shape=(2, 6, 6)) + y = tensor("y", shape=(2, *core_y_shape), dtype=int) + out = vectorize_graph(core_graph, replace={core_x: x, core_y: y}) + fn, ref_fn = self.compile_fn_and_ref([x, y], out) + assert self.has_blockwise(ref_fn) + assert not self.has_blockwise(fn) + test_x = np.ones(x.type.shape, dtype=x.type.dtype) + test_y = rng.integers(1, 10, size=y.type.shape, dtype=y.type.dtype) + np.testing.assert_allclose(fn(test_x, test_y), ref_fn(test_x, test_y)) + + # Both x and y are batched, but must be broadcasted + x = tensor("y", shape=(5, 1, 6, 6)) + y = tensor("y", shape=(1, 2, *core_y_shape), dtype=int) + out = vectorize_graph(core_graph, replace={core_x: x, core_y: y}) + fn, ref_fn = self.compile_fn_and_ref([x, y], out) + assert self.has_blockwise(ref_fn) + assert not self.has_blockwise(fn) + test_x = np.ones(x.type.shape, dtype=x.type.dtype) + test_y = rng.integers(1, 10, size=y.type.shape, dtype=y.type.dtype) + np.testing.assert_allclose(fn(test_x, test_y), ref_fn(test_x, test_y)) + + @pytest.mark.parametrize("basic_idx", (True, False), ids=["basic_idx", "adv_idx"]) + @pytest.mark.parametrize( + "batched_y", (False, True), ids=("unbatched_y", "batched_y") ) - - test_x = np.ones(x.type.shape, dtype=x.type.dtype) - test_y = rng.integers(1, 10, size=y.type.shape, dtype=y.type.dtype) - expected_out = test_x.copy() - np_inplace_f(expected_out, np.s_[:, core_idxs], test_y) - np.testing.assert_allclose(fn(test_x, test_y), expected_out) - - # Both x and y are batched, but must be broadcasted - x = tensor("y", shape=(5, 1, 6)) - y = tensor("y", shape=(1, 2, *core_y_shape), dtype=int) - out = vectorize_graph(core_graph, replace={core_x: x, core_y: y}) - assert isinstance(out.owner.op, Blockwise) - - fn = pytensor.function([x, y], out, mode="FAST_RUN") - assert not any( - isinstance(node.op, Blockwise) for node in fn.maker.fgraph.apply_nodes + @pytest.mark.parametrize( + "batched_x", (False, True), ids=("unbatched_x", "batched_x") ) + def test_vectorized_idxs( + self, + basic_idx, + batched_y, + batched_x, + ): + rng = np.random.default_rng([1874, basic_idx, batched_y, batched_x]) + + core_x = tensor("x", shape=(6, 6)) + core_y = tensor("y", shape=(), dtype=int) + scalar_idx = scalar("scalar_idx", dtype="int64") + vector_idx = vector("vector_idx", dtype="int64") + core_idxs = ( + (slice(None, 3), scalar_idx) if basic_idx else (scalar_idx, vector_idx) + ) + core_graph = inc_subtensor(core_x[core_idxs], core_y) + assert isinstance( + core_graph.owner.op, IncSubtensor if basic_idx else AdvancedIncSubtensor + ) + + # Indices don't broadcast with each other + x = pt.tensor("x", shape=(4, 1, *core_x.type.shape)) if batched_x else core_x + y = pt.tensor("y", shape=(2,), dtype=int) if batched_y else core_y + out = vectorize_graph( + core_graph, + replace={ + scalar_idx: pt.constant([0, -1]), + vector_idx: pt.constant([[0, 2, 4], [1, 3, 5]]), + core_x: x, + core_y: y, + }, + ) + fn, ref_fn = self.compile_fn_and_ref([x, y], out) + assert self.has_blockwise(ref_fn) + assert not self.has_blockwise(fn) + test_x = np.ones(x.type.shape, dtype=core_x.type.dtype) + test_y = rng.integers(1, 10, size=y.type.shape) + np.testing.assert_allclose(ref_fn(test_x, test_y), ref_fn(test_x, test_y)) + + # Indices broadcast with each other + x = core_x + y = pt.tensor("y", shape=(2,), dtype=int) if batched_y else core_y + out = vectorize_graph( + core_graph, + replace={ + scalar_idx: pt.constant([0, -1, 0, -1])[:, None], + vector_idx: pt.constant([[0, 2, 4], [1, 3, 5]])[None, :], + core_x: x, + core_y: y, + }, + ) + fn, ref_fn = self.compile_fn_and_ref([x, y], out) + assert self.has_blockwise(ref_fn) + assert not self.has_blockwise(fn) + test_x = np.ones(core_x.type.shape, dtype=x.type.dtype) + test_y = rng.integers(1, 10, size=y.type.shape) + np.testing.assert_allclose(fn(test_x, test_y), ref_fn(test_x, test_y)) - test_x = np.ones(x.type.shape, dtype=x.type.dtype) - test_y = rng.integers(1, 10, size=y.type.shape, dtype=y.type.dtype) - final_shape = ( - *np.broadcast_shapes(x.type.shape[:2], y.type.shape[:2]), - x.type.shape[-1], + @pytest.mark.parametrize( + "basic_idx", + [ + True, + pytest.param( + False, + marks=pytest.mark.xfail( + reason="AdvancedIncSubtensor with slices can't be blockwise" + ), + ), + ], + ids=["basic_idx", "adv_idx"], + ) + @pytest.mark.parametrize( + "vectorize_idx", (False, True), ids=lambda x: f"vectorize_idx={x}" ) - expected_out = np.broadcast_to(test_x, final_shape).copy() - np_inplace_f(expected_out, np.s_[:, :, core_idxs], test_y) - np.testing.assert_allclose(fn(test_x, test_y), expected_out) + def test_non_consecutive_integer_indices(self, vectorize_idx, basic_idx): + """Test numpy special behavior of transposing non-consecutive advanced indices to the front. + + Either in the original graph (id adv_idx) or in the induced graph after rewrite + """ + + core_a = pt.tensor("a", shape=(4, 3, 2)) + core_v = pt.tensor("v", dtype="float64", shape=(3,) if basic_idx else (2, 3)) + core_idx = pt.tensor("idx", dtype=int, shape=() if basic_idx else (2,)) + + # The empty slice before core_idx, will lead to a transposition of the advanced view + # once it is paired with an new arange slice on the batched dimensions. + # That's why core_v is (2, 3), and not (3, 2), in the case of advanced indexing + core_out = core_a[0, :, core_idx].set(core_v) + + vec_a = pt.tensor(shape=(2, 2, 4, 3, 2)) + vec_idx = pt.constant([0, -1]) if vectorize_idx else pt.constant(-1, dtype=int) + vec_v = pt.constant([[0, 1, 2], [2, 1, 0]]) + if not basic_idx: + vec_idx = pt.repeat(vec_idx[..., None], 2, axis=-1) + vec_v = pt.repeat(vec_v[None], repeats=2, axis=0) + + vec_out = vectorize_graph( + core_out, + {core_a: vec_a, core_v: vec_v, core_idx: vec_idx}, + ) + + fn, ref_fn = self.compile_fn_and_ref([vec_a], vec_out) + assert self.has_blockwise(ref_fn) + assert not self.has_blockwise(fn) + test_vec_a = np.arange(np.prod(vec_a.type.shape), dtype=vec_a.dtype).reshape( + vec_a.type.shape + ) + np.testing.assert_allclose(fn(test_vec_a), ref_fn(test_vec_a)) class TestUselessSlice: