@@ -227,6 +227,9 @@ def valid_block_size(
227227 for i , k in enumerate (subscript ):
228228 if k is None :
229229 continue
230+ if k is Ellipsis :
231+ # Ellipsis is not supported in tensor descriptor mode
232+ return False
230233 size , stride = size_stride .popleft ()
231234 if isinstance (k , slice ):
232235 # Slices with steps are not supported in tensor descriptor mode
@@ -465,9 +468,20 @@ def compute_shape(
465468 input_size = collections .deque (tensor .size ())
466469 output_size = []
467470 env = CompileEnvironment .current ()
468- for k in index :
471+ for i , k in enumerate ( index ) :
469472 if k is None :
470473 output_size .append (1 )
474+ elif k is Ellipsis :
475+ # Ellipsis expands to consume all remaining dims except those after it
476+ remaining_indices = len (index ) - i - 1
477+ ellipsis_dims = len (input_size ) - remaining_indices
478+ for _ in range (ellipsis_dims ):
479+ size = input_size .popleft ()
480+ if size != 1 :
481+ rdim = env .allocate_reduction_dimension (size )
482+ output_size .append (rdim .var )
483+ else :
484+ output_size .append (1 )
471485 elif isinstance (k , int ):
472486 input_size .popleft ()
473487 elif isinstance (k , torch .SymInt ):
@@ -517,6 +531,23 @@ def create(
517531 for n , k in enumerate (index ):
518532 if k is None :
519533 output_idx += 1
534+ elif k is Ellipsis :
535+ # Ellipsis expands to handle remaining dimensions
536+ remaining_indices = len (index ) - n - 1
537+ ellipsis_dims = fake_value .ndim - len (index_values ) - remaining_indices
538+ for dim_offset in range (ellipsis_dims ):
539+ expand = tile_strategy .expand_str (output_size , output_idx )
540+ size = fake_value .size (len (index_values ))
541+ if size != 1 :
542+ rdim = env .allocate_reduction_dimension (size )
543+ block_idx = rdim .block_id
544+ index_var = state .codegen .index_var (block_idx )
545+ index_values .append (f"({ index_var } ){ expand } " )
546+ if mask := state .codegen .mask_var (block_idx ):
547+ mask_values .setdefault (f"({ mask } ){ expand } " )
548+ else :
549+ index_values .append (f"tl.zeros([1], { dtype } ){ expand } " )
550+ output_idx += 1
520551 elif isinstance (k , int ):
521552 index_values .append (repr (k ))
522553 elif isinstance (k , torch .SymInt ):
@@ -729,8 +760,18 @@ def is_supported(
729760 # TODO(jansel): support block_ptr with extra_mask
730761 return False
731762 input_sizes = collections .deque (fake_tensor .size ())
732- for k in index :
733- input_size = 1 if k is None else input_sizes .popleft ()
763+ for n , k in enumerate (index ):
764+ if k is None :
765+ input_size = 1
766+ elif k is Ellipsis :
767+ # Skip appropriate number of dimensions for ellipsis
768+ remaining_indices = len (index ) - n - 1
769+ ellipsis_dims = len (input_sizes ) - remaining_indices
770+ for _ in range (ellipsis_dims ):
771+ input_sizes .popleft ()
772+ continue
773+ else :
774+ input_size = input_sizes .popleft ()
734775 if isinstance (k , torch .SymInt ):
735776 symbol = k ._sympy_ ()
736777 origin = None
@@ -780,9 +821,23 @@ def create(
780821 fake_value ,
781822 reshaped_size = SubscriptIndexing .compute_shape (fake_value , index ),
782823 )
783- for k in index :
824+ for n , k in enumerate ( index ) :
784825 if k is None :
785826 pass # handled by reshaped_size
827+ elif k is Ellipsis :
828+ # Ellipsis expands to handle remaining dimensions
829+ remaining_indices = len (index ) - n - 1
830+ ellipsis_dims = fake_value .ndim - len (res .offsets ) - remaining_indices
831+ for _ in range (ellipsis_dims ):
832+ size = fake_value .size (len (res .offsets ))
833+ if size != 1 :
834+ env = CompileEnvironment .current ()
835+ rdim = env .allocate_reduction_dimension (size )
836+ res .offsets .append (state .codegen .offset_var (rdim .block_id ))
837+ res .block_shape .append (rdim .var )
838+ else :
839+ res .offsets .append ("0" )
840+ res .block_shape .append (1 )
786841 elif isinstance (k , int ):
787842 res .offsets .append (repr (k ))
788843 res .block_shape .append (1 )
0 commit comments