Skip to content

Commit 5be73dd

Browse files
authored
Fix FlattenedTileStrategy to handle unit-sized block dimensions (#1048)
1 parent 79d57b7 commit 5be73dd

File tree

3 files changed

+103
-8
lines changed

3 files changed

+103
-8
lines changed

helion/_compiler/tile_strategy.py

Lines changed: 37 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -515,20 +515,49 @@ def update_allow_flattened(cls, shape: Sequence[sympy.Expr]) -> None:
515515
break
516516

517517
def compact_shape(self, shapes: list[CompactedShape]) -> list[CompactedShape]:
518+
env = CompileEnvironment.current()
519+
# Filter out unit-sized blocks that don't need compacting
520+
compact_block_ids = [
521+
block_id
522+
for block_id in self.block_ids
523+
if not (
524+
isinstance(env.block_sizes[block_id].size, int)
525+
and env.block_sizes[block_id].size == 1
526+
)
527+
]
528+
if not compact_block_ids:
529+
return shapes
530+
518531
output = []
519532
shape_queue = collections.deque(shapes)
520533
while shape_queue:
521534
shape = shape_queue.popleft()
522-
if len(shape.block_ids) != 1 or shape.block_ids[0] not in self.block_ids:
535+
# Check if this starts our flattened sequence
536+
if len(shape.block_ids) != 1 or shape.block_ids[0] != compact_block_ids[0]:
523537
output.append(shape)
524538
continue
525-
assert shape.block_ids[0] == self.block_ids[0]
526-
for expected in self.block_ids[1:]:
527-
new_shape = shape_queue.popleft()
528-
assert len(new_shape.block_ids) == 1
529-
assert new_shape.block_ids[0] == expected
530-
shape = shape.combine(new_shape)
531-
output.append(shape)
539+
540+
# Try to collect the full sequence
541+
group_shapes = [shape]
542+
found_complete_sequence = True
543+
for expected in compact_block_ids[1:]:
544+
if (
545+
shape_queue
546+
and len(shape_queue[0].block_ids) == 1
547+
and shape_queue[0].block_ids[0] == expected
548+
):
549+
group_shapes.append(shape_queue.popleft())
550+
else:
551+
# Partial match - don't combine
552+
found_complete_sequence = False
553+
output.extend(group_shapes)
554+
break
555+
556+
if found_complete_sequence:
557+
# Full match - combine into one
558+
for s in group_shapes[1:]:
559+
shape = shape.combine(s)
560+
output.append(shape)
532561
return output
533562

534563

test/test_loops.expected

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -549,6 +549,38 @@ def fn(x: torch.Tensor, begin: torch.Tensor, end: torch.Tensor, *, _launcher=_de
549549
# src[test_loops.py:N]: return out
550550
return out
551551

552+
--- assertExpectedJournal(TestLoops.test_flattened_tile_with_unit_axis)
553+
from __future__ import annotations
554+
555+
import torch
556+
import triton
557+
import triton.language as tl
558+
from helion.runtime import default_launcher as _default_launcher
559+
560+
@triton.jit
561+
def _helion_silu_kernel(x, out, _BLOCK_SIZE_0_1: tl.constexpr):
562+
# src[test_loops.py:N]: for tile in hl.tile(out.size()):
563+
offsets_0_1 = tl.program_id(0) * _BLOCK_SIZE_0_1 + tl.arange(0, _BLOCK_SIZE_0_1).to(tl.int32)
564+
indices_1 = offsets_0_1
565+
mask_0_1 = offsets_0_1 < 100
566+
# src[test_loops.py:N]: out[tile] = x[tile] * torch.sigmoid(x[tile])
567+
load = tl.load(x + indices_1[None, :] * 1, mask_0_1[None, :], other=0, eviction_policy='evict_first')
568+
load_1 = tl.load(x + indices_1[None, :] * 1, mask_0_1[None, :], other=0)
569+
v_0 = tl.cast(tl.sigmoid(tl.cast(load_1, tl.float32)), tl.float16)
570+
v_1 = load * v_0
571+
tl.store(out + indices_1[None, :] * 1, v_1, mask_0_1[None, :])
572+
573+
def silu_kernel(x: torch.Tensor, *, _launcher=_default_launcher):
574+
# src[test_loops.py:N]: out = torch.empty_like(x, dtype=x.dtype, device=x.device)
575+
out = torch.empty_like(x, dtype=x.dtype, device=x.device)
576+
# src[test_loops.py:N]: for tile in hl.tile(out.size()):
577+
_BLOCK_SIZE_0_1 = 16
578+
# src[test_loops.py:N]: for tile in hl.tile(out.size()):
579+
# src[test_loops.py:N]: out[tile] = x[tile] * torch.sigmoid(x[tile])
580+
_launcher(_helion_silu_kernel, (triton.cdiv(100, _BLOCK_SIZE_0_1), 1, 1), x, out, _BLOCK_SIZE_0_1, num_warps=1, num_stages=1)
581+
# src[test_loops.py:N]: return out
582+
return out
583+
552584
--- assertExpectedJournal(TestLoops.test_full_with_dynamic_fill_value)
553585
from __future__ import annotations
554586

test/test_loops.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,40 @@ def test_3d_device_loop3(self):
114114
torch.testing.assert_close(result, torch.sin(args[0]))
115115
self.assertExpectedJournal(code)
116116

117+
@skipIfRefEager(
118+
"Test is block size dependent which is not supported in ref eager mode"
119+
)
120+
def test_flattened_tile_with_unit_axis(self):
121+
@helion.kernel(
122+
config=helion.Config(
123+
block_sizes=[1, 16],
124+
flatten_loops=[True],
125+
indexing=["block_ptr", "pointer", "pointer"],
126+
l2_groupings=[1],
127+
load_eviction_policies=["first"],
128+
loop_orders=[[1, 0]],
129+
num_stages=1,
130+
num_warps=1,
131+
pid_type="flat",
132+
range_flattens=[None],
133+
range_multi_buffers=[None],
134+
range_num_stages=[0],
135+
range_unroll_factors=[0],
136+
range_warp_specializes=[],
137+
),
138+
static_shapes=True,
139+
)
140+
def silu_kernel(x: torch.Tensor) -> torch.Tensor:
141+
out = torch.empty_like(x, dtype=x.dtype, device=x.device)
142+
for tile in hl.tile(out.size()):
143+
out[tile] = x[tile] * torch.sigmoid(x[tile])
144+
return out
145+
146+
x = torch.randn((1, 100), dtype=torch.float16, device=DEVICE)
147+
code, result = code_and_output(silu_kernel, (x,))
148+
torch.testing.assert_close(result, torch.sigmoid(x) * x, rtol=1e-3, atol=1e-3)
149+
self.assertExpectedJournal(code)
150+
117151
@patch.object(_compat, "_supports_tensor_descriptor", lambda: False)
118152
def test_loop_fixed_block(self):
119153
@helion.kernel(config={"block_sizes": [], "indexing": "block_ptr"})

0 commit comments

Comments
 (0)