2424 ScalarFromTensor ,
2525 TensorFromScalar ,
2626 alloc ,
27+ arange ,
2728 cast ,
2829 concatenate ,
2930 expand_dims ,
3435 switch ,
3536)
3637from pytensor .tensor .basic import constant as tensor_constant
37- from pytensor .tensor .blockwise import Blockwise
38+ from pytensor .tensor .blockwise import Blockwise , _squeeze_left
3839from pytensor .tensor .elemwise import Elemwise
3940from pytensor .tensor .exceptions import NotScalarConstantError
41+ from pytensor .tensor .extra_ops import broadcast_to
4042from pytensor .tensor .math import (
4143 add ,
4244 and_ ,
5860)
5961from pytensor .tensor .shape import (
6062 shape_padleft ,
63+ shape_padright ,
6164 shape_tuple ,
6265)
6366from pytensor .tensor .sharedvar import TensorSharedVariable
@@ -1580,6 +1583,9 @@ def local_blockwise_of_subtensor(fgraph, node):
15801583 """Rewrite Blockwise of Subtensor, where the only batch input is the indexed tensor.
15811584
15821585 Blockwise(Subtensor{a: b})(x, a, b) -> x[:, a:b] when x has one batch dimension, and a/b none
1586+
1587+ TODO: Handle batched indices like we do with blockwise of inc_subtensor
1588+ TODO: Extend to AdvanceSubtensor
15831589 """
15841590 if not isinstance (node .op .core_op , Subtensor ):
15851591 return
@@ -1600,64 +1606,151 @@ def local_blockwise_of_subtensor(fgraph, node):
16001606@register_stabilize ("shape_unsafe" )
16011607@register_specialize ("shape_unsafe" )
16021608@node_rewriter ([Blockwise ])
1603- def local_blockwise_advanced_inc_subtensor (fgraph , node ):
1604- """Rewrite blockwise advanced inc_subtensor whithout batched indexes as an inc_subtensor with prepended empty slices."""
1605- if not isinstance (node .op .core_op , AdvancedIncSubtensor ):
1606- return None
1609+ def local_blockwise_inc_subtensor (fgraph , node ):
1610+ """Rewrite blockwised inc_subtensors.
16071611
1608- x , y , * idxs = node .inputs
1612+ Note: The reason we don't apply this rewrite eagerly in the `vectorize_node` dispatch
1613+ Is that we often have batch dimensions from alloc of shapes/reshape that can be removed by rewrites
16091614
1610- # It is currently not possible to Vectorize such AdvancedIncSubtensor, but we check again just in case
1611- if any (
1612- (
1613- isinstance (idx , SliceType | NoneTypeT )
1614- or (idx .type .dtype == "bool" and idx .type .ndim > 0 )
1615- )
1616- for idx in idxs
1617- ):
1615+ such as x[:vectorized(w.shape[0])].set(y), that will later be rewritten as x[:w.shape[1]].set(y),
1616+ and can be safely rewritten without Blockwise.
1617+ """
1618+ core_op = node .op .core_op
1619+ if not isinstance (core_op , AdvancedIncSubtensor | IncSubtensor ):
16181620 return None
16191621
1620- op : Blockwise = node .op # type: ignore
1621- batch_ndim = op .batch_ndim (node )
1622-
1623- new_idxs = []
1624- for idx in idxs :
1625- if all (idx .type .broadcastable [:batch_ndim ]):
1626- new_idxs .append (idx .squeeze (tuple (range (batch_ndim ))))
1627- else :
1628- # Rewrite does not apply
1622+ x , y , * idxs = node .inputs
1623+ [out ] = node .outputs
1624+ if isinstance (node .op .core_op , AdvancedIncSubtensor ):
1625+ if any (
1626+ (
1627+ # Blockwise requires all inputs to be tensors so it is not possible
1628+ # to wrap an AdvancedIncSubtensor with slice / newaxis inputs, but we check again just in case
1629+ # If this is ever supported we need to pay attention to special behavior of numpy when advanced indices
1630+ # are separated by basic indices
1631+ isinstance (idx , SliceType | NoneTypeT )
1632+ # Also get out if we have boolean indices as they cross dimension boundaries
1633+ # / can't be safely broadcasted depending on their runtime content
1634+ or (idx .type .dtype == "bool" )
1635+ )
1636+ for idx in idxs
1637+ ):
16291638 return None
16301639
1631- x_batch_bcast = x .type .broadcastable [:batch_ndim ]
1632- y_batch_bcast = y .type .broadcastable [:batch_ndim ]
1633- if any (xb and not yb for xb , yb in zip (x_batch_bcast , y_batch_bcast , strict = True )):
1634- # Need to broadcast batch x dims
1635- batch_shape = tuple (
1636- x_dim if (not xb or yb ) else y_dim
1637- for xb , x_dim , yb , y_dim in zip (
1638- x_batch_bcast ,
1640+ batch_ndim = node .op .batch_ndim (node )
1641+ idxs_core_ndim = [len (inp_sig ) for inp_sig in node .op .inputs_sig [2 :]]
1642+ max_idx_core_ndim = max (idxs_core_ndim , default = 0 )
1643+
1644+ # Step 1. Broadcast buffer to batch_shape
1645+ if x .type .broadcastable != out .type .broadcastable :
1646+ batch_shape = [1 ] * batch_ndim
1647+ for inp in node .inputs :
1648+ for i , (broadcastable , batch_dim ) in enumerate (
1649+ zip (inp .type .broadcastable [:batch_ndim ], tuple (inp .shape )[:batch_ndim ])
1650+ ):
1651+ if broadcastable :
1652+ # This dimension is broadcastable, it doesn't provide shape information
1653+ continue
1654+ if batch_shape [i ] != 1 :
1655+ # We already found a source of shape for this batch dimension
1656+ continue
1657+ batch_shape [i ] = batch_dim
1658+ x = broadcast_to (x , (* batch_shape , * x .shape [batch_ndim :]))
1659+ assert x .type .broadcastable == out .type .broadcastable
1660+
1661+ # Step 2. Massage indices so they respect blockwise semantics
1662+ if isinstance (core_op , IncSubtensor ):
1663+ # For basic IncSubtensor there are two cases:
1664+ # 1. Slice entries -> We need to squeeze away dummy dimensions so we can convert back to slice
1665+ # 2. Integers -> Can be used as is, but we try to squeeze away dummy batch dimensions
1666+ # in case we can end up with a basic IncSubtensor again
1667+ core_idxs = []
1668+ counter = 0
1669+ for idx in core_op .idx_list :
1670+ if isinstance (idx , slice ):
1671+ # Squeeze away dummy dimensions so we can convert to slice
1672+ new_entries = [None , None , None ]
1673+ for i , entry in enumerate ((idx .start , idx .stop , idx .step )):
1674+ if entry is None :
1675+ continue
1676+ else :
1677+ new_entries [i ] = new_entry = idxs [counter ].squeeze ()
1678+ counter += 1
1679+ if new_entry .ndim > 0 :
1680+ # If the slice entry has dimensions after the squeeze we can't convert it to a slice
1681+ # We could try to convert to equivalent integer indices, but nothing guarantees
1682+ # that the slice is "square".
1683+ return None
1684+ core_idxs .append (slice (* new_entries ))
1685+ else :
1686+ core_idxs .append (_squeeze_left (idxs [counter ]))
1687+ counter += 1
1688+ else :
1689+ # For AdvancedIncSubtensor we have tensor integer indices,
1690+ # We need to expand batch indexes on the right, so they don't interact with core index dimensions
1691+ # We still squeeze on the left in case that allows us to use simpler indices
1692+ core_idxs = [
1693+ _squeeze_left (
1694+ shape_padright (idx , max_idx_core_ndim - idx_core_ndim ),
1695+ stop_at_dim = batch_ndim ,
1696+ )
1697+ for idx , idx_core_ndim in zip (idxs , idxs_core_ndim )
1698+ ]
1699+
1700+ # Step 3. Create new indices for the new batch dimension of x
1701+ if not all (
1702+ all (idx .type .broadcastable [:batch_ndim ])
1703+ for idx in idxs
1704+ if not isinstance (idx , slice )
1705+ ):
1706+ # If indices have batch dimensions in the indices, they will interact with the new dimensions of x
1707+ # We build vectorized indexing with new arange indices that do not interact with core indices or each other
1708+ # (i.e., they broadcast)
1709+
1710+ # Note: due to how numpy handles non-consecutive advanced indexing (transposing it to the front),
1711+ # we don't want to create a mix of slice(None), and arange() indices for the new batch dimension,
1712+ # even if not all batch dimensions have corresponding batch indices.
1713+ batch_slices = [
1714+ shape_padright (arange (x_batch_shape , dtype = "int64" ), n )
1715+ for (x_batch_shape , n ) in zip (
16391716 tuple (x .shape )[:batch_ndim ],
1640- y_batch_bcast ,
1641- tuple (y .shape )[:batch_ndim ],
1642- strict = True ,
1717+ reversed (range (max_idx_core_ndim , max_idx_core_ndim + batch_ndim )),
16431718 )
1644- )
1645- core_shape = tuple (x .shape )[batch_ndim :]
1646- x = alloc (x , * batch_shape , * core_shape )
1647-
1648- new_idxs = [slice (None )] * batch_ndim + new_idxs
1649- x_view = x [tuple (new_idxs )]
1650-
1651- # We need to introduce any implicit expand_dims on core dimension of y
1652- y_core_ndim = y .type .ndim - batch_ndim
1653- if (missing_y_core_ndim := x_view .type .ndim - batch_ndim - y_core_ndim ) > 0 :
1654- missing_axes = tuple (range (batch_ndim , batch_ndim + missing_y_core_ndim ))
1655- y = expand_dims (y , missing_axes )
1656-
1657- symbolic_idxs = x_view .owner .inputs [1 :]
1658- new_out = op .core_op .make_node (x , y , * symbolic_idxs ).outputs
1659- copy_stack_trace (node .outputs , new_out )
1660- return new_out
1719+ ]
1720+ else :
1721+ # In the case we don't have batch indices,
1722+ # we can use slice(None) to broadcast the core indices to each new batch dimension of x / y
1723+ batch_slices = [slice (None )] * batch_ndim
1724+
1725+ new_idxs = (* batch_slices , * core_idxs )
1726+ x_view = x [new_idxs ]
1727+
1728+ # Step 4. Introduce any implicit expand_dims on core dimension of y
1729+ missing_y_core_ndim = x_view .type .ndim - y .type .ndim
1730+ implicit_axes = tuple (range (batch_ndim , batch_ndim + missing_y_core_ndim ))
1731+ y = _squeeze_left (expand_dims (y , implicit_axes ), stop_at_dim = batch_ndim )
1732+
1733+ if isinstance (core_op , IncSubtensor ):
1734+ # Check if we can still use a basic IncSubtensor
1735+ if isinstance (x_view .owner .op , Subtensor ):
1736+ new_props = core_op ._props_dict ()
1737+ new_props ["idx_list" ] = x_view .owner .op .idx_list
1738+ new_core_op = type (core_op )(** new_props )
1739+ symbolic_idxs = x_view .owner .inputs [1 :]
1740+ new_out = new_core_op (x , y , * symbolic_idxs )
1741+ else :
1742+ # We need to use AdvancedSet/IncSubtensor
1743+ if core_op .set_instead_of_inc :
1744+ new_out = x [new_idxs ].set (y )
1745+ else :
1746+ new_out = x [new_idxs ].inc (y )
1747+ else :
1748+ # AdvancedIncSubtensor takes symbolic indices/slices directly, no need to create a new op
1749+ symbolic_idxs = x_view .owner .inputs [1 :]
1750+ new_out = core_op (x , y , * symbolic_idxs )
1751+
1752+ copy_stack_trace (out , new_out )
1753+ return [new_out ]
16611754
16621755
16631756@node_rewriter (tracks = [AdvancedSubtensor , AdvancedIncSubtensor ])
0 commit comments