@@ -142,6 +142,30 @@ def allocate_reduction_dimension(self, size: torch.SymInt | int) -> BlockSizeInf
142142            if  rdim .reduction  and  rdim .size  ==  size :
143143                return  rdim 
144144
145+         # Check if size matches any tile dimension for symbolic equality. 
146+         # When building expressions that mix sizes derived from tiles 
147+         # (e.g., via slicing) with sizes coming directly from tile block vars, we 
148+         # want them to share the same SymInt variable whenever they are equal by 
149+         # construction. This preserves equality in the shape environment and avoids 
150+         # spurious "size mismatch" issues during fake-tensor broadcasting and 
151+         # arithmetic in type propagation. 
152+         if  isinstance (size , torch .SymInt ):
153+             size_str  =  str (size )
154+             for  block_info  in  self .block_sizes :
155+                 if  not  block_info .reduction  and  str (block_info .var ) ==  size_str :
156+                     # Create reduction dimension with the same var to preserve 
157+                     # symbolic equality and ensure all later users see identical 
158+                     # symbols (rather than equal-but-distinct SymInts). 
159+                     rdim_idx  =  self .allocate_block_size (
160+                         size ,
161+                         reduction = True ,
162+                         source = ReductionLoopBlockSizeSource (
163+                             reduction_loop = len ([b  for  b  in  self .block_sizes  if  b .reduction ])
164+                         ),
165+                     )
166+                     self .block_sizes [rdim_idx ].var  =  block_info .var 
167+                     return  self .block_sizes [rdim_idx ]
168+ 
145169        # Allocate a new reduction dimension 
146170        rdim_idx  =  self .allocate_block_size (
147171            size ,
@@ -203,6 +227,91 @@ def cached_create_unbacked_symint(
203227            self ._symint_cache [key ] =  result 
204228        return  result 
205229
230+ 
231+     def  register_tile_index_tensor_block_id (self , tensor : torch .Tensor , block_id : int ) ->  None :
232+         """Annotate ``tensor`` as originating from ``tile.index`` with ``block_id`` provenance.""" 
233+         tensor ._tile_index_block_id  =  block_id   # type: ignore[attr-defined] 
234+ 
235+     def  get_tile_index_tensor_block_id (self , tensor : torch .Tensor ) ->  int  |  None :
236+         """Return the originating ``tile.index`` block id if present.""" 
237+         return  getattr (tensor , "_tile_index_block_id" , None )
238+ 
239+     def  get_indexer_output_dims (
240+         self ,
241+         indexer_tensor : torch .Tensor ,
242+         base_dim_size : int  |  torch .SymInt  |  None ,
243+     ) ->  list [int  |  torch .SymInt ]:
244+         """Map a tensor indexer's shape to the output dimensions for advanced indexing.""" 
245+ 
246+         dims  =  list (indexer_tensor .size ())
247+         non_broadcast_dims  =  [d  for  d  in  dims  if  self .size_hint (d ) !=  1 ]
248+ 
249+         # Multi-dimensional indexer - return full shape 
250+         if  len (non_broadcast_dims ) >  1 :
251+             return  dims 
252+ 
253+         block_id  =  self .get_tile_index_tensor_block_id (indexer_tensor )
254+         if  block_id  is  None  and  base_dim_size  is  not None :
255+             block_id  =  self .get_block_id (base_dim_size )
256+         if  block_id  is  None  and  non_broadcast_dims :
257+             block_id  =  self .get_block_id (non_broadcast_dims [0 ])
258+ 
259+         if  block_id  is  not None :
260+             return  [self .block_sizes [block_id ].var ]
261+         if  non_broadcast_dims :
262+             return  [non_broadcast_dims [0 ]]
263+         return  [1 ]
264+ 
265+     def  tensor_indexer_broadcast_shape (
266+         self , tensors : typing .Sequence [torch .Tensor ]
267+     ) ->  list [int  |  torch .SymInt ] |  None :
268+         """Compute a shared broadcast shape for tensor indexers when needed.""" 
269+ 
270+         tensor_list  =  [t  for  t  in  tensors  if  isinstance (t , torch .Tensor )]
271+         if  not  tensor_list :
272+             return  None 
273+ 
274+         if  all (self .get_tile_index_tensor_block_id (t ) is  not None  for  t  in  tensor_list ):
275+             return  None 
276+ 
277+         shapes  =  [list (t .size ()) for  t  in  tensor_list ]
278+         return  compute_broadcast_shape_for_tensor_indexers (shapes , self )
279+ 
280+     def  resolve_tile_index_shape (
281+         self , input_tensor : torch .Tensor , output_shape : typing .Sequence [int  |  torch .SymInt ]
282+     ) ->  tuple [list [int  |  torch .SymInt ], int  |  None ]:
283+         """Resolve the symbolic shape for tensors derived from ``tile.index``. 
284+ 
285+         Returns a copy of ``output_shape`` where the single non-broadcast 
286+         dimension is replaced with the canonical block-symbol and the associated 
287+         block_id to register on the new tensor. If the tensor is not a tile 
288+         indexer or it introduces more than one non-broadcast dimension, the 
289+         original shape and ``None`` are returned. 
290+         """ 
291+ 
292+         block_id  =  self .get_tile_index_tensor_block_id (input_tensor )
293+         if  block_id  is  None :
294+             return  list (output_shape ), None 
295+ 
296+         resolved  =  list (output_shape )
297+         non_broadcast  =  [i  for  i , s  in  enumerate (resolved ) if  self .size_hint (s ) !=  1 ]
298+         if  len (non_broadcast ) <=  1 :
299+             if  non_broadcast :
300+                 resolved [non_broadcast [0 ]] =  self .block_sizes [block_id ].var 
301+             return  resolved , block_id 
302+         return  resolved , None 
303+ 
304+     def  new_index_result (
305+         self , tensor : torch .Tensor , output_shape : typing .Sequence [int  |  torch .SymInt ]
306+     ) ->  torch .Tensor :
307+         """Create a new tensor for indexing/view ops while preserving tile index provenance.""" 
308+ 
309+         resolved_shape , block_id  =  self .resolve_tile_index_shape (tensor , output_shape )
310+         result  =  tensor .new_empty (resolved_shape )
311+         if  block_id  is  not None :
312+             self .register_tile_index_tensor_block_id (result , block_id )
313+         return  result 
314+ 
206315    def  to_fake (self , obj : object , origin : Origin ) ->  object :
207316        if  isinstance (obj , torch .Tensor ):
208317            return  self ._to_fake_tensor (obj , origin .to_source ())
@@ -283,6 +392,10 @@ def _to_fake_tensor(self, tensor: torch.Tensor, source: Source) -> torch.Tensor:
283392                self .fake_mode , tensor , shape_env = self .shape_env , source = source 
284393            )
285394        self .input_sources [result ] =  source 
395+         if  hasattr (tensor , "_tile_index_block_id" ):
396+             self .register_tile_index_tensor_block_id (
397+                 result , typing .cast (int , getattr (tensor , "_tile_index_block_id" ))
398+             )
286399        if  isinstance (source , LocalSource ):
287400            for  i , s  in  enumerate (result .size ()):
288401                if  isinstance (s , torch .SymInt ) and  isinstance (
@@ -357,9 +470,9 @@ def current() -> CompileEnvironment:
357470    @staticmethod  
358471    def  has_current () ->  bool :
359472        try :
360-             CompileEnvironment .current ()
361-             return  True 
362-         except  NoCurrentEnvironment :
473+             CompileEnvironment .current ()	 
474+             return  True 	 
475+         except  NoCurrentEnvironment :	 
363476            return  False 
364477
365478    def  get_block_id (self , size : int  |  torch .SymInt  |  sympy .Expr ) ->  int  |  None :
@@ -535,3 +648,35 @@ def _to_sympy(x: int | torch.SymInt) -> sympy.Expr:
535648
536649def  _has_unbacked (expr : sympy .Expr ) ->  bool :
537650    return  any (n .name .startswith ("u" ) for  n  in  expr .free_symbols )  # pyright: ignore[reportAttributeAccessIssue] 
651+ 
652+ 
653+ def  compute_broadcast_shape_for_tensor_indexers (
654+     shapes : list [list [int  |  torch .SymInt ]], 
655+     env : "CompileEnvironment" 
656+ ) ->  list [int  |  torch .SymInt ]:
657+     """ 
658+     Compute broadcast shape for multiple tensor indexers using right-aligned broadcasting. 
659+      
660+     Args: 
661+         shapes: List of shapes from each tensor indexer 
662+         env: CompileEnvironment for size_hint and known_equal checks 
663+          
664+     Returns: 
665+         Broadcast shape as list of dimensions 
666+     """ 
667+     if  not  shapes :
668+         return  []
669+     
670+     max_ndim  =  max (len (s ) for  s  in  shapes )
671+     padded  =  [([1 ] *  (max_ndim  -  len (s )) +  s ) for  s  in  shapes ]
672+     broadcast_shape : list [int  |  torch .SymInt ] =  []
673+     
674+     for  dims_at_pos  in  zip (* padded , strict = True ):
675+         chosen : int  |  torch .SymInt  |  None  =  None 
676+         for  d  in  dims_at_pos :
677+             if  env .size_hint (d ) !=  1 :
678+                 if  chosen  is  None  or  env .known_equal (chosen , d ):
679+                     chosen  =  d 
680+         broadcast_shape .append (chosen  if  chosen  is  not None  else  1 )
681+     
682+     return  broadcast_shape 
0 commit comments