@@ -293,6 +293,74 @@ def torch_nonzero_in_device_code(x: torch.Tensor) -> torch.Tensor:
293293 torch_nonzero_in_device_code , (torch .randn (2 , 2 , device = DEVICE ),)
294294 )
295295
296+ def test_torch_chunk_device_error (self ):
297+ """Test that torch.chunk raises error in device loops and suggests hl.split()."""
298+
299+ @helion .kernel (use_default_config = True , static_shapes = True )
300+ def kernel_with_chunk (q : torch .Tensor ) -> torch .Tensor :
301+ _ , _ , M , D = q .shape
302+ D = hl .specialize (D )
303+ M = hl .specialize (M )
304+ q = q .reshape (- 1 , D )
305+ total_rows = q .shape [0 ]
306+ block_m = hl .register_block_size (M )
307+ result = hl .zeros ([total_rows , D ])
308+ for tile_m in hl .tile (total_rows , block_size = block_m ):
309+ acc = hl .zeros ([tile_m , D ])
310+
311+ for _tile_n in hl .tile (M , block_size = block_m ):
312+ acc = torch .stack (torch .chunk (acc , 2 , dim = - 1 ), dim = - 2 ).reshape (
313+ acc .shape
314+ )
315+ acc = acc + 0
316+
317+ result [tile_m , :] = acc
318+
319+ return result
320+
321+ with self .assertRaisesRegex (
322+ helion .exc .UnsupportedSplitOperation ,
323+ r"torch\.chunk is not supported in Helion device loops.*hl\.split\(\)" ,
324+ ):
325+ code_and_output (
326+ kernel_with_chunk ,
327+ (torch .randn (1 , 1 , 128 , 128 , device = DEVICE , dtype = torch .bfloat16 ),),
328+ )
329+
330+ def test_torch_unbind_device_error (self ):
331+ """Test that torch.unbind raises error in device loops and suggests hl.split()."""
332+
333+ @helion .kernel (use_default_config = True , static_shapes = True )
334+ def kernel_with_unbind (q : torch .Tensor ) -> torch .Tensor :
335+ _ , _ , M , D = q .shape
336+ D = hl .specialize (D )
337+ M = hl .specialize (M )
338+ q = q .reshape (- 1 , D )
339+ total_rows = q .shape [0 ]
340+ block_m = hl .register_block_size (M )
341+ result = hl .zeros ([total_rows , D ])
342+ for tile_m in hl .tile (total_rows , block_size = block_m ):
343+ acc = hl .zeros ([tile_m , D ])
344+
345+ for _tile_n in hl .tile (M , block_size = block_m ):
346+ reshaped = acc .reshape (tile_m , 2 , D // 2 )
347+ acc0 , acc1 = torch .unbind (reshaped , dim = 1 )
348+ acc = torch .stack ((acc0 , acc1 ), dim = 1 ).reshape (tile_m , D )
349+ acc = acc + 0
350+
351+ result [tile_m , :] = acc
352+
353+ return result
354+
355+ with self .assertRaisesRegex (
356+ helion .exc .UnsupportedSplitOperation ,
357+ r"torch\.unbind is not supported in Helion device loops.*hl\.split\(\)" ,
358+ ):
359+ code_and_output (
360+ kernel_with_unbind ,
361+ (torch .randn (1 , 1 , 128 , 128 , device = DEVICE , dtype = torch .bfloat16 ),),
362+ )
363+
296364 def test_closure_fn (self ):
297365 @helion .kernel ()
298366 def bad_fn (x : torch .Tensor ) -> torch .Tensor :
0 commit comments