Skip to content

Commit f3ebaf9

Browse files
authored
FP8 Blockwise Training: triton_op for dense model (#3402)
added triton_op to two gemm kernels
1 parent 5a7588e commit f3ebaf9

File tree

1 file changed

+4
-2
lines changed
  • torchao/prototype/blockwise_fp8_training

1 file changed

+4
-2
lines changed

torchao/prototype/blockwise_fp8_training/kernels.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -96,14 +96,15 @@ def triton_fp8_gemm_1x128_128x128_kernel(
9696
tl.store(c_ptrs, c, mask=c_mask)
9797

9898

99+
@triton_op("torchao::triton_fp8_gemm_1x128_128x128", mutates_args={})
99100
def triton_fp8_gemm_1x128_128x128(
100101
a: torch.Tensor, # (M, K)
101102
b: torch.Tensor, # (K, N)
102103
a_s: torch.Tensor, # (M, K // block_size)
103104
b_s: torch.Tensor, # (K // block_size, N // block_size)
104105
block_size: int = 128,
105106
out_dtype: torch.dtype = torch.float32,
106-
):
107+
) -> torch.Tensor:
107108
# 'a' must be in row-major layout, 'b' must be in column-major layout
108109
assert _is_row_major(a), "a must be row-major"
109110
assert _is_column_major(b), "b must be column-major"
@@ -214,14 +215,15 @@ def triton_fp8_gemm_1x128_128x1_kernel(
214215
tl.store(c_ptrs, c, mask=c_mask)
215216

216217

218+
@triton_op("torchao::triton_fp8_gemm_1x128_128x1", mutates_args={})
217219
def triton_fp8_gemm_1x128_128x1(
218220
a: torch.Tensor, # (M, K)
219221
b: torch.Tensor, # (K, N)
220222
a_s: torch.Tensor, # (M, K // block_size) reciprocals of scales
221223
b_s: torch.Tensor, # (K // block_size, N) reciprocals of scales
222224
block_size: int = 128,
223225
out_dtype: torch.dtype = torch.float32,
224-
):
226+
) -> torch.Tensor:
225227
# 'a' must be in row-major layout, 'b' must be in column-major layout
226228
assert _is_row_major(a), "a must be row-major"
227229
assert _is_column_major(b), "b must be column-major"

0 commit comments

Comments
 (0)