@@ -96,14 +96,15 @@ def triton_fp8_gemm_1x128_128x128_kernel(
9696 tl .store (c_ptrs , c , mask = c_mask )
9797
9898
99+ @triton_op ("torchao::triton_fp8_gemm_1x128_128x128" , mutates_args = {})
99100def triton_fp8_gemm_1x128_128x128 (
100101 a : torch .Tensor , # (M, K)
101102 b : torch .Tensor , # (K, N)
102103 a_s : torch .Tensor , # (M, K // block_size)
103104 b_s : torch .Tensor , # (K // block_size, N // block_size)
104105 block_size : int = 128 ,
105106 out_dtype : torch .dtype = torch .float32 ,
106- ):
107+ ) -> torch . Tensor :
107108 # 'a' must be in row-major layout, 'b' must be in column-major layout
108109 assert _is_row_major (a ), "a must be row-major"
109110 assert _is_column_major (b ), "b must be column-major"
@@ -214,14 +215,15 @@ def triton_fp8_gemm_1x128_128x1_kernel(
214215 tl .store (c_ptrs , c , mask = c_mask )
215216
216217
218+ @triton_op ("torchao::triton_fp8_gemm_1x128_128x1" , mutates_args = {})
217219def triton_fp8_gemm_1x128_128x1 (
218220 a : torch .Tensor , # (M, K)
219221 b : torch .Tensor , # (K, N)
220222 a_s : torch .Tensor , # (M, K // block_size) reciprocals of scales
221223 b_s : torch .Tensor , # (K // block_size, N) reciprocals of scales
222224 block_size : int = 128 ,
223225 out_dtype : torch .dtype = torch .float32 ,
224- ):
226+ ) -> torch . Tensor :
225227 # 'a' must be in row-major layout, 'b' must be in column-major layout
226228 assert _is_row_major (a ), "a must be row-major"
227229 assert _is_column_major (b ), "b must be column-major"
0 commit comments