@@ -227,6 +227,9 @@ def valid_block_size(
227227 for i , k in enumerate (subscript ):
228228 if k is None :
229229 continue
230+ if k is Ellipsis :
231+ # Ellipsis is not supported in tensor descriptor mode
232+ return False
230233 size , stride = size_stride .popleft ()
231234 if isinstance (k , slice ):
232235 # Slices with steps are not supported in tensor descriptor mode
@@ -447,6 +450,14 @@ def codegen_store(
447450 )
448451
449452
453+ def _calculate_ellipsis_dims (
454+ index : list [object ], current_index : int , total_dims : int
455+ ) -> int :
456+ """Calculate how many dimensions an ellipsis should expand to."""
457+ remaining_indices = len (index ) - current_index - 1
458+ return total_dims - current_index - remaining_indices
459+
460+
450461class SubscriptIndexing (NamedTuple ):
451462 index_expr : ast .AST
452463 mask_expr : ast .AST
@@ -465,9 +476,18 @@ def compute_shape(
465476 input_size = collections .deque (tensor .size ())
466477 output_size = []
467478 env = CompileEnvironment .current ()
468- for k in index :
479+ for i , k in enumerate ( index ) :
469480 if k is None :
470481 output_size .append (1 )
482+ elif k is Ellipsis :
483+ ellipsis_dims = _calculate_ellipsis_dims (index , i , len (tensor .size ()))
484+ for _ in range (ellipsis_dims ):
485+ size = input_size .popleft ()
486+ if size != 1 :
487+ rdim = env .allocate_reduction_dimension (size )
488+ output_size .append (rdim .var )
489+ else :
490+ output_size .append (1 )
471491 elif isinstance (k , int ):
472492 input_size .popleft ()
473493 elif isinstance (k , torch .SymInt ):
@@ -517,6 +537,21 @@ def create(
517537 for n , k in enumerate (index ):
518538 if k is None :
519539 output_idx += 1
540+ elif k is Ellipsis :
541+ ellipsis_dims = _calculate_ellipsis_dims (index , n , fake_value .ndim )
542+ for _ in range (ellipsis_dims ):
543+ expand = tile_strategy .expand_str (output_size , output_idx )
544+ size = fake_value .size (len (index_values ))
545+ if size != 1 :
546+ rdim = env .allocate_reduction_dimension (size )
547+ block_idx = rdim .block_id
548+ index_var = state .codegen .index_var (block_idx )
549+ index_values .append (f"({ index_var } ){ expand } " )
550+ if mask := state .codegen .mask_var (block_idx ):
551+ mask_values .setdefault (f"({ mask } ){ expand } " )
552+ else :
553+ index_values .append (f"tl.zeros([1], { dtype } ){ expand } " )
554+ output_idx += 1
520555 elif isinstance (k , int ):
521556 index_values .append (repr (k ))
522557 elif isinstance (k , torch .SymInt ):
@@ -729,8 +764,16 @@ def is_supported(
729764 # TODO(jansel): support block_ptr with extra_mask
730765 return False
731766 input_sizes = collections .deque (fake_tensor .size ())
732- for k in index :
733- input_size = 1 if k is None else input_sizes .popleft ()
767+ for n , k in enumerate (index ):
768+ if k is None :
769+ input_size = 1
770+ elif k is Ellipsis :
771+ ellipsis_dims = _calculate_ellipsis_dims (index , n , fake_tensor .ndim )
772+ for _ in range (ellipsis_dims ):
773+ input_sizes .popleft ()
774+ continue
775+ else :
776+ input_size = input_sizes .popleft ()
734777 if isinstance (k , torch .SymInt ):
735778 symbol = k ._sympy_ ()
736779 origin = None
@@ -780,9 +823,21 @@ def create(
780823 fake_value ,
781824 reshaped_size = SubscriptIndexing .compute_shape (fake_value , index ),
782825 )
783- for k in index :
826+ for n , k in enumerate ( index ) :
784827 if k is None :
785828 pass # handled by reshaped_size
829+ elif k is Ellipsis :
830+ ellipsis_dims = _calculate_ellipsis_dims (index , n , fake_value .ndim )
831+ env = CompileEnvironment .current ()
832+ for _ in range (ellipsis_dims ):
833+ size = fake_value .size (len (res .offsets ))
834+ if size != 1 :
835+ rdim = env .allocate_reduction_dimension (size )
836+ res .offsets .append (state .codegen .offset_var (rdim .block_id ))
837+ res .block_shape .append (rdim .var )
838+ else :
839+ res .offsets .append ("0" )
840+ res .block_shape .append (1 )
786841 elif isinstance (k , int ):
787842 res .offsets .append (repr (k ))
788843 res .block_shape .append (1 )
0 commit comments