@@ -490,11 +490,20 @@ def compute_shape(
490490 output_size .append (rdim .var )
491491 else :
492492 output_size .append (1 )
493- elif isinstance (k , torch .Tensor ) and (
494- k .ndim == 1 or (len (index ) == 1 and tensor .ndim == 1 )
495- ):
496- input_size .popleft ()
497- output_size .extend (k .size ())
493+ elif isinstance (k , torch .Tensor ):
494+ # Advanced tensor indexer: consume one base dim and splice indexer shape.
495+ base_dim = input_size .popleft ()
496+ dims = list (k .size ())
497+ non_broadcast_dims = [d for d in dims if env .size_hint (d ) != 1 ]
498+
499+ # Multi-d indexer contributes its own shape
500+ if len (non_broadcast_dims ) > 1 :
501+ output_size .extend (dims )
502+ continue
503+
504+ # Single or broadcast-only indexer - use origin tracking helper
505+ size = env .get_indexer_output_size (k , base_dim )
506+ output_size .append (size )
498507 else :
499508 raise exc .InvalidIndexingType (k )
500509 assert len (input_size ) == 0 , "invalid subscript"
@@ -514,6 +523,7 @@ def create(
514523 output_size = SubscriptIndexing .compute_shape (fake_value , index )
515524 env = CompileEnvironment .current ()
516525 dtype = env .triton_index_type ()
526+
517527 for n , k in enumerate (index ):
518528 if k is None :
519529 output_idx += 1
@@ -573,16 +583,16 @@ def create(
573583 else :
574584 index_values .append (f"tl.zeros([1], { dtype } ){ expand } " )
575585 output_idx += 1
576- elif isinstance (k , torch .Tensor ) and k . ndim == 1 :
577- expand = tile_strategy . expand_str ( output_size , output_idx )
578- ast_index = state . ast_args [ 1 ]
579- assert isinstance ( ast_index , ( list , tuple ))
580- assert len ( ast_index ) == len ( index )
581- index_var = state . codegen . lift ( ast_index [ n ], prefix = "index" ). id
582- index_values . append ( f"( { index_var } ) { expand } " )
583- if ( block_idx := env . get_block_id ( output_size [ output_idx ])) is not None :
584- if mask := state . codegen . mask_var ( block_idx ):
585- mask_values . setdefault ( f"( { mask } ) { expand } " )
586+ elif isinstance (k , torch .Tensor ) and (
587+ k . ndim == 1
588+ or sum ( CompileEnvironment . current (). size_hint ( d ) != 1 for d in k . size ())
589+ <= 1
590+ ):
591+ # Broadcast-only 1D indexer
592+ SubscriptIndexing . _handle_broadcast_indexer (
593+ k , n , output_size , output_idx , index ,
594+ state , tile_strategy , index_values , mask_values , env
595+ )
586596 output_idx += 1
587597 elif (
588598 isinstance (k , torch .Tensor ) and len (index ) == 1 and fake_value .ndim == 1
@@ -601,6 +611,24 @@ def create(
601611 mask_values .setdefault (
602612 f"({ mask } ){ tile_strategy .expand_str (output_size , n )} "
603613 )
614+ elif isinstance (k , torch .Tensor ) and k .ndim > 1 and len (index ) > 1 :
615+ # Multi-dimensional tensor indexer combined with other indices
616+ non_broadcast_dims = [dim for dim in k .size () if env .size_hint (dim ) != 1 ]
617+
618+ if len (non_broadcast_dims ) <= 1 :
619+ # Broadcast-only multi-dim indexer: treat as single dimension
620+ SubscriptIndexing ._handle_broadcast_indexer (
621+ k , n , output_size , output_idx , index ,
622+ state , tile_strategy , index_values , mask_values , env
623+ )
624+ output_idx += 1
625+ else :
626+ # True multi-dim indexer: handle all dims at once
627+ SubscriptIndexing ._handle_multidim_indexer (
628+ k , n , output_size , output_idx , index ,
629+ state , tile_strategy , index_values , mask_values , env
630+ )
631+ output_idx += k .ndim
604632 else :
605633 raise exc .InvalidIndexingType (type (k ))
606634 assert len (output_size ) == output_idx
@@ -618,10 +646,96 @@ def create(
618646 if extra_mask is not None :
619647 mask_values .setdefault ("{_extra_mask}" )
620648 kwargs ["_extra_mask" ] = extra_mask
649+
621650 return SubscriptIndexing (
622651 expr_from_string ("+" .join (index_expr )),
623652 expr_from_string ("&" .join (mask_values ) or "None" , ** kwargs ),
624653 )
654+
655+ @staticmethod
656+ def _handle_broadcast_indexer (
657+ k : torch .Tensor , n : int , output_size : list , output_idx : int , index : list ,
658+ state : CodegenState , tile_strategy : Any , index_values : list ,
659+ mask_values : dict , env : CompileEnvironment
660+ ) -> None :
661+ """Handle broadcast-only tensor indexer (all dims but one are size 1)."""
662+ expand = tile_strategy .expand_str (output_size , output_idx )
663+
664+ # Try to get tile.index tensor's origin block_id
665+ tile_origin_block_id = env .get_tile_index_tensor_block_id (k )
666+
667+ if tile_origin_block_id is not None :
668+ # Use the tile_index tensor's block id directly
669+ index_var = state .codegen .index_var (tile_origin_block_id )
670+ index_values .append (f"({ index_var } ){ expand } " )
671+ if (mask := state .codegen .mask_var (tile_origin_block_id )) is not None :
672+ mask_values .setdefault (f"({ mask } ){ expand } " )
673+ else :
674+ # Lift AST to preserve expressions like tile.index + 1
675+ ast_index = state .ast_args [1 ]
676+ assert isinstance (ast_index , (list , tuple ))
677+ assert len (ast_index ) == len (index )
678+ lifted = state .codegen .lift (ast_index [n ], prefix = "index" ).id
679+ index_values .append (f"({ lifted } ){ expand } " )
680+ # Even if we lift, we still know the block-id for this axis from output_size
681+ output_block_id = env .get_block_id (output_size [output_idx ])
682+ if output_block_id is not None and (mask := state .codegen .mask_var (output_block_id )) is not None :
683+ mask_values .setdefault (f"({ mask } ){ expand } " )
684+
685+ @staticmethod
686+ def _handle_multidim_indexer (
687+ k : torch .Tensor , n : int , output_size : list , output_idx : int , index : list ,
688+ state : CodegenState , tile_strategy : Any , index_values : list ,
689+ mask_values : dict , env : CompileEnvironment
690+ ) -> None :
691+ """Handle multi-dimensional tensor indexer."""
692+ # Lift the indexer once
693+ ast_index = state .ast_args [1 ]
694+ assert isinstance (ast_index , (list , tuple ))
695+ assert len (ast_index ) == len (index )
696+ index_var = state .codegen .lift (ast_index [n ], prefix = "index" ).id
697+
698+ # Build merged broadcast bracket for all dims
699+ # Start with first dimension
700+ base = tile_strategy .expand_str (output_size , output_idx )
701+ if base == "" :
702+ tokens = []
703+ else :
704+ assert base .startswith ("[" ) and base .endswith ("]" ), base
705+ tokens = base [1 :- 1 ].split (", " ) if len (base ) > 2 else []
706+
707+ # Merge with other dimensions
708+ for d in range (1 , k .ndim ):
709+ s = tile_strategy .expand_str (output_size , output_idx + d )
710+ if s == "" :
711+ s_tokens = [":" ]
712+ else :
713+ assert s .startswith ("[" ) and s .endswith ("]" ), s
714+ s_tokens = s [1 :- 1 ].split (", " ) if len (s ) > 2 else []
715+
716+ # Merge tokens: use ':' if either has ':', else 'None'
717+ if not tokens :
718+ tokens = s_tokens
719+ elif s_tokens :
720+ tokens = [
721+ ":" if (a == ":" or b == ":" ) else "None"
722+ for a , b in zip (tokens , s_tokens , strict = True )
723+ ]
724+
725+ if tokens == [":" ] or not tokens :
726+ bracket = ""
727+ else :
728+ bracket = f"[{ ', ' .join (tokens )} ]"
729+
730+ index_values .append (f"({ index_var } ){ bracket } " )
731+
732+ # Add mask contributions for each output dim
733+ for d in range (k .ndim ):
734+ if (block_idx := env .get_block_id (output_size [output_idx + d ])) is not None :
735+ if mask := state .codegen .mask_var (block_idx ):
736+ mask_values .setdefault (
737+ f"({ mask } ){ tile_strategy .expand_str (output_size , output_idx + d )} "
738+ )
625739
626740
627741@dataclasses .dataclass
0 commit comments