Skip to content

Commit 6cd21fb

Browse files
committed
initial version
1 parent 2882b3b commit 6cd21fb

File tree

7 files changed

+364
-46
lines changed

7 files changed

+364
-46
lines changed

helion/_compiler/compile_environment.py

Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,12 @@ def __init__(self, device: torch.device, settings: Settings) -> None:
7272
self.specialized_vars: set[sympy.Symbol] = set()
7373
self.loop_dependency_checker = LoopDependencyChecker()
7474
self._symint_cache: dict[object, torch.SymInt] = {}
75+
76+
# Track tile.index tensors to preserve their block_id through broadcast indexing operations.
77+
# When tile.index creates indices [0,1,2...] for a tiled dimension, we map the tensor to its
78+
# block_id. This origin is preserved through ops like tensor[:, None] so the symbolic size is maintained.
79+
self._tile_index_tensor_to_block_id_map: dict[int, int] = {} # unique_tensor_id -> block_id
80+
self._next_tensor_id = 0 # Counter for generating unique tensor IDs
7581

7682
def add_kernel_tensor_size(self, sizes: Sequence[int | torch.SymInt]) -> None:
7783
from .device_function import contains_only_block_size_symbols
@@ -142,6 +148,30 @@ def allocate_reduction_dimension(self, size: torch.SymInt | int) -> BlockSizeInf
142148
if rdim.reduction and rdim.size == size:
143149
return rdim
144150

