@@ -517,7 +517,6 @@ def compute_shape(
517517 size = input_size .popleft ()
518518 # Handle slices with steps
519519 slice_size = compute_slice_size (k , size )
520-
521520 if slice_size != 1 :
522521 rdim = env .allocate_reduction_dimension (slice_size )
523522 output_size .append (rdim .var )
@@ -633,16 +632,24 @@ def create(
633632 else :
634633 index_values .append (f"{ start } { expand } " )
635634 else :
636- # Full slice or slice without step
637- if size != 1 :
638- rdim = env .allocate_reduction_dimension (size )
635+ # Handle slices with start/stop but no step
636+ start = k .start if k .start is not None else 0
637+ stop = k .stop if k .stop is not None else size
638+ slice_size = stop - start
639+
640+ if slice_size != 1 :
641+ rdim = env .allocate_reduction_dimension (slice_size )
639642 block_idx = rdim .block_id
640643 index_var = state .codegen .index_var (block_idx )
641- index_values .append (f"({ index_var } ){ expand } " )
644+ # Generate index: start + index_var
645+ if start != 0 :
646+ index_values .append (f"({ start } + ({ index_var } )){ expand } " )
647+ else :
648+ index_values .append (f"({ index_var } ){ expand } " )
642649 if mask := state .codegen .mask_var (block_idx ):
643650 mask_values .setdefault (f"({ mask } ){ expand } " )
644651 else :
645- index_values .append (f"tl.zeros([1], { dtype } ) { expand } " )
652+ index_values .append (f"{ start } { expand } " )
646653 output_idx += 1
647654 elif isinstance (k , torch .Tensor ) and k .ndim == 1 :
648655 expand = tile_strategy .expand_str (output_size , output_idx )
@@ -941,8 +948,24 @@ def create(
941948 res .offsets .append (state .codegen .offset_var (rdim .block_id ))
942949 res .block_shape .append (rdim .var )
943950 else :
944- res .offsets .append ("0" )
945- res .block_shape .append (1 )
951+ # Handle slices with start/stop but no step
952+ start = k .start if k .start is not None else 0
953+ stop = k .stop if k .stop is not None else size
954+ slice_size = stop - start
955+
956+ if slice_size != 1 :
957+ env = CompileEnvironment .current ()
958+ rdim = env .allocate_reduction_dimension (slice_size )
959+ offset = state .codegen .offset_var (rdim .block_id )
960+ # Add start offset if needed
961+ if start != 0 :
962+ res .offsets .append (f"({ start } + { offset } )" )
963+ else :
964+ res .offsets .append (offset )
965+ res .block_shape .append (rdim .var )
966+ else :
967+ res .offsets .append (str (start ))
968+ res .block_shape .append (1 )
946969 else :
947970 raise exc .InvalidIndexingType (k )
948971 res .validate ()
0 commit comments