|
18 | 18 | from .host_function import HostFunction |
19 | 19 | from .tile_strategy import DeviceLoopState |
20 | 20 | from .utils import compute_slice_size |
| 21 | +from .utils import get_slice_start |
21 | 22 | from .variable_origin import BlockSizeOrigin |
22 | 23 |
|
23 | 24 | if TYPE_CHECKING: |
@@ -126,6 +127,30 @@ def _handle_remaining_index_dimensions( |
126 | 127 | return output_idx |
127 | 128 |
|
128 | 129 |
|
| 130 | +def _generate_slice_index( |
| 131 | + start: int | torch.SymInt, |
| 132 | + index_var: str, |
| 133 | + expand: str, |
| 134 | + step: int | None = None, |
| 135 | +) -> str: |
| 136 | + """Generate slice index expression with optional step.""" |
| 137 | + if step is not None: |
| 138 | + # Strided index: start + index * step |
| 139 | + return f"({start} + ({index_var}) * {step}){expand}" |
| 140 | + if start != 0: |
| 141 | + # Index with offset: start + index |
| 142 | + return f"({start} + ({index_var})){expand}" |
| 143 | + # Simple index |
| 144 | + return f"({index_var}){expand}" |
| 145 | + |
| 146 | + |
| 147 | +def _generate_offset_expr(start: int | torch.SymInt, offset: str) -> str: |
| 148 | + """Generate offset expression with optional start.""" |
| 149 | + if start != 0: |
| 150 | + return f"({start} + {offset})" |
| 151 | + return offset |
| 152 | + |
| 153 | + |
129 | 154 | class IndexingStrategy: |
130 | 155 | def codegen_load( |
131 | 156 | self, |
@@ -628,7 +653,6 @@ def compute_shape( |
628 | 653 | size = input_size.popleft() |
629 | 654 | # Handle slices with steps |
630 | 655 | slice_size = compute_slice_size(k, size) |
631 | | - |
632 | 656 | if slice_size != 1: |
633 | 657 | rdim = env.allocate_reduction_dimension(slice_size) |
634 | 658 | output_size.append(rdim.var) |
@@ -721,25 +745,29 @@ def create( |
721 | 745 | rdim = env.allocate_reduction_dimension(slice_size) |
722 | 746 | block_idx = rdim.block_id |
723 | 747 | index_var = state.codegen.index_var(block_idx) |
724 | | - # Generate strided index: start + index * step |
725 | 748 | index_values.append( |
726 | | - f"({start} + ({index_var}) * {step}){expand}" |
| 749 | + _generate_slice_index(start, index_var, expand, step) |
727 | 750 | ) |
728 | 751 | if mask := state.codegen.mask_var(block_idx): |
729 | 752 | mask_values.setdefault(f"({mask}){expand}") |
730 | 753 | else: |
731 | 754 | index_values.append(f"{start}{expand}") |
732 | 755 | else: |
733 | | - # Full slice or slice without step |
734 | | - if size != 1: |
735 | | - rdim = env.allocate_reduction_dimension(size) |
| 756 | + # Handle slices with start/stop but no step |
| 757 | + start = get_slice_start(k) |
| 758 | + slice_size = compute_slice_size(k, size) |
| 759 | + |
| 760 | + if slice_size != 1: |
| 761 | + rdim = env.allocate_reduction_dimension(slice_size) |
736 | 762 | block_idx = rdim.block_id |
737 | 763 | index_var = state.codegen.index_var(block_idx) |
738 | | - index_values.append(f"({index_var}){expand}") |
| 764 | + index_values.append( |
| 765 | + _generate_slice_index(start, index_var, expand) |
| 766 | + ) |
739 | 767 | if mask := state.codegen.mask_var(block_idx): |
740 | 768 | mask_values.setdefault(f"({mask}){expand}") |
741 | 769 | else: |
742 | | - index_values.append(f"tl.zeros([1], {dtype}){expand}") |
| 770 | + index_values.append(f"{start}{expand}") |
743 | 771 | output_idx += 1 |
744 | 772 | elif isinstance(k, torch.Tensor) and k.ndim == 1: |
745 | 773 | expand = tile_strategy.expand_str(output_size, output_idx) |
@@ -1029,8 +1057,19 @@ def create( |
1029 | 1057 | res.offsets.append(state.codegen.offset_var(rdim.block_id)) |
1030 | 1058 | res.block_shape.append(rdim.var) |
1031 | 1059 | else: |
1032 | | - res.offsets.append("0") |
1033 | | - res.block_shape.append(1) |
| 1060 | + # Handle slices with start/stop but no step |
| 1061 | + start = get_slice_start(k) |
| 1062 | + slice_size = compute_slice_size(k, size) |
| 1063 | + |
| 1064 | + if slice_size != 1: |
| 1065 | + env = CompileEnvironment.current() |
| 1066 | + rdim = env.allocate_reduction_dimension(slice_size) |
| 1067 | + offset = state.codegen.offset_var(rdim.block_id) |
| 1068 | + res.offsets.append(_generate_offset_expr(start, offset)) |
| 1069 | + res.block_shape.append(rdim.var) |
| 1070 | + else: |
| 1071 | + res.offsets.append(str(start)) |
| 1072 | + res.block_shape.append(1) |
1034 | 1073 | else: |
1035 | 1074 | raise exc.InvalidIndexingType(k) |
1036 | 1075 | res.validate() |
|
0 commit comments