@@ -142,17 +142,47 @@ def allocate_reduction_dimension(self, size: torch.SymInt | int) -> BlockSizeInf
142142 if rdim .reduction and rdim .size == size :
143143 return rdim
144144
145+ # Check if size matches any tile dimension for symbolic equality.
146+ # When building expressions that mix sizes derived from tiles (e.g. via
147+ # slicing) with sizes coming directly from tile block vars, we want them
148+ # to share the same SymInt variable whenever they are equal by
149+ # construction. This preserves equality in the shape environment and
150+ # avoids spurious "size mismatch" issues during fake-tensor broadcasting
151+ # and arithmetic in type propagation.
152+ if isinstance (size , torch .SymInt ):
153+ block_idx = self .get_block_id (size )
154+ if block_idx is not None and not self .block_sizes [block_idx ].reduction :
155+ return self ._clone_block_size_as_reduction (block_idx , size )
156+
157+ sym = size ._sympy_ ()
158+ for block_idx , block_info in enumerate (self .block_sizes ):
159+ if not block_info .reduction and sym == block_info .symbol ():
160+ return self ._clone_block_size_as_reduction (block_idx , size )
161+
145162 # Allocate a new reduction dimension
163+ return self ._allocate_new_reduction (size )
164+
165+ def _clone_block_size_as_reduction (
166+ self , block_idx : int , size : torch .SymInt | int
167+ ) -> BlockSizeInfo :
168+ rdim = self ._allocate_new_reduction (size )
169+ rdim .var = self .block_sizes [block_idx ].var
170+ return rdim
171+
172+ def _allocate_new_reduction (self , size : torch .SymInt | int ) -> BlockSizeInfo :
146173 rdim_idx = self .allocate_block_size (
147174 size ,
148175 reduction = True ,
149176 source = ReductionLoopBlockSizeSource (
150- sum ([ int ( bs . reduction ) for bs in self .block_sizes ] )
177+ self ._next_reduction_loop_index ( )
151178 ),
152179 hint = next_power_of_2 (self .size_hint (size )),
153180 )
154181 return self .block_sizes [rdim_idx ]
155182
183+ def _next_reduction_loop_index (self ) -> int :
184+ return sum (int (info .reduction ) for info in self .block_sizes )
185+
156186 def create_block_var (self , debug_name : str , hint : int = 64 ) -> torch .SymInt :
157187 with self .shape_env .ignore_fresh_unbacked_symbols ():
158188 sym = self .shape_env .create_unbacked_symint ()
@@ -203,6 +233,91 @@ def cached_create_unbacked_symint(
203233 self ._symint_cache [key ] = result
204234 return result
205235
236+
237+ def register_tile_index_tensor_block_id (self , tensor : torch .Tensor , block_id : int ) -> None :
238+ """Annotate ``tensor`` as originating from ``tile.index`` with ``block_id`` provenance."""
239+ tensor ._tile_index_block_id = block_id # type: ignore[attr-defined]
240+
241+ def get_tile_index_tensor_block_id (self , tensor : torch .Tensor ) -> int | None :
242+ """Return the originating ``tile.index`` block id if present."""
243+ return getattr (tensor , "_tile_index_block_id" , None )
244+
245+ def get_indexer_output_dims (
246+ self ,
247+ indexer_tensor : torch .Tensor ,
248+ base_dim_size : int | torch .SymInt | None ,
249+ ) -> list [int | torch .SymInt ]:
250+ """Map a tensor indexer's shape to the output dimensions for advanced indexing."""
251+
252+ dims = list (indexer_tensor .size ())
253+ non_broadcast_dims = [d for d in dims if self .size_hint (d ) != 1 ]
254+
255+ # Multi-dimensional indexer - return full shape
256+ if len (non_broadcast_dims ) > 1 :
257+ return dims
258+
259+ block_id = self .get_tile_index_tensor_block_id (indexer_tensor )
260+ if block_id is None and base_dim_size is not None :
261+ block_id = self .get_block_id (base_dim_size )
262+ if block_id is None and non_broadcast_dims :
263+ block_id = self .get_block_id (non_broadcast_dims [0 ])
264+
265+ if block_id is not None :
266+ return [self .block_sizes [block_id ].var ]
267+ if non_broadcast_dims :
268+ return [non_broadcast_dims [0 ]]
269+ return [1 ]
270+
271+ def tensor_indexer_broadcast_shape (
272+ self , tensors : typing .Sequence [torch .Tensor ]
273+ ) -> list [int | torch .SymInt ] | None :
274+ """Compute a shared broadcast shape for tensor indexers when needed."""
275+
276+ tensor_list = [t for t in tensors if isinstance (t , torch .Tensor )]
277+ if not tensor_list :
278+ return None
279+
280+ if all (self .get_tile_index_tensor_block_id (t ) is not None for t in tensor_list ):
281+ return None
282+
283+ shapes = [list (t .size ()) for t in tensor_list ]
284+ return compute_broadcast_shape_for_tensor_indexers (shapes , self )
285+
286+ def resolve_tile_index_shape (
287+ self , input_tensor : torch .Tensor , output_shape : typing .Sequence [int | torch .SymInt ]
288+ ) -> tuple [list [int | torch .SymInt ], int | None ]:
289+ """Resolve the symbolic shape for tensors derived from ``tile.index``.
290+
291+ Returns a copy of ``output_shape`` where the single non-broadcast
292+ dimension is replaced with the canonical block-symbol and the associated
293+ block_id to register on the new tensor. If the tensor is not a tile
294+ indexer or it introduces more than one non-broadcast dimension, the
295+ original shape and ``None`` are returned.
296+ """
297+
298+ block_id = self .get_tile_index_tensor_block_id (input_tensor )
299+ if block_id is None :
300+ return list (output_shape ), None
301+
302+ resolved = list (output_shape )
303+ non_broadcast = [i for i , s in enumerate (resolved ) if self .size_hint (s ) != 1 ]
304+ if len (non_broadcast ) <= 1 :
305+ if non_broadcast :
306+ resolved [non_broadcast [0 ]] = self .block_sizes [block_id ].var
307+ return resolved , block_id
308+ return resolved , None
309+
310+ def new_index_result (
311+ self , tensor : torch .Tensor , output_shape : typing .Sequence [int | torch .SymInt ]
312+ ) -> torch .Tensor :
313+ """Create a new tensor for indexing/view ops while preserving tile index provenance."""
314+
315+ resolved_shape , block_id = self .resolve_tile_index_shape (tensor , output_shape )
316+ result = tensor .new_empty (resolved_shape )
317+ if block_id is not None :
318+ self .register_tile_index_tensor_block_id (result , block_id )
319+ return result
320+
206321 def to_fake (self , obj : object , origin : Origin ) -> object :
207322 if isinstance (obj , torch .Tensor ):
208323 return self ._to_fake_tensor (obj , origin .to_source ())
@@ -283,6 +398,10 @@ def _to_fake_tensor(self, tensor: torch.Tensor, source: Source) -> torch.Tensor:
283398 self .fake_mode , tensor , shape_env = self .shape_env , source = source
284399 )
285400 self .input_sources [result ] = source
401+ if hasattr (tensor , "_tile_index_block_id" ):
402+ self .register_tile_index_tensor_block_id (
403+ result , typing .cast (int , getattr (tensor , "_tile_index_block_id" ))
404+ )
286405 if isinstance (source , LocalSource ):
287406 for i , s in enumerate (result .size ()):
288407 if isinstance (s , torch .SymInt ) and isinstance (
@@ -535,3 +654,35 @@ def _to_sympy(x: int | torch.SymInt) -> sympy.Expr:
535654
536655def _has_unbacked (expr : sympy .Expr ) -> bool :
537656 return any (n .name .startswith ("u" ) for n in expr .free_symbols ) # pyright: ignore[reportAttributeAccessIssue]
657+
658+
659+ def compute_broadcast_shape_for_tensor_indexers (
660+ shapes : list [list [int | torch .SymInt ]],
661+ env : "CompileEnvironment"
662+ ) -> list [int | torch .SymInt ]:
663+ """
664+ Compute broadcast shape for multiple tensor indexers using right-aligned broadcasting.
665+
666+ Args:
667+ shapes: List of shapes from each tensor indexer
668+ env: CompileEnvironment for size_hint and known_equal checks
669+
670+ Returns:
671+ Broadcast shape as list of dimensions
672+ """
673+ if not shapes :
674+ return []
675+
676+ max_ndim = max (len (s ) for s in shapes )
677+ padded = [([1 ] * (max_ndim - len (s )) + s ) for s in shapes ]
678+ broadcast_shape : list [int | torch .SymInt ] = []
679+
680+ for dims_at_pos in zip (* padded , strict = True ):
681+ chosen : int | torch .SymInt | None = None
682+ for d in dims_at_pos :
683+ if env .size_hint (d ) != 1 :
684+ if chosen is None or env .known_equal (chosen , d ):
685+ chosen = d
686+ broadcast_shape .append (chosen if chosen is not None else 1 )
687+
688+ return broadcast_shape
0 commit comments