Skip to content

Commit 28424b9

Browse files
committed
initial version
1 parent 5c71db4 commit 28424b9

File tree

8 files changed

+463
-76
lines changed

8 files changed

+463
-76
lines changed

helion/_compiler/ast_extension.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
from typing import TYPE_CHECKING
99
from typing import TypeVar
1010

11+
import torch
12+
1113
from .. import exc
1214
from .source_location import SourceLocation
1315
from .source_location import current_location
@@ -82,10 +84,31 @@ def __repr__(self) -> str:
8284

8385
def update_type_info(self, type_info: TypeInfo) -> TypeInfo:
8486
if self._type_info is not None and type_info != self._type_info:
87+
prev_rank = self._tensor_rank(self._type_info)
88+
new_rank = self._tensor_rank(type_info)
89+
if (
90+
prev_rank is not None
91+
and new_rank is not None
92+
and prev_rank != new_rank
93+
):
94+
self._type_info = type_info
95+
return self._type_info
8596
type_info = self._type_info.merge(type_info)
8697
self._type_info = type_info
8798
return self._type_info
8899

100+
@staticmethod
101+
def _tensor_rank(type_info: "TypeInfo") -> int | None:
102+
fake_value = getattr(type_info, "fake_value", None)
103+
if isinstance(fake_value, torch.Tensor):
104+
return fake_value.dim()
105+
tensor = getattr(type_info, "tensor", None)
106+
if tensor is not None:
107+
fake_value = getattr(tensor, "fake_value", None)
108+
if isinstance(fake_value, torch.Tensor):
109+
return fake_value.dim()
110+
return None
111+
89112
def debug_annotations(self) -> list[str]:
90113
result = []
91114
if self._type_info:

helion/_compiler/compile_environment.py

Lines changed: 152 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -142,17 +142,47 @@ def allocate_reduction_dimension(self, size: torch.SymInt | int) -> BlockSizeInf
142142
if rdim.reduction and rdim.size == size:
143143
return rdim
144144

145+
# Check if size matches any tile dimension for symbolic equality.
146+
# When building expressions that mix sizes derived from tiles (e.g. via
147+
# slicing) with sizes coming directly from tile block vars, we want them
148+
# to share the same SymInt variable whenever they are equal by
149+
# construction. This preserves equality in the shape environment and
150+
# avoids spurious "size mismatch" issues during fake-tensor broadcasting
151+
# and arithmetic in type propagation.
152+
if isinstance(size, torch.SymInt):
153+
block_idx = self.get_block_id(size)
154+
if block_idx is not None and not self.block_sizes[block_idx].reduction:
155+
return self._clone_block_size_as_reduction(block_idx, size)
156+
157+
sym = size._sympy_()
158+
for block_idx, block_info in enumerate(self.block_sizes):
159+
if not block_info.reduction and sym == block_info.symbol():
160+
return self._clone_block_size_as_reduction(block_idx, size)
161+
145162
# Allocate a new reduction dimension
163+
return self._allocate_new_reduction(size)
164+
165+
def _clone_block_size_as_reduction(
166+
self, block_idx: int, size: torch.SymInt | int
167+
) -> BlockSizeInfo:
168+
rdim = self._allocate_new_reduction(size)
169+
rdim.var = self.block_sizes[block_idx].var
170+
return rdim
171+
172+
def _allocate_new_reduction(self, size: torch.SymInt | int) -> BlockSizeInfo:
146173
rdim_idx = self.allocate_block_size(
147174
size,
148175
reduction=True,
149176
source=ReductionLoopBlockSizeSource(
150-
sum([int(bs.reduction) for bs in self.block_sizes])
177+
self._next_reduction_loop_index()
151178
),
152179
hint=next_power_of_2(self.size_hint(size)),
153180
)
154181
return self.block_sizes[rdim_idx]
155182

