Skip to content

Commit 6ad53ad

Browse files
authored
Print error message for torch.chunk / torch.unbind to redirect users to hl.split (#921)
1 parent c8f83fb commit 6ad53ad

File tree

3 files changed

+80
-0
lines changed

3 files changed

+80
-0
lines changed

helion/_compiler/type_propagation.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -773,6 +773,10 @@ def propagate_call( # pyright: ignore[reportIncompatibleMethodOverride]
773773
) -> TypeInfo | None:
774774
if self.value in (torch.nonzero, torch.Tensor.nonzero) and origin.is_device():
775775
raise exc.DataDependentOutputShapeNotSupported(op_desc="torch.nonzero")
776+
if self.value in (torch.chunk, torch.Tensor.chunk) and origin.is_device():
777+
raise exc.UnsupportedSplitOperation(op="torch.chunk")
778+
if self.value in (torch.unbind, torch.Tensor.unbind) and origin.is_device():
779+
raise exc.UnsupportedSplitOperation(op="torch.unbind")
776780
if is_api_func(fn := self.value):
777781
if fn._is_device_only and origin.is_host():
778782
raise exc.DeviceAPIOnHost(fn.__qualname__)

helion/exc.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,14 @@ class DataDependentOutputShapeNotSupported(BaseError):
114114
)
115115

116116

117+
class UnsupportedSplitOperation(BaseError):
118+
message = (
119+
"{op} is not supported in Helion device loops. "
120+
"For splitting the last dimension with size 2, use hl.split(). "
121+
"For other splitting operations, consider reshaping or using other hl.* operations."
122+
)
123+
124+
117125
class RequiresTensorInAssignment(BaseError):
118126
message = "Expected tensor in right-hand side of assignment, got {0!s}."
119127

test/test_errors.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -293,6 +293,74 @@ def torch_nonzero_in_device_code(x: torch.Tensor) -> torch.Tensor:
293293
torch_nonzero_in_device_code, (torch.randn(2, 2, device=DEVICE),)
294294
)
295295

296+
def test_torch_chunk_device_error(self):
297+
"""Test that torch.chunk raises error in device loops and suggests hl.split()."""
298+
299+
@helion.kernel(use_default_config=True, static_shapes=True)
300+
def kernel_with_chunk(q: torch.Tensor) -> torch.Tensor:
301+
_, _, M, D = q.shape
302+
D = hl.specialize(D)
303+
M = hl.specialize(M)
304+
q = q.reshape(-1, D)
305+
total_rows = q.shape[0]
306+
block_m = hl.register_block_size(M)
307+
result = hl.zeros([total_rows, D])
308+
for tile_m in hl.tile(total_rows, block_size=block_m):
309+
acc = hl.zeros([tile_m, D])
310+
311+
for _tile_n in hl.tile(M, block_size=block_m):
312+
acc = torch.stack(torch.chunk(acc, 2, dim=-1), dim=-2).reshape(
313+
acc.shape
314+
)
315+
acc = acc + 0
316+
317+
result[tile_m, :] = acc
318+
319+
return result
320+
321+
with self.assertRaisesRegex(
322+
helion.exc.UnsupportedSplitOperation,
323+
r"torch\.chunk is not supported in Helion device loops.*hl\.split\(\)",
324+
):
325+
code_and_output(
326+
kernel_with_chunk,
327+
(torch.randn(1, 1, 128, 128, device=DEVICE, dtype=torch.bfloat16),),
328+
)
329+
330+
def test_torch_unbind_device_error(self):
331+
"""Test that torch.unbind raises error in device loops and suggests hl.split()."""
332+
333+
@helion.kernel(use_default_config=True, static_shapes=True)
334+
def kernel_with_unbind(q: torch.Tensor) -> torch.Tensor:
335+
_, _, M, D = q.shape
336+
D = hl.specialize(D)
337+
M = hl.specialize(M)
338+
q = q.reshape(-1, D)
339+
total_rows = q.shape[0]
340+
block_m = hl.register_block_size(M)
341+
result = hl.zeros([total_rows, D])
342+
for tile_m in hl.tile(total_rows, block_size=block_m):
343+
acc = hl.zeros([tile_m, D])
344+
345+
for _tile_n in hl.tile(M, block_size=block_m):
346+
reshaped = acc.reshape(tile_m, 2, D // 2)
347+
acc0, acc1 = torch.unbind(reshaped, dim=1)
348+
acc = torch.stack((acc0, acc1), dim=1).reshape(tile_m, D)
349+
acc = acc + 0
350+
351+
result[tile_m, :] = acc
352+
353+
return result
354+
355+
with self.assertRaisesRegex(
356+
helion.exc.UnsupportedSplitOperation,
357+
r"torch\.unbind is not supported in Helion device loops.*hl\.split\(\)",
358+
):
359+
code_and_output(
360+
kernel_with_unbind,
361+
(torch.randn(1, 1, 128, 128, device=DEVICE, dtype=torch.bfloat16),),
362+
)
363+
296364
def test_closure_fn(self):
297365
@helion.kernel()
298366
def bad_fn(x: torch.Tensor) -> torch.Tensor:

0 commit comments

Comments
 (0)