Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 34 additions & 28 deletions test/prototype/mx_formats/test_nvfp4_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,13 @@
from torchao.quantization.utils import compute_error
from torchao.testing.utils import skip_if_rocm
from torchao.utils import (
get_current_accelerator_device,
is_sm_at_least_100,
torch_version_at_least,
)

torch.manual_seed(2)
_DEVICE = get_current_accelerator_device()

if not torch_version_at_least("2.8.0"):
pytest.skip("Unsupported PyTorch version", allow_module_level=True)
Expand All @@ -42,12 +44,12 @@
(torch.bfloat16, (1, 32, 64), False),
],
)
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@pytest.mark.skipif(not torch.accelerator.is_available(), reason="GPU not available")
@pytest.mark.skipif(
not torch_version_at_least("2.8.0"), reason="torch.compile requires PyTorch 2.8+"
)
def test_nvfp4_reconstruction(dtype, shape, use_per_tensor_scale):
x = torch.randn(shape, dtype=dtype, device="cuda")
x = torch.randn(shape, dtype=dtype, device=_DEVICE)
if use_per_tensor_scale:
tensor_amax = torch.max(torch.abs(x))
scale = per_tensor_amax_to_scale(tensor_amax)
Expand Down Expand Up @@ -113,14 +115,14 @@ def assert_sqnr_gt_threshold(orig, new, threshold):
@pytest.mark.skipif(
not torch_version_at_least("2.8.0"), reason="torch.compile requires PyTorch 2.8+"
)
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@pytest.mark.skipif(not torch.accelerator.is_available(), reason="GPU not available")
def test_nvfp4_swizzled_scales_construction(is_swizzled_scales, shape):
"""
Test that NVFP4Tensor can be constructed with swizzled scales and
that the _is_swizzled_scales flag is set correctly.
"""

data = torch.randn(*shape, device="cuda", dtype=torch.bfloat16)
data = torch.randn(*shape, device=_DEVICE, dtype=torch.bfloat16)

tensor = NVFP4Tensor.to_nvfp4(data, is_swizzled_scales=is_swizzled_scales)
assert tensor._is_swizzled_scales == is_swizzled_scales
Expand All @@ -146,7 +148,7 @@ def test_nvfp4_swizzled_scales_construction(is_swizzled_scales, shape):
pytest.param(1, slice(1024, 2048), id="slice_cols[1024:2048]_quarter"),
],
)
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@pytest.mark.skipif(not torch.accelerator.is_available(), reason="GPU not available")
@pytest.mark.skipif(
not torch_version_at_least("2.8.0"), reason="NVFP4 requires PyTorch 2.8+"
)
Expand All @@ -164,7 +166,7 @@ def test_nvfp4_swizzled_scales_slicing(slice_dim, slice_spec):
# For column slicing, need multiples of 64 columns for alignment
M, K = 128, 4096

data = torch.randn(M, K, device="cuda", dtype=torch.bfloat16)
data = torch.randn(M, K, device=_DEVICE, dtype=torch.bfloat16)