151+
# Check if size matches any tile dimension for symbolic equality.
152+
# When building expressions that mix sizes derived from tiles
153+
# (e.g., via slicing) with sizes coming directly from tile block vars, we
154+
# want them to share the same SymInt variable whenever they are equal by
155+
# construction. This preserves equality in the shape environment and avoids
156+
# spurious "size mismatch" issues during fake-tensor broadcasting and
157+
# arithmetic in type propagation.
158+
if isinstance(size, torch.SymInt):
159+
size_str = str(size)
160+
for block_info in self.block_sizes:
161+
if not block_info.reduction and str(block_info.var) == size_str:
162+
# Create reduction dimension with the same var to preserve
163+
# symbolic equality and ensure all later users see identical
164+
# symbols (rather than equal-but-distinct SymInts).
165+
rdim_idx = self.allocate_block_size(
166+
size,
167+
reduction=True,
168+
source=ReductionLoopBlockSizeSource(
169+
reduction_loop=len([b for b in self.block_sizes if b.reduction])
170+
),
171+
)
172+
self.block_sizes[rdim_idx].var = block_info.var
173+
return self.block_sizes[rdim_idx]
174+
145175
# Allocate a new reduction dimension
146176
rdim_idx = self.allocate_block_size(
147177
size,
@@ -203,6 +233,107 @@ def cached_create_unbacked_symint(
203233
self._symint_cache[key] = result
204234
return result
205235

236+
237+
def register_tile_index_tensor_block_id(self, tensor: torch.Tensor, block_id: int) -> None:
238+
"""Register a tensor as originating from a specific tile block.
239+
240+
This is called when tile.index creates a 1D tensor of indices for a
241+
specific tiled dimension. The tensor represents indices [0, 1, 2, ...]
242+
for ONE dimension that is being tiled.
243+
244+
Args:
245+
tensor: A 1D tensor created by tile.index containing indices for
246+
a single tiled dimension
247+
block_id: The block ID representing the tiled dimension this tensor
248+
corresponds to. This is NOT a multi-dimensional concept -
249+
each tile.index tensor tracks exactly one dimension.
250+
251+
Example:
252+
When tiling x.size(0) with block_id=3:
253+
- tile.index creates tensor([0, 1, 2, ..., block_size-1])
254+
- This tensor is registered with block_id=3
255+
- Later, when this tensor is used as an indexer, we know the
256+
output should have the symbolic size from block_id=3
257+
"""
258+
# Assign a unique ID to this tensor
259+
tensor_id = self._next_tensor_id
260+
self._next_tensor_id += 1
261+
262+
# Store the mapping from this tile.index tensor to its dimension's block_id
263+
self._tile_index_tensor_to_block_id_map[tensor_id] = block_id
264+
265+
# Tag the tensor with its unique ID
266+
tensor._tile_index_tensor_id = tensor_id
267+
268+
def get_tile_index_tensor_block_id(self, tensor: torch.Tensor) -> int | None:
269+
"""Get the block_id for a tensor if it originated from tile.index.
270+
271+
Returns the block_id of the single dimension this index tensor represents,
272+
or None if this tensor didn't originate from tile.index.
273+
"""
274+
# Check if tensor has our unique ID tag
275+
tensor_id = getattr(tensor, '_tile_index_tensor_id', None)
276+
if tensor_id is None:
277+
return None
278+
return self._tile_index_tensor_to_block_id_map.get(tensor_id)
279+
280+
def is_tile_index_tensor(self, tensor: torch.Tensor) -> bool:
281+
"""Check if a tensor originated from a tile.index operation."""
282+
tensor_id = getattr(tensor, '_tile_index_tensor_id', None)
283+
if tensor_id is None:
284+
return False
285+
# If tensor has an ID, it must be in the map
286+
assert tensor_id in self._tile_index_tensor_to_block_id_map
287+
return True
288+
289+
def preserve_tile_index_tensor_block_id(
290+
self,
291+
input_tensor: torch.Tensor,
292+
output_tensor: torch.Tensor,
293+
output_shape: list[int | torch.SymInt]
294+
) -> None:
295+
"""Preserve tile.index tensor's origin block id through broadcast-only view operations.
296+
297+
Note: Caller must check is_tile_index_tensor() before calling this method.
298+
"""
299+
# Get the block_id from input tensor
300+
input_tensor_id = getattr(input_tensor, '_tile_index_tensor_id')
301+
src_block_id = self._tile_index_tensor_to_block_id_map[input_tensor_id]
302+
303+
# Only preserve for broadcast-only views (at most one non-1 dimension)
304+
non_broadcast_dims = [i for i, s in enumerate(output_shape) if self.size_hint(s) != 1]
305+
if len(non_broadcast_dims) <= 1:
306+
# Register the output tensor with the same block_id
307+
self.register_tile_index_tensor_block_id(output_tensor, src_block_id)
308+
# Ensure the non-broadcast dimension uses the correct symbol
309+
if non_broadcast_dims and src_block_id < len(self.block_sizes):
310+
output_shape[non_broadcast_dims[0]] = self.block_sizes[src_block_id].var
311+
312+
def get_indexer_output_size(
313+
self,
314+
indexer_tensor: torch.Tensor,
315+
base_dim_size: int | torch.SymInt | None
316+
) -> int | torch.SymInt | list:
317+
"""Get the output size for a tensor indexer, preserving tile.index tensor's origin block id."""
318+
dims = list(indexer_tensor.size())
319+
non_broadcast_dims = [d for d in dims if self.size_hint(d) != 1]
320+
321+
# Multi-dimensional indexer - return full shape
322+
if len(non_broadcast_dims) > 1:
323+
return dims
324+
325+
# Try to find block_id from different sources in order
326+
if block_id := self.get_tile_index_tensor_block_id(indexer_tensor):
327+
return self.block_sizes[block_id].var
328+
329+
if base_dim_size and (block_id := self.get_block_id(base_dim_size)):
330+
return self.block_sizes[block_id].var
331+
332+
if non_broadcast_dims and (block_id := self.get_block_id(non_broadcast_dims[0])):
333+
return self.block_sizes[block_id].var
334+
335+
return non_broadcast_dims[0] if non_broadcast_dims else 1
336+
206337
def to_fake(self, obj: object, origin: Origin) -> object:
207338
if isinstance(obj, torch.Tensor):
208339
return self._to_fake_tensor(obj, origin.to_source())

helion/_compiler/indexing_strategy.py

Lines changed: 129 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -490,11 +490,20 @@ def compute_shape(
490490
output_size.append(rdim.var)
491491
else:
492492
output_size.append(1)
493-
elif isinstance(k, torch.Tensor) and (
494-
k.ndim == 1 or (len(index) == 1 and tensor.ndim == 1)
495-
):
496-
input_size.popleft()
497-
output_size.extend(k.size())
493+
elif isinstance(k, torch.Tensor):
494+
# Advanced tensor indexer: consume one base dim and splice indexer shape.
495+
base_dim = input_size.popleft()
496+
dims = list(k.size())
497+
non_broadcast_dims = [d for d in dims if env.size_hint(d) != 1]
498+
499+
# Multi-d indexer contributes its own shape
500+
if len(non_broadcast_dims) > 1:
501+
output_size.extend(dims)
502+
continue
503+
504+
# Single or broadcast-only indexer - use origin tracking helper
505+
size = env.get_indexer_output_size(k, base_dim)
506+
output_size.append(size)
498507
else:
499508
raise exc.InvalidIndexingType(k)
500509
assert len(input_size) == 0, "invalid subscript"
@@ -514,6 +523,7 @@ def create(
514523
output_size = SubscriptIndexing.compute_shape(fake_value, index)
515524
env = CompileEnvironment.current()
516525
dtype = env.triton_index_type()
526+
517527
for n, k in enumerate(index):
518528
if k is None:
519529
output_idx += 1
@@ -573,16 +583,16 @@ def create(
573583
else:
574584
index_values.append(f"tl.zeros([1], {dtype}){expand}")
575585
output_idx += 1
576-
elif isinstance(k, torch.Tensor) and k.ndim == 1:
577-
expand = tile_strategy.expand_str(output_size, output_idx)
578-
ast_index = state.ast_args[1]
579-
assert isinstance(ast_index, (list, tuple))
580-
assert len(ast_index) == len(index)
581-
index_var = state.codegen.lift(ast_index[n], prefix="index").id
582-
index_values.append(f"({index_var}){expand}")
583-
if (block_idx := env.get_block_id(output_size[output_idx])) is not None:
584-
if mask := state.codegen.mask_var(block_idx):
585-
mask_values.setdefault(f"({mask}){expand}")
586+
elif isinstance(k, torch.Tensor) and (
587+
k.ndim == 1
588+
or sum(CompileEnvironment.current().size_hint(d) != 1 for d in k.size())
589+
<= 1
590+
):
591+
# Broadcast-only 1D indexer
592+
SubscriptIndexing._handle_broadcast_indexer(
593+
k, n, output_size, output_idx, index,
594+
state, tile_strategy, index_values, mask_values, env
595+
)
586596
output_idx += 1
587597
elif (
588598
isinstance(k, torch.Tensor) and len(index) == 1 and fake_value.ndim == 1
@@ -601,6 +611,24 @@ def create(
601611
mask_values.setdefault(
602612
f"({mask}){tile_strategy.expand_str(output_size, n)}"
603613
)
614+
elif isinstance(k, torch.Tensor) and k.ndim > 1 and len(index) > 1:
615+
# Multi-dimensional tensor indexer combined with other indices
616+
non_broadcast_dims = [dim for dim in k.size() if env.size_hint(dim) != 1]
617+
618+
if len(non_broadcast_dims) <= 1:
619+
# Broadcast-only multi-dim indexer: treat as single dimension
620+
SubscriptIndexing._handle_broadcast_indexer(
621+
k, n, output_size, output_idx, index,
622+
state, tile_strategy, index_values, mask_values, env
623+
)
624+
output_idx += 1
625+
else:
626+
# True multi-dim indexer: handle all dims at once
627+
SubscriptIndexing._handle_multidim_indexer(
628+
k, n, output_size, output_idx, index,
629+
state, tile_strategy, index_values, mask_values, env
630+
)
631+
output_idx += k.ndim
604632
else:
605633
raise exc.InvalidIndexingType(type(k))
606634
assert len(output_size) == output_idx
@@ -618,10 +646,96 @@ def create(
618646
if extra_mask is not None:
619647
mask_values.setdefault("{_extra_mask}")
620648
kwargs["_extra_mask"] = extra_mask
649+
621650
return SubscriptIndexing(
622651
expr_from_string("+".join(index_expr)),
623652
expr_from_string("&".join(mask_values) or "None", **kwargs),
624653
)
654+
655+
@staticmethod
656+
def _handle_broadcast_indexer(
657+
k: torch.Tensor, n: int, output_size: list, output_idx: int, index: list,
658+
state: CodegenState, tile_strategy: Any, index_values: list,
659+
mask_values: dict, env: CompileEnvironment
660+
) -> None:
661+
"""Handle broadcast-only tensor indexer (all dims but one are size 1)."""
662+
expand = tile_strategy.expand_str(output_size, output_idx)
663+
664+
# Try to get tile.index tensor's origin block_id
665+
tile_origin_block_id = env.get_tile_index_tensor_block_id(k)
666+
667+
if tile_origin_block_id is not None:
668+
# Use the tile_index tensor's block id directly
669+
index_var = state.codegen.index_var(tile_origin_block_id)
670+
index_values.append(f"({index_var}){expand}")
671+
if (mask := state.codegen.mask_var(tile_origin_block_id)) is not None:
672+
mask_values.setdefault(f"({mask}){expand}")
673+
else:
674+
# Lift AST to preserve expressions like tile.index + 1
675+
ast_index = state.ast_args[1]
676+
assert isinstance(ast_index, (list, tuple))
677+
assert len(ast_index) == len(index)
678+
lifted = state.codegen.lift(ast_index[n], prefix="index").id
679+
index_values.append(f"({lifted}){expand}")
680+
# Even if we lift, we still know the block-id for this axis from output_size
681+
output_block_id = env.get_block_id(output_size[output_idx])
682+
if output_block_id is not None and (mask := state.codegen.mask_var(output_block_id)) is not None:
683+
mask_values.setdefault(f"({mask}){expand}")
684+
685+
@staticmethod
686+
def _handle_multidim_indexer(
687+
k: torch.Tensor, n: int, output_size: list, output_idx: int, index: list,
688+
state: CodegenState, tile_strategy: Any, index_values: list,
689+
mask_values: dict, env: CompileEnvironment
690+
) -> None:
691+
"""Handle multi-dimensional tensor indexer."""
692+
# Lift the indexer once
693+
ast_index = state.ast_args[1]
694+
assert isinstance(ast_index, (list, tuple))
695+
assert len(ast_index) == len(index)
696+
index_var = state.codegen.lift(ast_index[n], prefix="index").id
697+
698+
# Build merged broadcast bracket for all dims
699+
# Start with first dimension
700+
base = tile_strategy.expand_str(output_size, output_idx)
701+
if base == "":
702+
tokens = []
703+
else:
704+
assert base.startswith("[") and base.endswith("]"), base
705+
tokens = base[1:-1].split(", ") if len(base) > 2 else []
706+
707+
# Merge with other dimensions
708+
for d in range(1, k.ndim):
709+
s = tile_strategy.expand_str(output_size, output_idx + d)
710+
if s == "":
711+
s_tokens = [":"]
712+
else:
713+
assert s.startswith("[") and s.endswith("]"), s
714+
s_tokens = s[1:-1].split(", ") if len(s) > 2 else []
715+
716+
# Merge tokens: use ':' if either has ':', else 'None'
717+
if not tokens:
718+
tokens = s_tokens
719+
elif s_tokens:
720+
tokens = [
721+
":" if (a == ":" or b == ":") else "None"
722+
for a, b in zip(tokens, s_tokens, strict=True)
723+
]
724+
725+
if tokens == [":"] or not tokens:
726+
bracket = ""
727+
else:
728+
bracket = f"[{', '.join(tokens)}]"
729+
730+
index_values.append(f"({index_var}){bracket}")
731+
732+
# Add mask contributions for each output dim
733+
for d in range(k.ndim):
734+
if (block_idx := env.get_block_id(output_size[output_idx + d])) is not None:
735+
if mask := state.codegen.mask_var(block_idx):
736+
mask_values.setdefault(
737+
f"({mask}){tile_strategy.expand_str(output_size, output_idx + d)}"
738+
)
625739

626740

627741
@dataclasses.dataclass

0 commit comments

Comments
 (0)