2828@helion .kernel (
2929 # static_shapes=True gives a performance boost for matmuls
3030 static_shapes = True ,
31+ config = helion .Config (
32+ block_sizes = [64 , 64 , 64 ],
33+ loop_orders = [[0 , 1 ]],
34+ l2_groupings = [4 ],
35+ range_unroll_factors = [0 , 1 ],
36+ range_num_stages = [0 , 3 ],
37+ range_multi_buffers = [None , False ],
38+ range_flattens = [None , None ],
39+ num_warps = 8 ,
40+ num_stages = 6 ,
41+ indexing = 'tensor_descriptor' ,
42+ pid_type = 'flat'
43+ )
3144)
3245def matmul (
3346 x : Tensor ,
@@ -44,6 +57,7 @@ def matmul(
4457 Returns:
4558 Tensor: Resulting matrix of shape [m, n].
4659 """
60+
4761 m , k = x .size ()
4862 k2 , n = y .size ()
4963 assert k == k2 , f"size mismatch { k } != { k2 } "
@@ -298,97 +312,97 @@ def check(m: int, k: int, n: int) -> None:
298312 # Test without bias
299313 run_example (matmul , torch .matmul , (x , y ))
300314
301- # Test for addmm with scalar bias
302- def addmm (bias : Tensor , mat1 : Tensor , mat2 : Tensor ) -> Tensor :
303- m , k = mat1 .size ()
304- k2 , n = mat2 .size ()
305- bias = torch .broadcast_to (bias , [m , n ])
306- return matmul (mat1 , mat2 , lambda acc , tile : acc + bias [tile [0 ], tile [1 ]])
307-
308- run_example (addmm , torch .addmm , (bias_scalar , x , y ))
309-
310- # Test with bias
311- def helion_linear (x : Tensor , y : Tensor , bias : Tensor ) -> Tensor :
312- return matmul (x , y , lambda acc , tile : acc + bias [tile [1 ]])
313-
314- def baseline_linear (x : Tensor , y : Tensor , bias : Tensor ) -> Tensor :
315- return torch .nn .functional .linear (x , y .T , bias )
316-
317- run_example (helion_linear , baseline_linear , (x , y , bias ))
318-
319- # Test more complex epilogue
320- def epilogue (acc : Tensor , tile : tuple [Tensor , ...]) -> Tensor :
321- # The epilogue can use the captured bias tensor that is implicitly lifted to a kernel arg
322- return torch .relu (acc + bias [tile [1 ]])
323-
324- def kernel_wrapper (x : Tensor , y : Tensor ) -> Tensor :
325- return matmul (x , y , epilogue )
326-
327- def baseline_wrapper (x : Tensor , y : Tensor ) -> Tensor :
328- return torch .relu (x @ y + bias )
329-
330- run_example (
331- kernel_wrapper ,
332- baseline_wrapper ,
333- (x , y ),
334- )
335-
336- # Test matmul forward + backward pass
337- print ("\n \n === MatMul Forward + Backward Pass Test ===" )
338- x_grad = torch .randn ([m , k ], device = DEVICE , dtype = torch .float16 , requires_grad = True )
339- y_grad = torch .randn ([k , n ], device = DEVICE , dtype = torch .float16 , requires_grad = True )
340-
341- run_example (
342- matmul_autograd ,
343- torch .matmul ,
344- (x_grad , y_grad ),
345- kernel_name = "helion_matmul_autograd" ,
346- baseline_name = "torch" ,
347- rtol = 1e-2 ,
348- atol = 1e-2 ,
349- bwd = True ,
350- )
351-
352- # Test addmm forward + backward pass
353- print ("\n \n === AddMM Forward + Backward Pass Test ===" )
354- input_grad = torch .randn (
355- [m , n ], device = DEVICE , dtype = torch .float16 , requires_grad = True
356- )
357- mat1_grad = torch .randn (
358- [m , k ], device = DEVICE , dtype = torch .float16 , requires_grad = True
359- )
360- mat2_grad = torch .randn (
361- [k , n ], device = DEVICE , dtype = torch .float16 , requires_grad = True
362- )
363-
364- # Use lambda to handle the keyword argument format for torch.addmm
365- run_example (
366- addmm_autograd ,
367- lambda bias , mat1 , mat2 , alpha , beta : torch .addmm (
368- bias , mat1 , mat2 , alpha = alpha , beta = beta
369- ),
370- (input_grad , mat1_grad , mat2_grad , 1.0 , 1.0 ),
371- kernel_name = "helion_addmm_autograd" ,
372- baseline_name = "torch" ,
373- rtol = 1e-2 ,
374- atol = 1e-2 ,
375- bwd = True ,
376- )
377-
378- # Test addmm forward + backward with different alpha/beta values
379- print ("\n \n === AddMM Forward + Backward Test (Alpha=2.0, Beta=0.5) ===" )
380- run_example (
381- addmm_autograd ,
382- lambda bias , mat1 , mat2 , alpha , beta : torch .addmm (
383- bias , mat1 , mat2 , alpha = alpha , beta = beta
384- ),
385- (input_grad , mat1_grad , mat2_grad , 2.0 , 0.5 ),
386- kernel_name = "helion_addmm_autograd_scaled" ,
387- baseline_name = "torch" ,
388- rtol = 1e-2 ,
389- atol = 1e-2 ,
390- bwd = True ,
391- )
315+ # # Test for addmm with scalar bias
316+ # def addmm(bias: Tensor, mat1: Tensor, mat2: Tensor) -> Tensor:
317+ # m, k = mat1.size()
318+ # k2, n = mat2.size()
319+ # bias = torch.broadcast_to(bias, [m, n])
320+ # return matmul(mat1, mat2, lambda acc, tile: acc + bias[tile[0], tile[1]])
321+
322+ # run_example(addmm, torch.addmm, (bias_scalar, x, y))
323+
324+ # # Test with bias
325+ # def helion_linear(x: Tensor, y: Tensor, bias: Tensor) -> Tensor:
326+ # return matmul(x, y, lambda acc, tile: acc + bias[tile[1]])
327+
328+ # def baseline_linear(x: Tensor, y: Tensor, bias: Tensor) -> Tensor:
329+ # return torch.nn.functional.linear(x, y.T, bias)
330+
331+ # run_example(helion_linear, baseline_linear, (x, y, bias))
332+
333+ # # Test more complex epilogue
334+ # def epilogue(acc: Tensor, tile: tuple[Tensor, ...]) -> Tensor:
335+ # # The epilogue can use the captured bias tensor that is implicitly lifted to a kernel arg
336+ # return torch.relu(acc + bias[tile[1]])
337+
338+ # def kernel_wrapper(x: Tensor, y: Tensor) -> Tensor:
339+ # return matmul(x, y, epilogue)
340+
341+ # def baseline_wrapper(x: Tensor, y: Tensor) -> Tensor:
342+ # return torch.relu(x @ y + bias)
343+
344+ # run_example(
345+ # kernel_wrapper,
346+ # baseline_wrapper,
347+ # (x, y),
348+ # )
349+
350+ # # Test matmul forward + backward pass
351+ # print("\n\n=== MatMul Forward + Backward Pass Test ===")
352+ # x_grad = torch.randn([m, k], device=DEVICE, dtype=torch.float16, requires_grad=True)
353+ # y_grad = torch.randn([k, n], device=DEVICE, dtype=torch.float16, requires_grad=True)
354+
355+ # run_example(
356+ # matmul_autograd,
357+ # torch.matmul,
358+ # (x_grad, y_grad),
359+ # kernel_name="helion_matmul_autograd",
360+ # baseline_name="torch",
361+ # rtol=1e-2,
362+ # atol=1e-2,
363+ # bwd=True,
364+ # )
365+
366+ # # Test addmm forward + backward pass
367+ # print("\n\n=== AddMM Forward + Backward Pass Test ===")
368+ # input_grad = torch.randn(
369+ # [m, n], device=DEVICE, dtype=torch.float16, requires_grad=True
370+ # )
371+ # mat1_grad = torch.randn(
372+ # [m, k], device=DEVICE, dtype=torch.float16, requires_grad=True
373+ # )
374+ # mat2_grad = torch.randn(
375+ # [k, n], device=DEVICE, dtype=torch.float16, requires_grad=True
376+ # )
377+
378+ # # Use lambda to handle the keyword argument format for torch.addmm
379+ # run_example(
380+ # addmm_autograd,
381+ # lambda bias, mat1, mat2, alpha, beta: torch.addmm(
382+ # bias, mat1, mat2, alpha=alpha, beta=beta
383+ # ),
384+ # (input_grad, mat1_grad, mat2_grad, 1.0, 1.0),
385+ # kernel_name="helion_addmm_autograd",
386+ # baseline_name="torch",
387+ # rtol=1e-2,
388+ # atol=1e-2,
389+ # bwd=True,
390+ # )
391+
392+ # # Test addmm forward + backward with different alpha/beta values
393+ # print("\n\n=== AddMM Forward + Backward Test (Alpha=2.0, Beta=0.5) ===")
394+ # run_example(
395+ # addmm_autograd,
396+ # lambda bias, mat1, mat2, alpha, beta: torch.addmm(
397+ # bias, mat1, mat2, alpha=alpha, beta=beta
398+ # ),
399+ # (input_grad, mat1_grad, mat2_grad, 2.0, 0.5),
400+ # kernel_name="helion_addmm_autograd_scaled",
401+ # baseline_name="torch",
402+ # rtol=1e-2,
403+ # atol=1e-2,
404+ # bwd=True,
405+ # )
392406
393407
394408# %%
0 commit comments