tensor = NVFP4Tensor.to_nvfp4(data, is_swizzled_scales=True)
assert tensor._is_swizzled_scales == True
Expand Down Expand Up @@ -240,7 +242,7 @@ def test_nvfp4_swizzled_scales_slicing(slice_dim, slice_spec):
),
],
)
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@pytest.mark.skipif(not torch.accelerator.is_available(), reason="GPU not available")
@pytest.mark.skipif(
not torch_version_at_least("2.8.0"), reason="NVFP4 requires PyTorch 2.8+"
)
Expand All @@ -250,7 +252,7 @@ def test_nvfp4_swizzled_scales_slicing_errors(slice_dim, slice_spec, expected_er
"""

M, K = 256, 4096
data = torch.randn(M, K, device="cuda", dtype=torch.bfloat16)
data = torch.randn(M, K, device=_DEVICE, dtype=torch.bfloat16)
tensor = NVFP4Tensor.to_nvfp4(data, is_swizzled_scales=True)

with pytest.raises(RuntimeError, match=expected_error):
Expand All @@ -260,7 +262,7 @@ def test_nvfp4_swizzled_scales_slicing_errors(slice_dim, slice_spec, expected_er
_ = tensor[:, slice_spec]


@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@pytest.mark.skipif(not torch.accelerator.is_available(), reason="GPU not available")
@pytest.mark.skipif(
not torch_version_at_least("2.8.0"), reason="NVFP4 requires PyTorch 2.8+"
)
Expand All @@ -270,7 +272,7 @@ def test_nvfp4_swizzled_scales_view_semantics():
"""

M, K = 256, 4096
data = torch.randn(M, K, device="cuda", dtype=torch.bfloat16)
data = torch.randn(M, K, device=_DEVICE, dtype=torch.bfloat16)
tensor = NVFP4Tensor.to_nvfp4(data, is_swizzled_scales=True)

# Test row slicing (should maintain views)
Expand All @@ -286,7 +288,7 @@ def test_nvfp4_swizzled_scales_view_semantics():
assert full_width_slice.qdata.data_ptr() == tensor.qdata.data_ptr()


@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@pytest.mark.skipif(not torch.accelerator.is_available(), reason="GPU not available")
@pytest.mark.skipif(
not torch_version_at_least("2.8.0"), reason="NVFP4 requires PyTorch 2.8+"
)
Expand All @@ -296,7 +298,7 @@ def test_nvfp4_swizzled_scales_serialization():
"""

M, K = 32, 64
data = torch.randn(M, K, device="cuda", dtype=torch.bfloat16)
data = torch.randn(M, K, device=_DEVICE, dtype=torch.bfloat16)

# Create tensor with swizzled scales
original_tensor = NVFP4Tensor.to_nvfp4(data, is_swizzled_scales=True)
Expand Down Expand Up @@ -327,7 +329,7 @@ def test_nvfp4_swizzled_scales_serialization():
torch.testing.assert_close(original_dq, reconstructed_dq, atol=1e-6, rtol=1e-6)


@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@pytest.mark.skipif(not torch.accelerator.is_available(), reason="GPU not available")
@pytest.mark.skipif(
not torch_version_at_least("2.8.0"), reason="NVFP4 requires PyTorch 2.8+"
)
Expand All @@ -337,7 +339,7 @@ def test_nvfp4_swizzled_scales_get_scales_method():
"""

M, K = 32, 64
data = torch.randn(M, K, device="cuda", dtype=torch.bfloat16)
data = torch.randn(M, K, device=_DEVICE, dtype=torch.bfloat16)

# Create tensors with both storage methods
regular_tensor = NVFP4Tensor.to_nvfp4(data, is_swizzled_scales=False)
Expand Down Expand Up @@ -371,7 +373,7 @@ def test_triton_nvfp4_quantize_equivalence(M, N, use_per_tensor_scale, dtype):
"""Test that Triton and PyTorch NVFP4 quantization produce equivalent results."""

torch.manual_seed(42)
x = torch.randn(M, N, dtype=dtype, device="cuda")
x = torch.randn(M, N, dtype=dtype, device=_DEVICE)

per_tensor_scale = None
if use_per_tensor_scale:
Expand Down Expand Up @@ -413,7 +415,7 @@ def test_triton_nvfp4_quantize_equivalence(M, N, use_per_tensor_scale, dtype):
)


@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@pytest.mark.skipif(not torch.accelerator.is_available(), reason="GPU not available")
@pytest.mark.skipif(
not torch_version_at_least("2.8.0"), reason="torch.compile requires PyTorch 2.8+"
)
Expand Down Expand Up @@ -454,7 +456,11 @@ def test_nvfp4_matmul_with_amax(
shapes: tuple,
):
# DYNAMIC mode requires SM100+, but WEIGHT_ONLY works on older GPUs
if quant_type == "dynamic" and not is_sm_at_least_100():
if (
quant_type == "dynamic"
and torch.cuda.is_available()
and not is_sm_at_least_100()
):
pytest.skip("CUDA capability >= 10.0 required for DYNAMIC float4 gemm")

