Skip to content

Commit a6dbf45

Browse files
[mxfp8 moe training] fix bug introduced in #3385 (#3417)
1 parent 3c3515a commit a6dbf45

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

torchao/prototype/mx_formats/utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,7 @@ def _to_mxfp8_dim1_kernel_wrapper(
151151
block_size,
152152
elem_dtype,
153153
hp_dtype,
154-
gemm_kernel_choice,
154+
kernel_preference,
155155
cast_kernel_choice,
156156
scale_calculation_mode: ScaleCalculationMode,
157157
):
@@ -187,7 +187,7 @@ def _to_mxfp8_dim1_kernel_wrapper(
187187
elem_dtype,
188188
block_size,
189189
hp_dtype,
190-
gemm_kernel_choice,
190+
kernel_preference,
191191
None,
192192
is_swizzled_scales,
193193
)
@@ -206,7 +206,7 @@ def _to_mxfp8_dim1_kernel_wrapper(
206206
elem_dtype,
207207
block_size,
208208
hp_dtype,
209-
gemm_kernel_choice,
209+
kernel_preference,
210210
None,
211211
is_swizzled_scales,
212212
)

0 commit comments

Comments
 (0)