diff --git a/pytensor/link/numba/dispatch/subtensor.py b/pytensor/link/numba/dispatch/subtensor.py index 51787daf41..8ab0f21d8e 100644 --- a/pytensor/link/numba/dispatch/subtensor.py +++ b/pytensor/link/numba/dispatch/subtensor.py @@ -1,6 +1,7 @@ import operator import sys from hashlib import sha256 +from textwrap import dedent, indent import numba import numpy as np @@ -14,13 +15,13 @@ compile_numba_function_src, ) from pytensor.link.numba.dispatch.basic import ( + create_tuple_string, generate_fallback_impl, register_funcify_and_cache_key, register_funcify_default_op_cache_key, ) from pytensor.link.numba.dispatch.compile_ops import numba_deepcopy from pytensor.tensor import TensorType -from pytensor.tensor.rewriting.subtensor import is_full_slice from pytensor.tensor.subtensor import ( AdvancedIncSubtensor, AdvancedIncSubtensor1, @@ -29,7 +30,7 @@ IncSubtensor, Subtensor, ) -from pytensor.tensor.type_other import MakeSlice, NoneTypeT, SliceType +from pytensor.tensor.type_other import MakeSlice, NoneTypeT def slice_new(self, start, stop, step): @@ -243,14 +244,6 @@ def numba_funcify_AdvancedSubtensor(op, node, **kwargs): else: _x, _y, *idxs = node.inputs - basic_idxs = [ - idx - for idx in idxs - if ( - isinstance(idx.type, NoneTypeT) - or (isinstance(idx.type, SliceType) and not is_full_slice(idx)) - ) - ] adv_idxs = [ { "axis": i, @@ -262,248 +255,401 @@ def numba_funcify_AdvancedSubtensor(op, node, **kwargs): if isinstance(idx.type, TensorType) ] - # Special implementation for consecutive integer vector indices - if ( - not basic_idxs - and len(adv_idxs) >= 2 - # Must be integer vectors - # Todo: we could allow shape=(1,) if this is the shape of x - and all( - (adv_idx["bcast"] == (False,) and adv_idx["dtype"] != "bool") - for adv_idx in adv_idxs + must_ignore_duplicates = ( + isinstance(op, AdvancedIncSubtensor) + and not op.set_instead_of_inc + and op.ignore_duplicates + # Only vector integer indices can have "duplicates", not scalars or boolean vectors + and not all( + adv_idx["ndim"] == 0 or adv_idx["dtype"] == "bool" for adv_idx in adv_idxs ) - # Must be consecutive - and not op.non_consecutive_adv_indexing(node) + ) + + # Special implementation for integer indices that respects duplicates + if ( + not must_ignore_duplicates + and len(adv_idxs) >= 1 + and all(adv_idx["dtype"] != "bool" for adv_idx in adv_idxs) + # Implementation does not support newaxis + and not any(isinstance(idx.type, NoneTypeT) for idx in idxs) ): - return numba_funcify_multiple_integer_vector_indexing(op, node, **kwargs) + return vector_integer_advanced_indexing(op, node, **kwargs) + + must_respect_duplicates = ( + isinstance(op, AdvancedIncSubtensor) + and not op.set_instead_of_inc + and not op.ignore_duplicates + # Only vector integer indices can have "duplicates", not scalars or boolean vectors + and not all( + adv_idx["ndim"] == 0 or adv_idx["dtype"] == "bool" for adv_idx in adv_idxs + ) + ) - # Other cases not natively supported by Numba (fallback to obj-mode) + # Cases natively supported by Numba if ( + # Numba indexing, like Numpy, ignores duplicates in update + not must_respect_duplicates # Numba does not support indexes with more than one dimension - any(idx["ndim"] > 1 for idx in adv_idxs) + and not any(idx["ndim"] > 1 for idx in adv_idxs) # Nor multiple vector indexes - or sum(idx["ndim"] > 0 for idx in adv_idxs) > 1 - # The default PyTensor implementation does not handle duplicate indices correctly - or ( - isinstance(op, AdvancedIncSubtensor) - and not op.set_instead_of_inc - and not ( - op.ignore_duplicates - # Only vector integer indices can have "duplicates", not scalars or boolean vectors - or all( - adv_idx["ndim"] == 0 or adv_idx["dtype"] == "bool" - for adv_idx in adv_idxs - ) - ) - ) + and not sum(idx["ndim"] > 0 for idx in adv_idxs) > 1 ): - return generate_fallback_impl(op, node, **kwargs), subtensor_op_cache_key( - op, func="fallback_impl" - ) + return numba_funcify_default_subtensor(op, node, **kwargs) - # What's left should all be supported natively by numba - return numba_funcify_default_subtensor(op, node, **kwargs) + # Otherwise fallback to obj_mode + return generate_fallback_impl(op, node, **kwargs), subtensor_op_cache_key( + op, func="fallback_impl" + ) -def _broadcasted_to(x_bcast: tuple[bool, ...], to_bcast: tuple[bool, ...]): - # Check that x is not broadcasted to y based on broadcastable info - if len(x_bcast) < len(to_bcast): - return True - for x_bcast_dim, to_bcast_dim in zip(x_bcast, to_bcast, strict=True): - if x_bcast_dim and not to_bcast_dim: - return True - return False +@register_funcify_and_cache_key(AdvancedIncSubtensor1) +def numba_funcify_AdvancedIncSubtensor1(op, node, **kwargs): + return vector_integer_advanced_indexing(op, node=node, **kwargs) -def numba_funcify_multiple_integer_vector_indexing( - op: AdvancedSubtensor | AdvancedIncSubtensor, node, **kwargs +def vector_integer_advanced_indexing( + op: AdvancedSubtensor1 | AdvancedSubtensor | AdvancedIncSubtensor, node, **kwargs ): - # Special-case implementation for multiple consecutive vector integer indices (and set/incsubtensor) - if isinstance(op, AdvancedSubtensor): - idxs = node.inputs[1:] - else: - idxs = node.inputs[2:] + """Implement all forms of advanced indexing (and assignment) that combine basic and vector integer indices. - first_axis = next( - i for i, idx in enumerate(idxs) if isinstance(idx.type, TensorType) - ) - try: - after_last_axis = next( - i - for i, idx in enumerate(idxs[first_axis:], start=first_axis) - if not isinstance(idx.type, TensorType) - ) - except StopIteration: - after_last_axis = len(idxs) - last_axis = after_last_axis - 1 + It does not support `newaxis` in basic indices - vector_indices = idxs[first_axis:after_last_axis] - assert all(v.type.broadcastable == (False,) for v in vector_indices) - y_is_broadcasted = False + It handles += like `np.add.at` would, accumulating add for duplicate indices. - if isinstance(op, AdvancedSubtensor): + Examples + -------- + + Codegen for an AdvancedSubtensor, with non-consecutive matrix indices, and a slice(1, None) basic index - @numba_basic.numba_njit - def advanced_subtensor_multiple_vector(x, *idxs): - none_slices = idxs[:first_axis] - vec_idxs = idxs[first_axis:after_last_axis] - - x_shape = x.shape - idx_shape = vec_idxs[0].shape - shape_bef = x_shape[:first_axis] - shape_aft = x_shape[after_last_axis:] - out_shape = (*shape_bef, *idx_shape, *shape_aft) - out_buffer = np.empty(out_shape, dtype=x.dtype) - for i, scalar_idxs in enumerate(zip(*vec_idxs)): - out_buffer[(*none_slices, i)] = x[(*none_slices, *scalar_idxs)] + .. code-block:: python + + # AdvancedSubtensor [id A] + # ├─ [id B] + # ├─ [[1 2] [2 1]] [id C] + # ├─ SliceConstant{1, None, None} [id D] + # └─ [[0 0] [0 0]] [id E] + + + def advanced_integer_vector_indexing(x, idx0, idx1, idx2): + # Move advanced indexed dims to the front (if needed) + x_adv_dims_front = x.transpose((0, 2, 1)) + + # Perform basic indexing once (if needed) + basic_indexed_x = x_adv_dims_front[:, :, idx1] + + # Broadcast indices + adv_idx_shape = np.broadcast_shapes(idx0.shape, idx2.shape) + (idx0, idx2) = ( + np.broadcast_to(idx0, adv_idx_shape), + np.broadcast_to(idx2, adv_idx_shape), + ) + + # Create output buffer + adv_idx_size = idx0.size + basic_idx_shape = basic_indexed_x.shape[2:] + out_buffer = np.empty((adv_idx_size, *basic_idx_shape), dtype=x.dtype) + + # Index over tuples of raveled advanced indices and write to output buffer + for i, scalar_idxs in enumerate(zip(idx0.ravel(), idx2.ravel())): + out_buffer[i] = basic_indexed_x[scalar_idxs] + + # Unravel out_buffer (if needed) + out_buffer = out_buffer.reshape((*adv_idx_shape, *basic_idx_shape)) + + # Move advanced output indexing group to its final position (if needed) and return return out_buffer - ret_func = advanced_subtensor_multiple_vector - else: - inplace = op.inplace - - # Check if y must be broadcasted - # Includes the last integer vector index, - x, y = node.inputs[:2] - indexed_bcast_dims = ( - *x.type.broadcastable[:first_axis], - *x.type.broadcastable[last_axis:], - ) - y_is_broadcasted = _broadcasted_to(y.type.broadcastable, indexed_bcast_dims) + Codegen for similar AdvancedSetSubtensor - if op.set_instead_of_inc: + .. code-block::python - @numba_basic.numba_njit - def advanced_set_subtensor_multiple_vector(x, y, *idxs): - vec_idxs = idxs[first_axis:after_last_axis] - x_shape = x.shape + AdvancedSetSubtensor [id A] + ├─ x [id B] + ├─ y [id C] + ├─ [1 2] [id D] + ├─ SliceConstant{None, None, None} [id E] + └─ [3 4] [id F] - if inplace: - out = x - else: - out = x.copy() + def set_advanced_integer_vector_indexing(x, y, idx0, idx1, idx2): + # Expand dims of y explicitly (if needed) + y = y - if y_is_broadcasted: - y = np.broadcast_to(y, x_shape[:first_axis] + x_shape[last_axis:]) + # Copy x (if not inplace) + x = x.copy() - for outer in np.ndindex(x_shape[:first_axis]): - for i, scalar_idxs in enumerate(zip(*vec_idxs)): - out[(*outer, *scalar_idxs)] = y[(*outer, i)] - return out + # Move advanced indexed dims to the front (if needed) + # This will remain a view of x + x_adv_dims_front = x.transpose((0, 2, 1)) - ret_func = advanced_set_subtensor_multiple_vector + # Perform basic indexing once (if needed) + # This will remain a view of x + basic_indexed_x = x_adv_dims_front[:, :, idx1] - else: + # Broadcast indices + adv_idx_shape = np.broadcast_shapes(idx0.shape, idx2.shape) + (idx0, idx2) = (np.broadcast_to(idx0, adv_idx_shape), np.broadcast_to(idx2, adv_idx_shape)) - @numba_basic.numba_njit - def advanced_inc_subtensor_multiple_vector(x, y, *idxs): - vec_idxs = idxs[first_axis:after_last_axis] - x_shape = x.shape + # Move advanced indexed dims to the front (if needed) + y_adv_dims_front = y - if inplace: - out = x - else: - out = x.copy() + # Broadcast y to the shape of each assignment/update + adv_idx_shape = idx0.shape + basic_idx_shape = basic_indexed_x.shape[2:] + y_bcast = np.broadcast_to(y_adv_dims_front, (*adv_idx_shape, *basic_idx_shape)) - if y_is_broadcasted: - y = np.broadcast_to(y, x_shape[:first_axis] + x_shape[last_axis:]) + # Ravel the advanced dims (if needed) + # Note that numba reshape only supports C-arrays, so we ravel before reshape + y_bcast = y_bcast - for outer in np.ndindex(x_shape[:first_axis]): - for i, scalar_idxs in enumerate(zip(*vec_idxs)): - out[(*outer, *scalar_idxs)] += y[(*outer, i)] - return out + # Index over tuples of raveled advanced indices and update buffer + for i, scalar_idxs in enumerate(zip(idx0, idx2)): + basic_indexed_x[scalar_idxs] = y_bcast[i] - ret_func = advanced_inc_subtensor_multiple_vector + # Return the original x, with the entries updated + return x - cache_key = subtensor_op_cache_key( - op, - func="multiple_integer_vector_indexing", - y_is_broadcasted=y_is_broadcasted, - first_axis=first_axis, - last_axis=last_axis, - ) - return ret_func, cache_key + Codegen for an AdvancedIncSubtensor, with two contiguous advanced groups not in the leading axis -@register_funcify_and_cache_key(AdvancedIncSubtensor1) -def numba_funcify_AdvancedIncSubtensor1(op, node, **kwargs): - inplace = op.inplace - set_instead_of_inc = op.set_instead_of_inc - x, vals, _idxs = node.inputs - broadcast_with_index = vals.type.ndim < x.type.ndim or vals.type.broadcastable[0] - # TODO: Add runtime_broadcast check - - if set_instead_of_inc: - if broadcast_with_index: - - @numba_basic.numba_njit(boundscheck=True) - def advancedincsubtensor1_inplace(x, val, idxs): - if val.ndim == x.ndim: - core_val = val[0] - elif val.ndim == 0: - # Workaround for https://github.com/numba/numba/issues/9573 - core_val = val.item() - else: - core_val = val - - for idx in idxs: - x[idx] = core_val - return x + .. code-block::python - else: + AdvancedIncSubtensor [id A] + ├─ x [id B] + ├─ y [id C] + ├─ SliceConstant{1, None, None} [id D] + ├─ [1 2] [id E] + └─ [3 4] [id F] - @numba_basic.numba_njit(boundscheck=True) - def advancedincsubtensor1_inplace(x, vals, idxs): - if not len(idxs) == len(vals): - raise ValueError("The number of indices and values must match.") - # no strict argument because incompatible with numba - for idx, val in zip(idxs, vals): - x[idx] = val - return x + def inc_advanced_integer_vector_indexing(x, y, idx0, idx1, idx2): + # Expand dims of y explicitly (if needed) + y = y + + # Copy x (if not inplace) + x = x.copy() + + # Move advanced indexed dims to the front (if needed) + # This will remain a view of x + x_adv_dims_front = x.transpose((1, 2, 0)) + + # Perform basic indexing once (if needed) + # This will remain a view of x + basic_indexed_x = x_adv_dims_front[:, :, idx0] + + # Broadcast indices + adv_idx_shape = np.broadcast_shapes(idx1.shape, idx2.shape) + (idx1, idx2) = (np.broadcast_to(idx1, adv_idx_shape), np.broadcast_to(idx2, adv_idx_shape)) + + # Move advanced indexed dims to the front (if needed) + y_adv_dims_front = y.transpose((1, 0)) + + # Broadcast y to the shape of each assignment/update + adv_idx_shape = idx1.shape + basic_idx_shape = basic_indexed_x.shape[2:] + y_bcast = np.broadcast_to(y_adv_dims_front, (*adv_idx_shape, *basic_idx_shape)) + + # Ravel the advanced dims (if needed) + # Note that numba reshape only supports C-arrays, so we ravel before reshape + y_bcast = y_bcast + + # Index over tuples of raveled advanced indices and update buffer + for i, scalar_idxs in enumerate(zip(idx1, idx2)): + basic_indexed_x[scalar_idxs] += y_bcast[i] + + # Return the original x, with the entries updated + return x + + """ + if isinstance(op, AdvancedSubtensor1 | AdvancedSubtensor): + x, *idxs = node.inputs else: - if broadcast_with_index: - - @numba_basic.numba_njit(boundscheck=True) - def advancedincsubtensor1_inplace(x, val, idxs): - if val.ndim == x.ndim: - core_val = val[0] - elif val.ndim == 0: - # Workaround for https://github.com/numba/numba/issues/9573 - core_val = val.item() - else: - core_val = val - - for idx in idxs: - x[idx] += core_val - return x + x, y, *idxs = node.inputs + [out] = node.outputs - else: + adv_indices_pos = tuple( + i for i, idx in enumerate(idxs) if isinstance(idx.type, TensorType) + ) + assert adv_indices_pos # Otherwise it's just basic indexing + basic_indices_pos = tuple( + i for i, idx in enumerate(idxs) if not isinstance(idx.type, TensorType) + ) + explicit_basic_indices_pos = (*basic_indices_pos, *range(len(idxs), x.type.ndim)) - @numba_basic.numba_njit(boundscheck=True) - def advancedincsubtensor1_inplace(x, vals, idxs): - if not len(idxs) == len(vals): - raise ValueError("The number of indices and values must match.") - # no strict argument because unsupported by numba - # TODO: this doesn't come up in tests - for idx, val in zip(idxs, vals): - x[idx] += val - return x + # Create index signature and split them among basic and advanced + idx_signature = ", ".join(f"idx{i}" for i in range(len(idxs))) + adv_indices = [f"idx{i}" for i in adv_indices_pos] + basic_indices = [f"idx{i}" for i in basic_indices_pos] - cache_key = subtensor_op_cache_key( - op, - func="numba_funcify_advancedincsubtensor1", - broadcast_with_index=broadcast_with_index, + # Define transpose axis so that advanced indexing dims are on the front + adv_axis_front_order = (*adv_indices_pos, *explicit_basic_indices_pos) + adv_axis_front_transpose_needed = adv_axis_front_order != tuple(range(x.ndim)) + adv_idx_ndim = max(idxs[i].ndim for i in adv_indices_pos) + + # Helper needed for basic indexing after moving advanced indices to the front + basic_indices_with_none_slices = ", ".join( + (*((":",) * len(adv_indices)), *basic_indices) ) - if inplace: - return advancedincsubtensor1_inplace, cache_key + # Position of the first advanced index dimension after indexing the array + if (np.diff(adv_indices_pos) > 1).any(): + # If not consecutive, it's always at the front + out_adv_axis_pos = 0 + else: + # Otherwise wherever the first advanced index is located + out_adv_axis_pos = adv_indices_pos[0] + + to_tuple = create_tuple_string # alias to make code more readable below + + if isinstance(op, AdvancedSubtensor1 | AdvancedSubtensor): + # Define transpose axis on the output to restore original meaning + # After (potentially) having transposed advanced indexing dims to the front unlike numpy + _final_axis_order = list(range(adv_idx_ndim, out.type.ndim)) + for i in range(adv_idx_ndim): + _final_axis_order.insert(out_adv_axis_pos + i, i) + final_axis_order = tuple(_final_axis_order) + del _final_axis_order + final_axis_transpose_needed = final_axis_order != tuple(range(out.type.ndim)) + + func_name = "advanced_integer_vector_indexing" + codegen = dedent( + f""" + def {func_name}(x, {idx_signature}): + # Move advanced indexed dims to the front (if needed) + x_adv_dims_front = {f"x.transpose({adv_axis_front_order})" if adv_axis_front_transpose_needed else "x"} + + # Perform basic indexing once (if needed) + basic_indexed_x = {f"x_adv_dims_front[{basic_indices_with_none_slices}]" if basic_indices else "x_adv_dims_front"} + """ + ) + if len(adv_indices) > 1: + codegen += indent( + dedent( + f""" + # Broadcast indices + adv_idx_shape = np.broadcast_shapes{to_tuple([f"{idx}.shape" for idx in adv_indices])} + {to_tuple(adv_indices)} = {to_tuple([f"np.broadcast_to({idx}, adv_idx_shape)" for idx in adv_indices])} + """ + ), + " " * 4, + ) + codegen += indent( + dedent( + f""" + # Create output buffer + adv_idx_size = {adv_indices[0]}.size + basic_idx_shape = basic_indexed_x.shape[{len(adv_indices)}:] + out_buffer = np.empty((adv_idx_size, *basic_idx_shape), dtype=x.dtype) + + # Index over tuples of raveled advanced indices and write to output buffer + for i, scalar_idxs in enumerate(zip{to_tuple([f"{idx}.ravel()" for idx in adv_indices] if adv_idx_ndim != 1 else adv_indices)}): + out_buffer[i] = basic_indexed_x[scalar_idxs] + + # Unravel out_buffer (if needed) + out_buffer = {f"out_buffer.reshape((*{adv_indices[0]}.shape, *basic_idx_shape))" if adv_idx_ndim != 1 else "out_buffer"} + + # Move advanced output indexing group to its final position (if needed) and return + return {f"out_buffer.transpose({final_axis_order})" if final_axis_transpose_needed else "out_buffer"} + """ + ), + " " * 4, + ) else: + # Make implicit dims of y explicit to simplify code + # Numba doesn't support `np.expand_dims` with multiple axis, so we use indexing with newaxis + indexed_ndim = x[tuple(idxs)].type.ndim + y_expand_dims = [":"] * y.type.ndim + y_implicit_dims = range(indexed_ndim - y.type.ndim) + for axis in y_implicit_dims: + y_expand_dims.insert(axis, "None") + + # We transpose the advanced dimensions of x to the front for indexing + # We may have to do the same for y + # Note that if there are non-contiguous advanced indices, + # y must already be aligned with the indices jumping to the front + y_adv_axis_front_order = tuple( + range( + # Position of the first advanced axis after indexing + out_adv_axis_pos, + # Position of the last advanced axis after indexing + out_adv_axis_pos + adv_idx_ndim, + ) + ) + y_order = tuple(range(indexed_ndim)) + y_adv_axis_front_order = ( + *y_adv_axis_front_order, + # Basic indices, after explicit_expand_dims + *(o for o in y_order if o not in y_adv_axis_front_order), + ) + y_adv_axis_front_transpose_needed = y_adv_axis_front_order != y_order + + func_name = f"{'set' if op.set_instead_of_inc else 'inc'}_advanced_integer_vector_indexing" + codegen = dedent( + f""" + def {func_name}(x, y, {idx_signature}): + # Expand dims of y explicitly (if needed) + y = {f"y[{', '.join(y_expand_dims)},]" if y_implicit_dims else "y"} + + # Copy x (if not inplace) + x = {"x" if op.inplace else "x.copy()"} + + # Move advanced indexed dims to the front (if needed) + # This will remain a view of x + x_adv_dims_front = {f"x.transpose({adv_axis_front_order})" if adv_axis_front_transpose_needed else "x"} + + # Perform basic indexing once (if needed) + # This will remain a view of x + basic_indexed_x = {f"x_adv_dims_front[{basic_indices_with_none_slices}]" if basic_indices else "x_adv_dims_front"} + """ + ) + if len(adv_indices) > 1: + codegen += indent( + dedent( + f""" + # Broadcast indices + adv_idx_shape = np.broadcast_shapes{to_tuple([f"{idx}.shape" for idx in adv_indices])} + {to_tuple(adv_indices)} = {to_tuple([f"np.broadcast_to({idx}, adv_idx_shape)" for idx in adv_indices])} + """ + ), + " " * 4, + ) + codegen += indent( + dedent( + f""" + # Move advanced indexed dims to the front (if needed) + y_adv_dims_front = {f"y.transpose({y_adv_axis_front_order})" if y_adv_axis_front_transpose_needed else "y"} + + # Broadcast y to the shape of each assignment/update + adv_idx_shape = {adv_indices[0]}.shape + basic_idx_shape = basic_indexed_x.shape[{len(adv_indices)}:] + y_bcast = np.broadcast_to(y_adv_dims_front, (*adv_idx_shape, *basic_idx_shape)) + + # Ravel the advanced dims (if needed) + # Note that numba reshape only supports C-arrays, so we ravel before reshape + y_bcast = {"y_bcast.ravel().reshape((-1, *basic_idx_shape))" if adv_idx_ndim != 1 else "y_bcast"} + + # Index over tuples of raveled advanced indices and update buffer + for i, scalar_idxs in enumerate(zip{to_tuple([f"{idx}.ravel()" for idx in adv_indices] if adv_idx_ndim != 1 else adv_indices)}): + basic_indexed_x[scalar_idxs] {"=" if op.set_instead_of_inc else "+="} y_bcast[i] + + # Return the original x, with the entries updated + return x + """ + ), + " " * 4, + ) - @numba_basic.numba_njit - def advancedincsubtensor1(x, vals, idxs): - x = x.copy() - return advancedincsubtensor1_inplace(x, vals, idxs) + cache_key = subtensor_op_cache_key( + op, + codegen=codegen, + ) - return advancedincsubtensor1, cache_key + ret_func = numba_basic.numba_njit( + compile_numba_function_src( + codegen, + function_name=func_name, + global_env=globals(), + cache_key=cache_key, + ) + ) + return ret_func, cache_key diff --git a/pytensor/tensor/rewriting/subtensor.py b/pytensor/tensor/rewriting/subtensor.py index fbe97b9a68..7ad8d087b3 100644 --- a/pytensor/tensor/rewriting/subtensor.py +++ b/pytensor/tensor/rewriting/subtensor.py @@ -83,7 +83,7 @@ inc_subtensor, indices_from_subtensor, ) -from pytensor.tensor.type import TensorType, integer_dtypes +from pytensor.tensor.type import TensorType from pytensor.tensor.type_other import NoneTypeT, SliceType from pytensor.tensor.variable import TensorConstant, TensorVariable @@ -1744,205 +1744,45 @@ def local_blockwise_inc_subtensor(fgraph, node): @node_rewriter(tracks=[AdvancedSubtensor, AdvancedIncSubtensor]) -def ravel_multidimensional_bool_idx(fgraph, node): - """Convert multidimensional boolean indexing into equivalent vector boolean index, supported by Numba +def bool_idx_to_nonzero(fgraph, node): + """Convert boolean indexing into equivalent vector boolean index, supported by our dispatch - x[eye(3, dtype=bool)] -> x.ravel()[eye(3).ravel()] - x[eye(3, dtype=bool)].set(y) -> x.ravel()[eye(3).ravel()].set(y).reshape(x.shape) + x[eye(3, dtype=bool)] -> x[eye(3).ravel().nonzero()] """ if isinstance(node.op, AdvancedSubtensor): x, *idxs = node.inputs else: x, y, *idxs = node.inputs - if any( - ( - (isinstance(idx.type, TensorType) and idx.type.dtype in integer_dtypes) - or isinstance(idx.type, NoneTypeT) - ) - for idx in idxs - ): - # Get out if there are any other advanced indexes or np.newaxis - return None - - bool_idxs = [ - (i, idx) + bool_pos = { + i for i, idx in enumerate(idxs) if (isinstance(idx.type, TensorType) and idx.dtype == "bool") - ] - - if len(bool_idxs) != 1: - # Get out if there are no or multiple boolean idxs - return None + } - [(bool_idx_pos, bool_idx)] = bool_idxs - bool_idx_ndim = bool_idx.type.ndim - if bool_idx.type.ndim < 2: - # No need to do anything if it's a vector or scalar, as it's already supported by Numba + if not bool_pos: return None - x_shape = x.shape - raveled_x = x.reshape( - (*x_shape[:bool_idx_pos], -1, *x_shape[bool_idx_pos + bool_idx_ndim :]) - ) - - raveled_bool_idx = bool_idx.ravel() - new_idxs = list(idxs) - new_idxs[bool_idx_pos] = raveled_bool_idx + new_idxs = [] + for i, idx in enumerate(idxs): + if i in bool_pos: + new_idxs.extend(idx.nonzero()) + else: + new_idxs.append(idx) if isinstance(node.op, AdvancedSubtensor): - new_out = node.op(raveled_x, *new_idxs) + new_out = node.op(x, *new_idxs) else: - # The dimensions of y that correspond to the boolean indices - # must already be raveled in the original graph, so we don't need to do anything to it - new_out = node.op(raveled_x, y, *new_idxs) - # But we must reshape the output to math the original shape - new_out = new_out.reshape(x_shape) + new_out = node.op(x, y, *new_idxs) return [copy_stack_trace(node.outputs[0], new_out)] -@node_rewriter(tracks=[AdvancedSubtensor, AdvancedIncSubtensor]) -def ravel_multidimensional_int_idx(fgraph, node): - """Convert multidimensional integer indexing into equivalent consecutive vector integer index, - supported by Numba or by our specialized dispatchers - - x[eye(3)] -> x[eye(3).ravel()].reshape((3, 3)) - - NOTE: This is very similar to the rewrite `local_replace_AdvancedSubtensor` except it also handles non-full slices - - x[eye(3), 2:] -> x[eye(3).ravel(), 2:].reshape((3, 3, ...)), where ... are the remaining output shapes - - It also handles multiple integer indices, but only if they don't broadcast - - x[eye(3,), 2:, eye(3)] -> x[eye(3).ravel(), eye(3).ravel(), 2:].reshape((3, 3, ...)), where ... are the remaining output shapes - - Also handles AdvancedIncSubtensor, but only if the advanced indices are consecutive and neither indices nor y broadcast - - x[eye(3), 2:].set(y) -> x[eye(3).ravel(), 2:].set(y.reshape(-1, y.shape[1:])) - - """ - op = node.op - non_consecutive_adv_indexing = op.non_consecutive_adv_indexing(node) - is_inc_subtensor = isinstance(op, AdvancedIncSubtensor) - - if is_inc_subtensor: - x, y, *idxs = node.inputs - # Inc/SetSubtensor is harder to reason about due to y - # We get out if it's broadcasting or if the advanced indices are non-consecutive - if non_consecutive_adv_indexing or ( - y.type.broadcastable != x[tuple(idxs)].type.broadcastable - ): - return None - - else: - x, *idxs = node.inputs - - if any( - ( - (isinstance(idx.type, TensorType) and idx.type.dtype == "bool") - or isinstance(idx.type, NoneTypeT) - ) - for idx in idxs - ): - # Get out if there are any other advanced indices or np.newaxis - return None - - int_idxs_and_pos = [ - (i, idx) - for i, idx in enumerate(idxs) - if (isinstance(idx.type, TensorType) and idx.dtype in integer_dtypes) - ] - - if not int_idxs_and_pos: - return None - - int_idxs_pos, int_idxs = zip( - *int_idxs_and_pos, strict=False - ) # strict=False because by definition it's true - - first_int_idx_pos = int_idxs_pos[0] - first_int_idx = int_idxs[0] - first_int_idx_bcast = first_int_idx.type.broadcastable - - if any(int_idx.type.broadcastable != first_int_idx_bcast for int_idx in int_idxs): - # We don't have a view-only broadcasting operation - # Explicitly broadcasting the indices can incur a memory / copy overhead - return None - - int_idxs_ndim = len(first_int_idx_bcast) - if ( - int_idxs_ndim == 0 - ): # This should be a basic indexing operation, rewrite elsewhere - return None - - int_idxs_need_raveling = int_idxs_ndim > 1 - if not (int_idxs_need_raveling or non_consecutive_adv_indexing): - # Numba or our dispatch natively supports consecutive vector indices, nothing needs to be done - return None - - # Reorder non-consecutive indices - if non_consecutive_adv_indexing: - assert not is_inc_subtensor # Sanity check that we got out if this was the case - # This case works as if all the advanced indices were on the front - transposition = list(int_idxs_pos) + [ - i for i in range(len(idxs)) if i not in int_idxs_pos - ] - idxs = tuple(idxs[a] for a in transposition) - x = x.transpose(transposition) - first_int_idx_pos = 0 - del int_idxs_pos # Make sure they are not wrongly used - - # Ravel multidimensional indices - if int_idxs_need_raveling: - idxs = list(idxs) - for idx_pos, int_idx in enumerate(int_idxs, start=first_int_idx_pos): - idxs[idx_pos] = int_idx.ravel() - - # Index with reordered and/or raveled indices - new_subtensor = x[tuple(idxs)] - - if is_inc_subtensor: - y_shape = tuple(y.shape) - y_raveled_shape = ( - *y_shape[:first_int_idx_pos], - -1, - *y_shape[first_int_idx_pos + int_idxs_ndim :], - ) - y_raveled = y.reshape(y_raveled_shape) - - new_out = inc_subtensor( - new_subtensor, - y_raveled, - set_instead_of_inc=op.set_instead_of_inc, - ignore_duplicates=op.ignore_duplicates, - inplace=op.inplace, - ) - - else: - # Unravel advanced indexing dimensions - raveled_shape = tuple(new_subtensor.shape) - unraveled_shape = ( - *raveled_shape[:first_int_idx_pos], - *first_int_idx.shape, - *raveled_shape[first_int_idx_pos + 1 :], - ) - new_out = new_subtensor.reshape(unraveled_shape) - - return [copy_stack_trace(node.outputs[0], new_out)] - - -optdb["specialize"].register( - ravel_multidimensional_bool_idx.__name__, - ravel_multidimensional_bool_idx, - "numba", - use_db_name_as_tag=False, # Not included if only "specialize" is requested -) - optdb["specialize"].register( - ravel_multidimensional_int_idx.__name__, - ravel_multidimensional_int_idx, + bool_idx_to_nonzero.__name__, + bool_idx_to_nonzero, "numba", + "shape_unsafe", # It can mask invalid mask sizes use_db_name_as_tag=False, # Not included if only "specialize" is requested ) diff --git a/pytensor/tensor/subtensor.py b/pytensor/tensor/subtensor.py index d7fc1bedbc..1e21e67726 100644 --- a/pytensor/tensor/subtensor.py +++ b/pytensor/tensor/subtensor.py @@ -2629,9 +2629,13 @@ def make_node(self, x, *indices): advanced_indices = [] adv_group_axis = None last_adv_group_axis = None - expanded_x_shape = tuple( - np.insert(np.array(x.type.shape, dtype=object), 1, new_axes) - ) + if new_axes: + expanded_x_shape_list = list(x.type.shape) + for new_axis in new_axes: + expanded_x_shape_list.insert(new_axis, 1) + expanded_x_shape = tuple(expanded_x_shape_list) + else: + expanded_x_shape = x.type.shape for i, (idx, dim_length) in enumerate( zip_longest(explicit_indices, expanded_x_shape, fillvalue=NoneSliceConst) ): diff --git a/tests/link/numba/test_subtensor.py b/tests/link/numba/test_subtensor.py index 592f8af2fb..b700172779 100644 --- a/tests/link/numba/test_subtensor.py +++ b/tests/link/numba/test_subtensor.py @@ -109,117 +109,95 @@ def test_AdvancedSubtensor1_out_of_bounds(): @pytest.mark.parametrize( - "x, indices, objmode_needed", + "x, indices", [ - # Single vector indexing (supported natively by Numba) + # Single vector indexing ( as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))), (0, [1, 2, 2, 3]), - False, ), ( as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))), (np.array([True, False, False])), - False, ), - # Single multidimensional indexing (supported after specialization rewrites) + # Single multidimensional indexing ( as_tensor(np.arange(3 * 3).reshape((3, 3))), (np.eye(3).astype(int)), - False, ), ( as_tensor(np.arange(3 * 3).reshape((3, 3))), (np.eye(3).astype(bool)), - False, ), ( as_tensor(np.arange(3 * 3 * 2).reshape((3, 3, 2))), (np.eye(3).astype(int)), - False, ), ( as_tensor(np.arange(3 * 3 * 2).reshape((3, 3, 2))), (np.eye(3).astype(bool)), - False, ), ( as_tensor(np.arange(2 * 3 * 3).reshape((2, 3, 3))), (slice(2, None), np.eye(3).astype(int)), - False, ), ( as_tensor(np.arange(2 * 3 * 3).reshape((2, 3, 3))), (slice(2, None), np.eye(3).astype(bool)), - False, ), - # Multiple vector indexing (supported by our dispatcher) + # Multiple vector indexing ( pt.as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))), ([1, 2], [2, 3]), - False, ), ( as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))), (slice(None), [1, 2], [3, 4]), - False, ), ( as_tensor(np.arange(3 * 5 * 7).reshape((3, 5, 7))), ([1, 2], [3, 4], [5, 6]), - False, ), - # Non-consecutive vector indexing, supported by our dispatcher after rewriting + # Non-consecutive vector indexing ( as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))), ([1, 2], slice(None), [3, 4]), - False, ), - # Multiple multidimensional integer indexing (supported by our dispatcher) + # Multiple multidimensional integer indexing ( as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))), ([[1, 2], [2, 1]], [[0, 0], [0, 0]]), - False, ), ( as_tensor(np.arange(2 * 3 * 4 * 5).reshape((2, 3, 4, 5))), (slice(None), [[1, 2], [2, 1]], slice(None), [[0, 0], [0, 0]]), - False, ), - # Multiple multidimensional indexing with broadcasting, only supported in obj mode + # Multiple multidimensional indexing with broadcasting ( as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))), ([[1, 2], [2, 1]], [0, 0]), - True, ), - # multiple multidimensional integer indexing mixed with basic indexing, only supported in obj mode + # multiple multidimensional integer indexing mixed with basic indexing ( as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))), ([[1, 2], [2, 1]], slice(1, None), [[0, 0], [0, 0]]), - True, ), ], ) @pytest.mark.filterwarnings("error") # Raise if we did not expect objmode to be needed -def test_AdvancedSubtensor(x, indices, objmode_needed): +def test_AdvancedSubtensor(x, indices): """Test NumPy's advanced indexing in more than one dimension.""" x_pt = x.type() out_pt = x_pt[indices] assert isinstance(out_pt.owner.op, AdvancedSubtensor) - with ( - pytest.warns( - UserWarning, - match="Numba will use object mode to run AdvancedSubtensor's perform method", - ) - if objmode_needed - else contextlib.nullcontext() - ): - compare_numba_and_py( - [x_pt], - [out_pt], - [x.data], - numba_mode=numba_mode.including("specialize"), - ) + compare_numba_and_py( + [x_pt], + [out_pt], + [x.data], + # Specialize allows running boolean indexing without falling back to object mode + # Thanks to bool_idx_to_nonzero rewrite + numba_mode=numba_mode.including("specialize"), + ) @pytest.mark.parametrize( @@ -323,7 +301,7 @@ def test_AdvancedIncSubtensor1(x, y, indices): @pytest.mark.parametrize( - "x, y, indices, duplicate_indices, set_requires_objmode, inc_requires_objmode", + "x, y, indices, duplicate_indices, duplicate_indices_require_obj_mode", [ ( np.arange(3 * 4 * 5).reshape((3, 4, 5)), @@ -331,7 +309,6 @@ def test_AdvancedIncSubtensor1(x, y, indices): (slice(None, None, 2), [1, 2, 3]), # Mixed basic and vector index False, False, - False, ), ( np.arange(3 * 4 * 5).reshape((3, 4, 5)), @@ -343,7 +320,6 @@ def test_AdvancedIncSubtensor1(x, y, indices): ), # Mixed basic and broadcasted vector idx False, False, - False, ), ( np.arange(3 * 4 * 5).reshape((3, 4, 5)), @@ -351,7 +327,6 @@ def test_AdvancedIncSubtensor1(x, y, indices): (slice(None, None, 2), [1, 2, 3]), # Mixed basic and vector idx False, False, - False, ), ( np.arange(3 * 4 * 5).reshape((3, 4, 5)), @@ -359,7 +334,6 @@ def test_AdvancedIncSubtensor1(x, y, indices): (0, [1, 2, 2, 3]), # Broadcasted vector index with repeated values True, False, - True, ), ( np.arange(3 * 4 * 5).reshape((3, 4, 5)), @@ -367,21 +341,11 @@ def test_AdvancedIncSubtensor1(x, y, indices): (0, [1, 2, 2, 3]), # Broadcasted vector index with repeated values True, False, - True, ), ( np.arange(3 * 4 * 5).reshape((3, 4, 5)), -np.arange(1 * 4 * 5).reshape(1, 4, 5), (np.array([True, False, False])), # Broadcasted boolean index - False, # It shouldn't matter what we set this to, boolean indices cannot be duplicate - False, - False, - ), - ( - np.arange(3 * 4 * 5).reshape((3, 4, 5)), - -np.arange(1 * 4 * 5).reshape(1, 4, 5), - (np.array([True, False, False])), # Broadcasted boolean index - True, # It shouldn't matter what we set this to, boolean indices cannot be duplicate False, False, ), @@ -391,7 +355,6 @@ def test_AdvancedIncSubtensor1(x, y, indices): (np.eye(3).astype(bool)), # Boolean index False, False, - False, ), ( np.arange(3 * 3 * 5).reshape((3, 3, 5)), @@ -402,7 +365,6 @@ def test_AdvancedIncSubtensor1(x, y, indices): ), # Boolean index, mixed with basic index False, False, - False, ), ( np.arange(3 * 4 * 5).reshape((3, 4, 5)), @@ -410,7 +372,6 @@ def test_AdvancedIncSubtensor1(x, y, indices): ([1, 2], [2, 3]), # 2 vector indices False, False, - False, ), ( np.arange(3 * 4 * 5).reshape((3, 4, 5)), @@ -418,7 +379,6 @@ def test_AdvancedIncSubtensor1(x, y, indices): (slice(None), [1, 2], [2, 3]), # 2 vector indices False, False, - False, ), ( np.arange(3 * 4 * 6).reshape((3, 4, 6)), @@ -426,7 +386,6 @@ def test_AdvancedIncSubtensor1(x, y, indices): ([1, 2], [2, 3], [4, 5]), # 3 vector indices False, False, - False, ), ( np.arange(3 * 4 * 5).reshape((3, 4, 5)), @@ -434,15 +393,13 @@ def test_AdvancedIncSubtensor1(x, y, indices): ([1, 2], [2, 3]), # 2 vector indices False, False, - False, ), ( np.arange(3 * 4 * 5).reshape((3, 4, 5)), rng.poisson(size=(2, 4)), ([1, 2], slice(None), [3, 4]), # Non-consecutive vector indices False, - True, - True, + False, ), ( np.arange(3 * 4 * 5).reshape((3, 4, 5)), @@ -453,8 +410,7 @@ def test_AdvancedIncSubtensor1(x, y, indices): [3, 4], ), # Mixed double vector index and basic index False, - True, - True, + False, ), ( np.arange(5), @@ -462,7 +418,6 @@ def test_AdvancedIncSubtensor1(x, y, indices): ([[1, 2], [2, 3]]), # matrix index False, False, - False, ), ( np.arange(3 * 5).reshape((3, 5)), @@ -470,23 +425,20 @@ def test_AdvancedIncSubtensor1(x, y, indices): (slice(1, 3), [[1, 2], [2, 3]]), # matrix index, mixed with basic index False, False, - False, ), ( np.arange(3 * 5).reshape((3, 5)), rng.poisson(size=(1, 2, 2)), # Same as before, but Y broadcasts (slice(1, 3), [[1, 2], [2, 3]]), False, - True, - True, + False, ), ( np.arange(3 * 4 * 5).reshape((3, 4, 5)), rng.poisson(size=(2, 5)), ([1, 1], [2, 2]), # Repeated indices True, - False, - False, + True, ), ( np.arange(3 * 4 * 5).reshape((3, 4, 5)), @@ -494,7 +446,6 @@ def test_AdvancedIncSubtensor1(x, y, indices): (slice(None), [[1, 2], [2, 1]], [[2, 3], [0, 0]]), # 2 matrix indices False, False, - False, ), ], ) @@ -505,8 +456,7 @@ def test_AdvancedIncSubtensor( y, indices, duplicate_indices, - set_requires_objmode, - inc_requires_objmode, + duplicate_indices_require_obj_mode, inplace, ): # Need rewrite to support certain forms of advanced indexing without object mode @@ -518,17 +468,9 @@ def test_AdvancedIncSubtensor( out_pt = set_subtensor(x_pt[indices], y_pt, inplace=inplace) assert isinstance(out_pt.owner.op, AdvancedIncSubtensor) - with ( - pytest.warns( - UserWarning, - match="Numba will use object mode to run AdvancedSetSubtensor's perform method", - ) - if set_requires_objmode - else contextlib.nullcontext() - ): - fn, _ = compare_numba_and_py( - [x_pt, y_pt], out_pt, [x, y], numba_mode=mode, inplace=inplace - ) + fn, _ = compare_numba_and_py( + [x_pt, y_pt], out_pt, [x, y], numba_mode=mode, inplace=inplace + ) if inplace: # Test updates inplace @@ -536,23 +478,58 @@ def test_AdvancedIncSubtensor( fn(x, y + 1) assert not np.all(x == x_orig) - out_pt = inc_subtensor( - x_pt[indices], y_pt, ignore_duplicates=not duplicate_indices, inplace=inplace - ) + out_pt = inc_subtensor(x_pt[indices], y_pt, inplace=inplace) assert isinstance(out_pt.owner.op, AdvancedIncSubtensor) - with ( - pytest.warns( - UserWarning, - match="Numba will use object mode to run AdvancedIncSubtensor's perform method", - ) - if inc_requires_objmode - else contextlib.nullcontext() - ): - fn, _ = compare_numba_and_py( - [x_pt, y_pt], out_pt, [x, y], numba_mode=mode, inplace=inplace - ) + + fn, _ = compare_numba_and_py( + [x_pt, y_pt], out_pt, [x, y], numba_mode=mode, inplace=inplace + ) if inplace: # Test updates inplace x_orig = x.copy() fn(x, y) assert not np.all(x == x_orig) + + if duplicate_indices: + # If inc_subtensor is called with `ignore_duplicates=True`, and it's not one of the cases supported by Numba + # We have to fall back to obj_mode + out_pt = inc_subtensor( + x_pt[indices], y_pt, inplace=inplace, ignore_duplicates=True + ) + assert isinstance(out_pt.owner.op, AdvancedIncSubtensor) + + with ( + pytest.warns( + UserWarning, + match="Numba will use object mode to run AdvancedIncSubtensor's perform method", + ) + if duplicate_indices_require_obj_mode + else contextlib.nullcontext() + ): + fn, _ = compare_numba_and_py( + [x_pt, y_pt], out_pt, [x, y], numba_mode=mode, inplace=inplace + ) + if inplace: + # Test updates inplace + x_orig = x.copy() + fn(x, y) + assert not np.all(x == x_orig) + + +def test_advanced_indexing_with_newaxis_fallback_obj_mode(): + # This should be automatically solved with https://github.com/pymc-devs/pytensor/issues/1564 + # After which we can add these parametrizations to the relevant tests above + x = pt.matrix("x") + out = x[None, [0, 1, 2], [0, 1, 2]] + with pytest.warns( + UserWarning, + match=r"Numba will use object mode to run AdvancedSubtensor's perform method", + ): + compare_numba_and_py([x], [out], [np.random.normal(size=(4, 4))]) + + out = x[None, [0, 1, 2], [0, 1, 2]].inc(5) + with pytest.warns( + UserWarning, + match=r"Numba will use object mode to run AdvancedIncSubtensor's perform method", + ): + compare_numba_and_py([x], [out], [np.random.normal(size=(4, 4))]) diff --git a/tests/tensor/test_subtensor.py b/tests/tensor/test_subtensor.py index d8dadf0009..c6918238b4 100644 --- a/tests/tensor/test_subtensor.py +++ b/tests/tensor/test_subtensor.py @@ -1856,6 +1856,7 @@ def test_static_shape(self): assert x[idx1].type.shape == (10, None) assert x[:, idx1].type.shape == (None, 10) + assert x[None, :, idx1].type.shape == (1, None, 10) assert x[idx2, :5].type.shape == (3, None, None) assert specify_shape(x, (None, 7))[idx2, :5].type.shape == (3, None, 5) assert specify_shape(x, (None, 3))[idx2, :5].type.shape == (3, None, 3)