Skip to content

Commit 11daf4a

Browse files
committed
use the float4 dtype in mxfp4 and nvfp4 tensors
Summary: Uses the `torch.float4_e2m1fn_x2` dtype in mxfp4 and nxfp4 torchao tensors. Requires pytorch/pytorch#169595, so we need to wait for the next PyTorch branch cut. Note: nvfp4 models in vllm currently hit an error https://gist.github.com/vkuzo/e1407ee68c9ebb8d0f67478aedd81b96 before or after this PR, so some more debugging to do before landing this. Test Plan: ``` CUDA_VISIBLE_DEVICES=6 time pytest test/prototype/mx_formats/ -s ``` Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: ee5d3a0 ghstack-comment-id: 3614417747 Pull-Request: #3440
1 parent 534bea5 commit 11daf4a

File tree

5 files changed

+26
-9
lines changed

5 files changed

+26
-9
lines changed

test/prototype/mx_formats/test_mx_tensor.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@ def assert_sqnr_gt_threshold(orig, new, threshold):
7676
prev_dims, K = data_hp.shape[:-1], data_hp.shape[-1]
7777
if elem_dtype is torch.float4_e2m1fn_x2:
7878
assert data_mx.qdata.shape == (*prev_dims, K // 2)
79+
assert data_mx.qdata.dtype == torch.float4_e2m1fn_x2
7980
else:
8081
assert data_mx.qdata.shape == (*prev_dims, K)
8182
assert data_mx.scale.shape == (*prev_dims, K // block_size)

test/prototype/mx_formats/test_nvfp4_tensor.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -392,8 +392,8 @@ def test_triton_nvfp4_quantize_equivalence(M, N, use_per_tensor_scale, dtype):
392392
)
393393

394394
torch.testing.assert_close(nvfp4_pt.scale.flatten(), nvfp4_triton.scale.flatten())
395-
pt_unpacked = unpack_uint4(nvfp4_pt.qdata)
396-
triton_unpacked = unpack_uint4(nvfp4_triton.qdata)
395+
pt_unpacked = unpack_uint4(nvfp4_pt.qdata.view(torch.uint8))
396+
triton_unpacked = unpack_uint4(nvfp4_triton.qdata.view(torch.uint8))
397397
torch.testing.assert_close(
398398
pt_unpacked,
399399
triton_unpacked,
@@ -611,3 +611,17 @@ def test_3d_transpose(dims, is_swizzled_scales):
611611
x_hp_t = x_hp.transpose(dims[0], dims[1])
612612
x_nvfp4_t = x_nvfp4.transpose(dims[0], dims[1])
613613
assert x_hp_t.shape == x_nvfp4_t.shape
614+
615+
616+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
617+
@pytest.mark.skipif(
618+
not torch_version_at_least("2.8.0"), reason="NVFP4 requires PyTorch 2.8+"
619+
)
620+
@pytest.mark.parametrize("use_triton_kernel", [False, True])
621+
def test_uses_fp4_qdata(use_triton_kernel):
622+
x_hp = torch.randn(2, 128, 256, device="cuda")
623+
# TODO also test triton kernel
624+
x_nvfp4 = NVFP4Tensor.to_nvfp4(
625+
x_hp, use_triton_kernel=use_triton_kernel, is_swizzled_scales=True
626+
)
627+
assert x_nvfp4.qdata.dtype == torch.float4_e2m1fn_x2

torchao/prototype/mx_formats/kernels.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1028,8 +1028,9 @@ def triton_quantize_nvfp4(
10281028
# reshape back to original shape
10291029
scales = scales.view(*orig_leading_dims, -1, padded_cols)
10301030
xq = xq.view(*orig_leading_dims, -1, N // 2)
1031+
xq = xq.view(torch.float4_e2m1fn_x2)
10311032

1032-
return scales, xq.view(torch.uint8)
1033+
return scales, xq
10331034

10341035
@triton_quantize_nvfp4.register_fake
10351036
def _(x, per_tensor_scale=None):
@@ -1043,7 +1044,7 @@ def _(x, per_tensor_scale=None):
10431044
scales = torch.empty(
10441045
padded_rows, padded_cols, device=x.device, dtype=torch.float8_e4m3fn
10451046
)
1046-
xq = torch.empty(M, N // 2, device=x.device, dtype=torch.uint8)
1047+
xq = torch.empty(M, N // 2, device=x.device, dtype=torch.float4_e2m1fn_x2)
10471048
return scales, xq
10481049

10491050
@triton_mx_block_rearrange.register_fake

torchao/prototype/mx_formats/mx_tensor.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -321,6 +321,7 @@ def to_mx(
321321
data_lp = data_lp.reshape(orig_shape)
322322
data_lp = f32_to_f4_unpacked(data_lp)
323323
data_lp = pack_uint4(data_lp)
324+
data_lp = data_lp.view(torch.float4_e2m1fn_x2)
324325
else:
325326
raise AssertionError("unsupported")
326327

@@ -382,7 +383,7 @@ def to_dtype(
382383
data_hp = data_hp.to(target_dtype).reshape(orig_shape)
383384
elif elem_dtype == torch.float4_e2m1fn_x2:
384385
# fp4
385-
f4_unpacked = unpack_uint4(data_lp)
386+
f4_unpacked = unpack_uint4(data_lp.view(torch.uint8))
386387
# for now we only have a cast to f32
387388
# TODO(future PR): add cast directly to bf16
388389
f32 = f4_unpacked_to_f32(f4_unpacked)
@@ -483,6 +484,7 @@ def __new__(
483484
torch.float8_e4m3fn,
484485
torch.float8_e5m2,
485486
torch.uint8,
487+
torch.float4_e2m1fn_x2,
486488
), "unsupported"
487489
self.qdata = qdata
488490
self.scale = scale_e8m0_bits

torchao/prototype/mx_formats/nvfp4_tensor.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -478,8 +478,8 @@ def _addmm_nvfp4_dispatch(
478478
# should_add_bias_separately = bias is not None
479479

480480
result = torch._scaled_mm(
481-
a.qdata.view(torch.float4_e2m1fn_x2),
482-
b.qdata.view(torch.float4_e2m1fn_x2),
481+
a.qdata,
482+
b.qdata,
483483
a_scale_blocked.view(torch.float8_e4m3fn),
484484
b_scale_blocked.view(torch.float8_e4m3fn),
485485
bias=None if should_add_bias_separately else bias,
@@ -685,7 +685,6 @@ def nvfp4_quantize(
685685
data_scaled = torch.clamp(data_scaled, -F4_E2M1_MAX, F4_E2M1_MAX)
686686
data_scaled = data_scaled.view(orig_shape)
687687
data_lp = f32_to_f4_unpacked(data_scaled)
688-
# TODO: NotImplementedError: "copy_kernel" not implemented for 'Float4_e2m1fn_x2'
689-
# data_lp = pack_uint4(data_lp).view(torch.float4_e2m1fn_x2)
690688
data_lp = pack_uint4(data_lp)
689+
data_lp = data_lp.view(torch.float4_e2m1fn_x2)
691690
return out_scales, data_lp

0 commit comments

Comments
 (0)