if bias and inpt_dtype == torch.float32:
Expand All @@ -467,13 +473,13 @@ def test_nvfp4_matmul_with_amax(

# Create activation tensor
if use_gelu:
x = torch.randn(m, k, dtype=inpt_dtype, device="cuda")
x = torch.randn(m, k, dtype=inpt_dtype, device=_DEVICE)
A = torch.nn.functional.gelu(x)
else:
A = torch.randn(m, k, dtype=inpt_dtype, device="cuda")
A = torch.randn(m, k, dtype=inpt_dtype, device=_DEVICE)

B = torch.randn(n, k, dtype=inpt_dtype, device="cuda")
bias_tensor = torch.randn(n, dtype=inpt_dtype, device="cuda") if bias else None
B = torch.randn(n, k, dtype=inpt_dtype, device=_DEVICE)
bias_tensor = torch.randn(n, dtype=inpt_dtype, device=_DEVICE) if bias else None

# Compute reference
C_ref = F.linear(A, B, bias_tensor)
Expand Down Expand Up @@ -511,12 +517,12 @@ def test_nvfp4_matmul_with_amax(
)


@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@pytest.mark.skipif(not torch.accelerator.is_available(), reason="GPU not available")
@pytest.mark.skipif(
not torch_version_at_least("2.8.0"), reason="NVFP4 requires PyTorch 2.8+"
)
def test_nvfp4_to_copy():
x = NVFP4Tensor.to_nvfp4(torch.randn((32, 128))).cuda()
x = NVFP4Tensor.to_nvfp4(torch.randn((32, 128))).to(_DEVICE)
y = torch.ops.aten._to_copy(x, dtype=torch.bfloat16)
assert torch.equal(x.qdata, y.qdata)
assert torch.equal(x.scale, y.scale)
Expand All @@ -531,7 +537,7 @@ def test_nvfp4_to_copy():
assert y.dtype == torch.bfloat16


@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@pytest.mark.skipif(not torch.accelerator.is_available(), reason="GPU not available")
@pytest.mark.skipif(
not torch_version_at_least("2.8.0"), reason="NVFP4 requires PyTorch 2.8+"
)
Expand All @@ -551,14 +557,14 @@ def test_nvfp4_to_copy():
def test_scale_shape_matches_qdata(
transpose, use_triton_kernel, is_swizzled_scales, shape
):
if use_triton_kernel and not is_sm_at_least_100():
if use_triton_kernel and torch.cuda.is_available() and not is_sm_at_least_100():
pytest.skip("CUDA capability >= 10.0 required for nvfp4 triton kernel")
if use_triton_kernel and not is_swizzled_scales:
pytest.skip("triton kernel requires swizzled scales")

block_size = 16

x_hp = torch.randn(*shape, device="cuda")
x_hp = torch.randn(*shape, device=_DEVICE)
x = NVFP4Tensor.to_nvfp4(
x_hp, is_swizzled_scales=is_swizzled_scales, use_triton_kernel=use_triton_kernel
)
Expand Down Expand Up @@ -599,14 +605,14 @@ def test_scale_shape_matches_qdata(
)


@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@pytest.mark.skipif(not torch.accelerator.is_available(), reason="GPU not available")
@pytest.mark.skipif(
not torch_version_at_least("2.8.0"), reason="NVFP4 requires PyTorch 2.8+"
)
@pytest.mark.parametrize("dims", ((1, 2), (2, 1), (-1, -2), (-2, -1)))
@pytest.mark.parametrize("is_swizzled_scales", [True, False])
def test_3d_transpose(dims, is_swizzled_scales):
x_hp = torch.randn(2, 128, 256, device="cuda")
x_hp = torch.randn(2, 128, 256, device=_DEVICE)
x_nvfp4 = NVFP4Tensor.to_nvfp4(x_hp, is_swizzled_scales=is_swizzled_scales)
x_hp_t = x_hp.transpose(dims[0], dims[1])
x_nvfp4_t = x_nvfp4.transpose(dims[0], dims[1])
Expand Down
6 changes: 5 additions & 1 deletion test/prototype/test_spinquant.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,11 @@ def _init_model(name="7B", device="cpu", precision=torch.bfloat16):
return model.eval()


_AVAILABLE_DEVICES = ["cpu"] + (["cuda"] if torch.cuda.is_available() else [])
_AVAILABLE_DEVICES = (
["cpu"]
+ (["cuda"] if torch.cuda.is_available() else [])
+ (["xpu"] if torch.xpu.is_available() else [])
)


@pytest.mark.parametrize("device", _AVAILABLE_DEVICES)
Expand Down