Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 37 additions & 8 deletions helion/_compiler/tile_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -515,20 +515,49 @@ def update_allow_flattened(cls, shape: Sequence[sympy.Expr]) -> None:
break

def compact_shape(self, shapes: list[CompactedShape]) -> list[CompactedShape]:
env = CompileEnvironment.current()
# Filter out unit-sized blocks that don't need compacting
compact_block_ids = [
block_id
for block_id in self.block_ids
if not (
isinstance(env.block_sizes[block_id].size, int)
and env.block_sizes[block_id].size == 1
)
]
if not compact_block_ids:
return shapes

output = []
shape_queue = collections.deque(shapes)
while shape_queue:
shape = shape_queue.popleft()
if len(shape.block_ids) != 1 or shape.block_ids[0] not in self.block_ids:
# Check if this starts our flattened sequence
if len(shape.block_ids) != 1 or shape.block_ids[0] != compact_block_ids[0]:
output.append(shape)
continue
assert shape.block_ids[0] == self.block_ids[0]
for expected in self.block_ids[1:]:
new_shape = shape_queue.popleft()
assert len(new_shape.block_ids) == 1
assert new_shape.block_ids[0] == expected
shape = shape.combine(new_shape)
output.append(shape)

# Try to collect the full sequence
group_shapes = [shape]
found_complete_sequence = True
for expected in compact_block_ids[1:]:
if (
shape_queue
and len(shape_queue[0].block_ids) == 1
and shape_queue[0].block_ids[0] == expected
):
group_shapes.append(shape_queue.popleft())
else:
# Partial match - don't combine
found_complete_sequence = False
output.extend(group_shapes)
break

if found_complete_sequence:
# Full match - combine into one
for s in group_shapes[1:]:
shape = shape.combine(s)
output.append(shape)
return output


Expand Down
32 changes: 32 additions & 0 deletions test/test_loops.expected
Original file line number Diff line number Diff line change
Expand Up @@ -549,6 +549,38 @@ def fn(x: torch.Tensor, begin: torch.Tensor, end: torch.Tensor, *, _launcher=_de
# src[test_loops.py:N]: return out
return out

--- assertExpectedJournal(TestLoops.test_flattened_tile_with_unit_axis)
from __future__ import annotations

import torch
import triton
import triton.language as tl
from helion.runtime import default_launcher as _default_launcher

@triton.jit
def _helion_silu_kernel(x, out, _BLOCK_SIZE_0_1: tl.constexpr):
# src[test_loops.py:N]: for tile in hl.tile(out.size()):
offsets_0_1 = tl.program_id(0) * _BLOCK_SIZE_0_1 + tl.arange(0, _BLOCK_SIZE_0_1).to(tl.int32)
indices_1 = offsets_0_1
mask_0_1 = offsets_0_1 < 100
# src[test_loops.py:N]: out[tile] = x[tile] * torch.sigmoid(x[tile])
load = tl.load(x + indices_1[None, :] * 1, mask_0_1[None, :], other=0, eviction_policy='evict_first')
load_1 = tl.load(x + indices_1[None, :] * 1, mask_0_1[None, :], other=0)
v_0 = tl.cast(tl.sigmoid(tl.cast(load_1, tl.float32)), tl.float16)
v_1 = load * v_0
tl.store(out + indices_1[None, :] * 1, v_1, mask_0_1[None, :])

def silu_kernel(x: torch.Tensor, *, _launcher=_default_launcher):
# src[test_loops.py:N]: out = torch.empty_like(x, dtype=x.dtype, device=x.device)
out = torch.empty_like(x, dtype=x.dtype, device=x.device)
# src[test_loops.py:N]: for tile in hl.tile(out.size()):
_BLOCK_SIZE_0_1 = 16
# src[test_loops.py:N]: for tile in hl.tile(out.size()):
# src[test_loops.py:N]: out[tile] = x[tile] * torch.sigmoid(x[tile])
_launcher(_helion_silu_kernel, (triton.cdiv(100, _BLOCK_SIZE_0_1), 1, 1), x, out, _BLOCK_SIZE_0_1, num_warps=1, num_stages=1)
# src[test_loops.py:N]: return out
return out

--- assertExpectedJournal(TestLoops.test_full_with_dynamic_fill_value)
from __future__ import annotations

Expand Down
34 changes: 34 additions & 0 deletions test/test_loops.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,40 @@ def test_3d_device_loop3(self):
torch.testing.assert_close(result, torch.sin(args[0]))
self.assertExpectedJournal(code)

@skipIfRefEager(
"Test is block size dependent which is not supported in ref eager mode"
)
def test_flattened_tile_with_unit_axis(self):
@helion.kernel(
config=helion.Config(
block_sizes=[1, 16],
flatten_loops=[True],
indexing=["block_ptr", "pointer", "pointer"],
l2_groupings=[1],
load_eviction_policies=["first"],
loop_orders=[[1, 0]],
num_stages=1,
num_warps=1,
pid_type="flat",
range_flattens=[None],
range_multi_buffers=[None],
range_num_stages=[0],
range_unroll_factors=[0],
range_warp_specializes=[],
),
static_shapes=True,
)
def silu_kernel(x: torch.Tensor) -> torch.Tensor:
out = torch.empty_like(x, dtype=x.dtype, device=x.device)
for tile in hl.tile(out.size()):
out[tile] = x[tile] * torch.sigmoid(x[tile])
return out

x = torch.randn((1, 100), dtype=torch.float16, device=DEVICE)
code, result = code_and_output(silu_kernel, (x,))
torch.testing.assert_close(result, torch.sigmoid(x) * x, rtol=1e-3, atol=1e-3)
self.assertExpectedJournal(code)

@patch.object(_compat, "_supports_tensor_descriptor", lambda: False)
def test_loop_fixed_block(self):
@helion.kernel(config={"block_sizes": [], "indexing": "block_ptr"})
Expand Down
Loading