Skip to content

Commit 85d2bb5

Browse files
committed
test
1 parent 7db09b7 commit 85d2bb5

File tree

2 files changed

+66
-0
lines changed

2 files changed

+66
-0
lines changed

test/test_loops.expected

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -549,6 +549,38 @@ def fn(x: torch.Tensor, begin: torch.Tensor, end: torch.Tensor, *, _launcher=_de
549549
# src[test_loops.py:N]: return out
550550
return out
551551

552+
--- assertExpectedJournal(TestLoops.test_flattened_tile_with_unit_axis)
553+
from __future__ import annotations
554+
555+
import torch
556+
import triton
557+
import triton.language as tl
558+
from helion.runtime import default_launcher as _default_launcher
559+
560+
@triton.jit
561+
def _helion_silu_kernel(x, out, _BLOCK_SIZE_0_1: tl.constexpr):
562+
# src[test_loops.py:N]: for tile in hl.tile(out.size()):
563+
offsets_0_1 = tl.program_id(0) * _BLOCK_SIZE_0_1 + tl.arange(0, _BLOCK_SIZE_0_1).to(tl.int32)
564+
indices_1 = offsets_0_1
565+
mask_0_1 = offsets_0_1 < 100
566+
# src[test_loops.py:N]: out[tile] = x[tile] * torch.sigmoid(x[tile])
567+
load = tl.load(x + indices_1[None, :] * 1, mask_0_1[None, :], other=0, eviction_policy='evict_first')
568+
load_1 = tl.load(x + indices_1[None, :] * 1, mask_0_1[None, :], other=0)
569+
v_0 = tl.cast(tl.sigmoid(tl.cast(load_1, tl.float32)), tl.float16)
570+
v_1 = load * v_0
571+
tl.store(out + indices_1[None, :] * 1, v_1, mask_0_1[None, :])
572+
573+
def silu_kernel(x: torch.Tensor, *, _launcher=_default_launcher):
574+
# src[test_loops.py:N]: out = torch.empty_like(x, dtype=x.dtype, device=x.device)
575+
out = torch.empty_like(x, dtype=x.dtype, device=x.device)
576+
# src[test_loops.py:N]: for tile in hl.tile(out.size()):
577+
_BLOCK_SIZE_0_1 = 128
578+
# src[test_loops.py:N]: for tile in hl.tile(out.size()):
579+
# src[test_loops.py:N]: out[tile] = x[tile] * torch.sigmoid(x[tile])
580+
_launcher(_helion_silu_kernel, (triton.cdiv(100, _BLOCK_SIZE_0_1), 1, 1), x, out, _BLOCK_SIZE_0_1, num_warps=32, num_stages=8)
581+
# src[test_loops.py:N]: return out
582+
return out
583+
552584
--- assertExpectedJournal(TestLoops.test_full_with_dynamic_fill_value)
553585
from __future__ import annotations
554586

test/test_loops.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,40 @@ def test_3d_device_loop3(self):
114114
torch.testing.assert_close(result, torch.sin(args[0]))
115115
self.assertExpectedJournal(code)
116116

117+
@skipIfRefEager(
118+
"Test is block size dependent which is not supported in ref eager mode"
119+
)
120+
def test_flattened_tile_with_unit_axis(self):
121+
@helion.kernel(
122+
config=helion.Config(
123+
block_sizes=[1, 16],
124+
flatten_loops=[True],
125+
indexing=["block_ptr", "pointer", "pointer"],
126+
l2_groupings=[1],
127+
load_eviction_policies=["first"],
128+
loop_orders=[[1, 0]],
129+
num_stages=1,
130+
num_warps=1,
131+
pid_type="flat",
132+
range_flattens=[None],
133+
range_multi_buffers=[None],
134+
range_num_stages=[0],
135+
range_unroll_factors=[0],
136+
range_warp_specializes=[],
137+
),
138+
static_shapes=True,
139+
)
140+
def silu_kernel(x: torch.Tensor) -> torch.Tensor:
141+
out = torch.empty_like(x, dtype=x.dtype, device=x.device)
142+
for tile in hl.tile(out.size()):
143+
out[tile] = x[tile] * torch.sigmoid(x[tile])
144+
return out
145+
146+
x = torch.randn((1, 100), dtype=torch.float16, device=DEVICE)
147+
code, result = code_and_output(silu_kernel, (x,))
148+
torch.testing.assert_close(result, torch.sigmoid(x) * x, rtol=1e-3, atol=1e-3)
149+
self.assertExpectedJournal(code)
150+
117151
@patch.object(_compat, "_supports_tensor_descriptor", lambda: False)
118152
def test_loop_fixed_block(self):
119153
@helion.kernel(config={"block_sizes": [], "indexing": "block_ptr"})

0 commit comments

Comments
 (0)