2828@helion .kernel ( 
2929    # static_shapes=True gives a performance boost for matmuls  
3030    static_shapes = True , 
31+     config = helion .Config ( 
32+         block_sizes = [64 , 64 , 64 ], 
33+         loop_orders = [[0 , 1 ]], 
34+         l2_groupings = [4 ], 
35+         range_unroll_factors = [0 , 1 ], 
36+         range_num_stages = [0 , 3 ], 
37+         range_multi_buffers = [None , False ], 
38+         range_flattens = [None , None ], 
39+         num_warps = 8 , 
40+         num_stages = 6 , 
41+         indexing = 'tensor_descriptor' , 
42+         pid_type = 'flat'  
43+     ) 
3144) 
3245def  matmul (
3346    x : Tensor ,
@@ -44,17 +57,22 @@ def matmul(
4457    Returns: 
4558        Tensor: Resulting matrix of shape [m, n]. 
4659    """ 
60+ 
4761    m , k  =  x .size ()
4862    k2 , n  =  y .size ()
4963    assert  k  ==  k2 , f"size mismatch { k } { k2 }  
5064    out  =  torch .empty (
5165        [m , n ], dtype = torch .promote_types (x .dtype , y .dtype ), device = x .device 
5266    )
53-     for  tile_m , tile_n  in  hl .tile ([m , n ]):
67+     block_m  =  hl .register_block_size (m )
68+     block_n  =  hl .register_block_size (n )
69+     for  tile_m , tile_n  in  hl .tile ([m , n ], block_size = [block_m , block_n ]):
5470        acc  =  hl .zeros ([tile_m , tile_n ], dtype = torch .float32 )
5571        for  tile_k  in  hl .tile (k ):
5672            acc  =  torch .addmm (acc , x [tile_m , tile_k ], y [tile_k , tile_n ])
57-         out [tile_m , tile_n ] =  epilogue (acc , (tile_m , tile_n ))
73+ 
74+         acc  =  epilogue (acc , (tile_m , tile_n ))
75+         out [tile_m , tile_n ] =  acc 
5876    return  out 
5977
6078
@@ -298,97 +316,97 @@ def check(m: int, k: int, n: int) -> None:
298316    # Test without bias 
299317    run_example (matmul , torch .matmul , (x , y ))
300318
301-     # Test for addmm with scalar bias 
302-     def  addmm (bias : Tensor , mat1 : Tensor , mat2 : Tensor ) ->  Tensor :
303-         m , k  =  mat1 .size ()
304-         k2 , n  =  mat2 .size ()
305-         bias  =  torch .broadcast_to (bias , [m , n ])
306-         return  matmul (mat1 , mat2 , lambda  acc , tile : acc  +  bias [tile [0 ], tile [1 ]])
307- 
308-     run_example (addmm , torch .addmm , (bias_scalar , x , y ))
309- 
310-     # Test with bias 
311-     def  helion_linear (x : Tensor , y : Tensor , bias : Tensor ) ->  Tensor :
312-         return  matmul (x , y , lambda  acc , tile : acc  +  bias [tile [1 ]])
313- 
314-     def  baseline_linear (x : Tensor , y : Tensor , bias : Tensor ) ->  Tensor :
315-         return  torch .nn .functional .linear (x , y .T , bias )
316- 
317-     run_example (helion_linear , baseline_linear , (x , y , bias ))
318- 
319-     # Test more complex epilogue 
320-     def  epilogue (acc : Tensor , tile : tuple [Tensor , ...]) ->  Tensor :
321-         # The epilogue can use the captured bias tensor that is implicitly lifted to a kernel arg 
322-         return  torch .relu (acc  +  bias [tile [1 ]])
323- 
324-     def  kernel_wrapper (x : Tensor , y : Tensor ) ->  Tensor :
325-         return  matmul (x , y , epilogue )
326- 
327-     def  baseline_wrapper (x : Tensor , y : Tensor ) ->  Tensor :
328-         return  torch .relu (x  @ y  +  bias )
329- 
330-     run_example (
331-         kernel_wrapper ,
332-         baseline_wrapper ,
333-         (x , y ),
334-     )
335- 
336-     # Test matmul forward + backward pass 
337-     print ("\n \n === MatMul Forward + Backward Pass Test ===" )
338-     x_grad  =  torch .randn ([m , k ], device = DEVICE , dtype = torch .float16 , requires_grad = True )
339-     y_grad  =  torch .randn ([k , n ], device = DEVICE , dtype = torch .float16 , requires_grad = True )
340- 
341-     run_example (
342-         matmul_autograd ,
343-         torch .matmul ,
344-         (x_grad , y_grad ),
345-         kernel_name = "helion_matmul_autograd" ,
346-         baseline_name = "torch" ,
347-         rtol = 1e-2 ,
348-         atol = 1e-2 ,
349-         bwd = True ,
350-     )
351- 
352-     # Test addmm forward + backward pass 
353-     print ("\n \n === AddMM Forward + Backward Pass Test ===" )
354-     input_grad  =  torch .randn (
355-         [m , n ], device = DEVICE , dtype = torch .float16 , requires_grad = True 
356-     )
357-     mat1_grad  =  torch .randn (
358-         [m , k ], device = DEVICE , dtype = torch .float16 , requires_grad = True 
359-     )
360-     mat2_grad  =  torch .randn (
361-         [k , n ], device = DEVICE , dtype = torch .float16 , requires_grad = True 
362-     )
363- 
364-     # Use lambda to handle the keyword argument format for torch.addmm 
365-     run_example (
366-         addmm_autograd ,
367-         lambda  bias , mat1 , mat2 , alpha , beta : torch .addmm (
368-             bias , mat1 , mat2 , alpha = alpha , beta = beta 
369-         ),
370-         (input_grad , mat1_grad , mat2_grad , 1.0 , 1.0 ),
371-         kernel_name = "helion_addmm_autograd" ,
372-         baseline_name = "torch" ,
373-         rtol = 1e-2 ,
374-         atol = 1e-2 ,
375-         bwd = True ,
376-     )
377- 
378-     # Test addmm forward + backward with different alpha/beta values 
379-     print ("\n \n === AddMM Forward + Backward Test (Alpha=2.0, Beta=0.5) ===" )
380-     run_example (
381-         addmm_autograd ,
382-         lambda  bias , mat1 , mat2 , alpha , beta : torch .addmm (
383-             bias , mat1 , mat2 , alpha = alpha , beta = beta 
384-         ),
385-         (input_grad , mat1_grad , mat2_grad , 2.0 , 0.5 ),
386-         kernel_name = "helion_addmm_autograd_scaled" ,
387-         baseline_name = "torch" ,
388-         rtol = 1e-2 ,
389-         atol = 1e-2 ,
390-         bwd = True ,
391-     )
319+     # #  Test for addmm with scalar bias 
320+     #  def addmm(bias: Tensor, mat1: Tensor, mat2: Tensor) -> Tensor:
321+     #      m, k = mat1.size()
322+     #      k2, n = mat2.size()
323+     #      bias = torch.broadcast_to(bias, [m, n])
324+     #      return matmul(mat1, mat2, lambda acc, tile: acc + bias[tile[0], tile[1]])
325+ 
326+     #  run_example(addmm, torch.addmm, (bias_scalar, x, y))
327+ 
328+     # #  Test with bias 
329+     #  def helion_linear(x: Tensor, y: Tensor, bias: Tensor) -> Tensor:
330+     #      return matmul(x, y, lambda acc, tile: acc + bias[tile[1]])
331+ 
332+     #  def baseline_linear(x: Tensor, y: Tensor, bias: Tensor) -> Tensor:
333+     #      return torch.nn.functional.linear(x, y.T, bias)
334+ 
335+     #  run_example(helion_linear, baseline_linear, (x, y, bias))
336+ 
337+     # #  Test more complex epilogue 
338+     #  def epilogue(acc: Tensor, tile: tuple[Tensor, ...]) -> Tensor:
339+     #      # The epilogue can use the captured bias tensor that is implicitly lifted to a kernel arg
340+     #      return torch.relu(acc + bias[tile[1]])
341+ 
342+     #  def kernel_wrapper(x: Tensor, y: Tensor) -> Tensor:
343+     #      return matmul(x, y, epilogue)
344+ 
345+     #  def baseline_wrapper(x: Tensor, y: Tensor) -> Tensor:
346+     #      return torch.relu(x @ y + bias)
347+ 
348+     #  run_example(
349+     #      kernel_wrapper,
350+     #      baseline_wrapper,
351+     #      (x, y),
352+     #  )
353+ 
354+     # #  Test matmul forward + backward pass 
355+     #  print("\n\n=== MatMul Forward + Backward Pass Test ===")
356+     #  x_grad = torch.randn([m, k], device=DEVICE, dtype=torch.float16, requires_grad=True)
357+     #  y_grad = torch.randn([k, n], device=DEVICE, dtype=torch.float16, requires_grad=True)
358+ 
359+     #  run_example(
360+     #      matmul_autograd,
361+     #      torch.matmul,
362+     #      (x_grad, y_grad),
363+     #      kernel_name="helion_matmul_autograd",
364+     #      baseline_name="torch",
365+     #      rtol=1e-2,
366+     #      atol=1e-2,
367+     #      bwd=True,
368+     #  )
369+ 
370+     # #  Test addmm forward + backward pass 
371+     #  print("\n\n=== AddMM Forward + Backward Pass Test ===")
372+     #  input_grad = torch.randn(
373+     #      [m, n], device=DEVICE, dtype=torch.float16, requires_grad=True
374+     #  )
375+     #  mat1_grad = torch.randn(
376+     #      [m, k], device=DEVICE, dtype=torch.float16, requires_grad=True
377+     #  )
378+     #  mat2_grad = torch.randn(
379+     #      [k, n], device=DEVICE, dtype=torch.float16, requires_grad=True
380+     #  )
381+ 
382+     # #  Use lambda to handle the keyword argument format for torch.addmm 
383+     #  run_example(
384+     #      addmm_autograd,
385+     #      lambda bias, mat1, mat2, alpha, beta: torch.addmm(
386+     #          bias, mat1, mat2, alpha=alpha, beta=beta
387+     #      ),
388+     #      (input_grad, mat1_grad, mat2_grad, 1.0, 1.0),
389+     #      kernel_name="helion_addmm_autograd",
390+     #      baseline_name="torch",
391+     #      rtol=1e-2,
392+     #      atol=1e-2,
393+     #      bwd=True,
394+     #  )
395+ 
396+     # #  Test addmm forward + backward with different alpha/beta values 
397+     #  print("\n\n=== AddMM Forward + Backward Test (Alpha=2.0, Beta=0.5) ===")
398+     #  run_example(
399+     #      addmm_autograd,
400+     #      lambda bias, mat1, mat2, alpha, beta: torch.addmm(
401+     #          bias, mat1, mat2, alpha=alpha, beta=beta
402+     #      ),
403+     #      (input_grad, mat1_grad, mat2_grad, 2.0, 0.5),
404+     #      kernel_name="helion_addmm_autograd_scaled",
405+     #      baseline_name="torch",
406+     #      rtol=1e-2,
407+     #      atol=1e-2,
408+     #      bwd=True,
409+     #  )
392410
393411
394412# %% 
0 commit comments