183+
def _next_reduction_loop_index(self) -> int:
184+
return sum(int(info.reduction) for info in self.block_sizes)
185+
156186
def create_block_var(self, debug_name: str, hint: int = 64) -> torch.SymInt:
157187
with self.shape_env.ignore_fresh_unbacked_symbols():
158188
sym = self.shape_env.create_unbacked_symint()
@@ -203,6 +233,91 @@ def cached_create_unbacked_symint(
203233
self._symint_cache[key] = result
204234
return result
205235

236+
237+
def register_tile_index_tensor_block_id(self, tensor: torch.Tensor, block_id: int) -> None:
238+
"""Annotate ``tensor`` as originating from ``tile.index`` with ``block_id`` provenance."""
239+
tensor._tile_index_block_id = block_id # type: ignore[attr-defined]
240+
241+
def get_tile_index_tensor_block_id(self, tensor: torch.Tensor) -> int | None:
242+
"""Return the originating ``tile.index`` block id if present."""
243+
return getattr(tensor, "_tile_index_block_id", None)
244+
245+
def get_indexer_output_dims(
246+
self,
247+
indexer_tensor: torch.Tensor,
248+
base_dim_size: int | torch.SymInt | None,
249+
) -> list[int | torch.SymInt]:
250+
"""Map a tensor indexer's shape to the output dimensions for advanced indexing."""
251+
252+
dims = list(indexer_tensor.size())
253+
non_broadcast_dims = [d for d in dims if self.size_hint(d) != 1]
254+
255+
# Multi-dimensional indexer - return full shape
256+
if len(non_broadcast_dims) > 1:
257+
return dims
258+
259+
block_id = self.get_tile_index_tensor_block_id(indexer_tensor)
260+
if block_id is None and base_dim_size is not None:
261+
block_id = self.get_block_id(base_dim_size)
262+
if block_id is None and non_broadcast_dims:
263+
block_id = self.get_block_id(non_broadcast_dims[0])
264+
265+
if block_id is not None:
266+
return [self.block_sizes[block_id].var]
267+
if non_broadcast_dims:
268+
return [non_broadcast_dims[0]]
269+
return [1]
270+
271+
def tensor_indexer_broadcast_shape(
272+
self, tensors: typing.Sequence[torch.Tensor]
273+
) -> list[int | torch.SymInt] | None:
274+
"""Compute a shared broadcast shape for tensor indexers when needed."""
275+
276+
tensor_list = [t for t in tensors if isinstance(t, torch.Tensor)]
277+
if not tensor_list:
278+
return None
279+
280+
if all(self.get_tile_index_tensor_block_id(t) is not None for t in tensor_list):
281+
return None
282+
283+
shapes = [list(t.size()) for t in tensor_list]
284+
return compute_broadcast_shape_for_tensor_indexers(shapes, self)
285+
286+
def resolve_tile_index_shape(
287+
self, input_tensor: torch.Tensor, output_shape: typing.Sequence[int | torch.SymInt]
288+
) -> tuple[list[int | torch.SymInt], int | None]:
289+
"""Resolve the symbolic shape for tensors derived from ``tile.index``.
290+
291+
Returns a copy of ``output_shape`` where the single non-broadcast
292+
dimension is replaced with the canonical block-symbol and the associated
293+
block_id to register on the new tensor. If the tensor is not a tile
294+
indexer or it introduces more than one non-broadcast dimension, the
295+
original shape and ``None`` are returned.
296+
"""
297+
298+
block_id = self.get_tile_index_tensor_block_id(input_tensor)
299+
if block_id is None:
300+
return list(output_shape), None
301+
302+
resolved = list(output_shape)
303+
non_broadcast = [i for i, s in enumerate(resolved) if self.size_hint(s) != 1]
304+
if len(non_broadcast) <= 1:
305+
if non_broadcast:
306+
resolved[non_broadcast[0]] = self.block_sizes[block_id].var
307+
return resolved, block_id
308+
return resolved, None
309+
310+
def new_index_result(
311+
self, tensor: torch.Tensor, output_shape: typing.Sequence[int | torch.SymInt]
312+
) -> torch.Tensor:
313+
"""Create a new tensor for indexing/view ops while preserving tile index provenance."""
314+
315+
resolved_shape, block_id = self.resolve_tile_index_shape(tensor, output_shape)
316+
result = tensor.new_empty(resolved_shape)
317+
if block_id is not None:
318+
self.register_tile_index_tensor_block_id(result, block_id)
319+
return result
320+
206321
def to_fake(self, obj: object, origin: Origin) -> object:
207322
if isinstance(obj, torch.Tensor):
208323
return self._to_fake_tensor(obj, origin.to_source())
@@ -283,6 +398,10 @@ def _to_fake_tensor(self, tensor: torch.Tensor, source: Source) -> torch.Tensor:
283398
self.fake_mode, tensor, shape_env=self.shape_env, source=source
284399
)
285400
self.input_sources[result] = source
401+
if hasattr(tensor, "_tile_index_block_id"):
402+
self.register_tile_index_tensor_block_id(
403+
result, typing.cast(int, getattr(tensor, "_tile_index_block_id"))
404+
)
286405
if isinstance(source, LocalSource):
287406
for i, s in enumerate(result.size()):
288407
if isinstance(s, torch.SymInt) and isinstance(
@@ -535,3 +654,35 @@ def _to_sympy(x: int | torch.SymInt) -> sympy.Expr:
535654

536655
def _has_unbacked(expr: sympy.Expr) -> bool:
537656
return any(n.name.startswith("u") for n in expr.free_symbols) # pyright: ignore[reportAttributeAccessIssue]
657+
658+
659+
def compute_broadcast_shape_for_tensor_indexers(
660+
shapes: list[list[int | torch.SymInt]],
661+
env: "CompileEnvironment"
662+
) -> list[int | torch.SymInt]:
663+
"""
664+
Compute broadcast shape for multiple tensor indexers using right-aligned broadcasting.
665+
666+
Args:
667+
shapes: List of shapes from each tensor indexer
668+
env: CompileEnvironment for size_hint and known_equal checks
669+
670+
Returns:
671+
Broadcast shape as list of dimensions
672+
"""
673+
if not shapes:
674+
return []
675+
676+
max_ndim = max(len(s) for s in shapes)
677+
padded = [([1] * (max_ndim - len(s)) + s) for s in shapes]
678+
broadcast_shape: list[int | torch.SymInt] = []
679+
680+
for dims_at_pos in zip(*padded, strict=True):
681+
chosen: int | torch.SymInt | None = None
682+
for d in dims_at_pos:
683+
if env.size_hint(d) != 1:
684+
if chosen is None or env.known_equal(chosen, d):
685+
chosen = d
686+
broadcast_shape.append(chosen if chosen is not None else 1)
687+
688+
return broadcast_shape

0 commit comments

Comments
 (0)