Skip to content

Commit fe33e3b

Browse files
authored
Suggest use of @helion.kernel(index_dtype=torch.int64) if index offset is out of bound for int32 (#850)
1 parent 7efa2b0 commit fe33e3b

File tree

3 files changed

+123
-0
lines changed

3 files changed

+123
-0
lines changed

helion/_compiler/indexing_strategy.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import sympy
1010
import torch
1111
from torch._inductor.utils import triton_type
12+
from torch._prims_common import compute_required_storage_length
1213

1314
from .. import exc
1415
from .._compat import get_tensor_descriptor_fn_name
@@ -519,6 +520,31 @@ def compute_shape(
519520
assert len(input_size) == 0, "invalid subscript"
520521
return output_size
521522

523+
@staticmethod
524+
def _needs_int64(fake_value: torch.Tensor) -> bool:
525+
storage_offset = fake_value.storage_offset()
526+
527+
if not isinstance(storage_offset, int):
528+
return False
529+
530+
try:
531+
required = compute_required_storage_length(
532+
fake_value.shape,
533+
fake_value.stride(),
534+
storage_offset,
535+
)
536+
except Exception:
537+
return False
538+
539+
if not isinstance(required, int):
540+
return False
541+
542+
if abs(storage_offset) > torch.iinfo(torch.int32).max:
543+
return True
544+
545+
max_offset = required - 1
546+
return max_offset > torch.iinfo(torch.int32).max
547+
522548
@staticmethod
523549
def create(
524550
state: CodegenState,
@@ -533,6 +559,8 @@ def create(
533559
output_size = SubscriptIndexing.compute_shape(fake_value, index)
534560
env = CompileEnvironment.current()
535561
dtype = env.triton_index_type()
562+
if dtype == "tl.int32" and SubscriptIndexing._needs_int64(fake_value):
563+
raise exc.IndexOffsetOutOfRangeForInt32(env.settings.index_dtype)
536564

537565
def _is_size_one(size: int | torch.SymInt) -> bool:
538566
return env.known_equal(size, 1)

helion/exc.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,13 @@ class InvalidIndexingType(BaseError):
107107
message = "Expected tile/int/None/tensor/etc in tensor[...], got {0!s}."
108108

109109

110+
class IndexOffsetOutOfRangeForInt32(BaseError):
111+
message = (
112+
"Kernel index_dtype is {0}, but tensor indexing offsets exceed the int32 range. "
113+
"Use @helion.kernel(index_dtype=torch.int64) to enable larger offsets."
114+
)
115+
116+
110117
class DataDependentOutputShapeNotSupported(BaseError):
111118
message = (
112119
"{op_desc} is not supported in Helion device loops because it produces "

test/test_indexing.py

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from helion._testing import RefEagerTestBase
1212
from helion._testing import TestCase
1313
from helion._testing import code_and_output
14+
from helion._testing import skipIfLowVRAM
1415
from helion._testing import skipIfNormalMode
1516
from helion._testing import skipIfRefEager
1617
from helion._testing import skipIfRocm
@@ -241,6 +242,93 @@ def test_block_size_access(x: torch.Tensor) -> torch.Tensor:
241242
expected = torch.full_like(x, 1, dtype=torch.int32)
242243
torch.testing.assert_close(result, expected)
243244

245+
@skipIfRefEager(
246+
"IndexOffsetOutOfRangeForInt32 error is not raised in ref eager mode"
247+
)
248+
@skipIfLowVRAM("Test requires high VRAM")
249+
def test_int32_offset_out_of_range_error(self):
250+
repro_config = helion.Config(
251+
block_sizes=[32, 32],
252+
flatten_loops=[False],
253+
indexing="pointer",
254+
l2_groupings=[1],
255+
loop_orders=[[0, 1]],
256+
num_stages=3,
257+
num_warps=4,
258+
pid_type="flat",
259+
range_flattens=[None],
260+
range_multi_buffers=[None],
261+
range_num_stages=[],
262+
range_unroll_factors=[0],
263+
range_warp_specializes=[],
264+
)
265+
266+
def make_kernel(*, index_dtype: torch.dtype):
267+
kwargs = {"config": repro_config, "static_shapes": True}
268+
kwargs["index_dtype"] = index_dtype
269+
decorator = helion.kernel(**kwargs)
270+
271+
@decorator
272+
def repro_bf16_add(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
273+
x, y = torch.broadcast_tensors(x, y)
274+
out = torch.empty(
275+
x.shape,
276+
dtype=torch.promote_types(x.dtype, y.dtype),
277+
device=x.device,
278+
)
279+
for tile in hl.tile(out.size()):
280+
out[tile] = x[tile] + y[tile]
281+
return out
282+
283+
return repro_bf16_add
284+
285+
def run_case(
286+
shape, *, index_dtype, expect_int64_in_code=False, expect_error=False
287+
):
288+
kernel = make_kernel(index_dtype=index_dtype)
289+
x = torch.randn(*shape, device=DEVICE, dtype=torch.bfloat16)
290+
y = torch.randn(*shape, device=DEVICE, dtype=torch.bfloat16)
291+
torch.cuda.synchronize()
292+
if expect_error:
293+
with self.assertRaisesRegex(
294+
helion.exc.IndexOffsetOutOfRangeForInt32,
295+
f"index_dtype is {index_dtype}",
296+
):
297+
code_and_output(kernel, (x, y))
298+
torch.cuda.synchronize()
299+
return
300+
301+
code, out = code_and_output(kernel, (x, y))
302+
torch.cuda.synchronize()
303+
checker = self.assertIn if expect_int64_in_code else self.assertNotIn
304+
checker("tl.int64", code)
305+
torch.cuda.synchronize()
306+
ref_out = torch.add(x, y)
307+
torch.cuda.synchronize()
308+
torch.testing.assert_close(out, ref_out, rtol=1e-2, atol=1e-2)
309+
310+
small_shape = (128, 128)
311+
large_shape = (51200, 51200)
312+
313+
run_case(
314+
small_shape,
315+
index_dtype=torch.int32,
316+
expect_int64_in_code=False,
317+
expect_error=False,
318+
)
319+
run_case(
320+
large_shape,
321+
index_dtype=torch.int32,
322+
expect_int64_in_code=False,
323+
expect_error=True,
324+
)
325+
run_case(
326+
large_shape,
327+
index_dtype=torch.int64,
328+
expect_int64_in_code=True,
329+
expect_error=False,
330+
)
331+
244332
def test_assign_int(self):
245333
@helion.kernel
246334
def fn(x: torch.Tensor) -> torch.Tensor:

0 commit comments

Comments
 (0)