From a80d78ce9436cec7a97fc30f534e0e774391a157 Mon Sep 17 00:00:00 2001 From: "Zeng, Xiangdong" Date: Mon, 1 Dec 2025 15:37:23 +0800 Subject: [PATCH 1/6] test/prototype/mx_formats/test_mx_tensor.py --- test/prototype/mx_formats/test_mx_tensor.py | 98 +++++++++++---------- 1 file changed, 50 insertions(+), 48 deletions(-) diff --git a/test/prototype/mx_formats/test_mx_tensor.py b/test/prototype/mx_formats/test_mx_tensor.py index 2b8c72ff91..716f8c0fc1 100644 --- a/test/prototype/mx_formats/test_mx_tensor.py +++ b/test/prototype/mx_formats/test_mx_tensor.py @@ -30,9 +30,11 @@ is_sm_at_least_89, is_sm_at_least_90, torch_version_at_least, + get_current_accelerator_device, ) torch.manual_seed(2) +_DEVICE = get_current_accelerator_device() if not torch_version_at_least("2.8.0"): pytest.skip("Unsupported PyTorch version", allow_module_level=True) @@ -81,42 +83,42 @@ def assert_sqnr_gt_threshold(orig, new, threshold): assert data_mx.scale.shape == (*prev_dims, K // block_size) -@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +@pytest.mark.skipif(not torch.accelerator.is_available(), reason="GPU not available") @pytest.mark.parametrize("elem_dtype", SUPPORTED_ELEM_DTYPES) def test_hello_world(elem_dtype): - data = torch.randn(8, 8, device="cuda", dtype=torch.bfloat16) + data = torch.randn(8, 8, device=_DEVICE, dtype=torch.bfloat16) block_size = 4 _test_mx(data, elem_dtype, block_size) -@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +@pytest.mark.skipif(not torch.accelerator.is_available(), reason="GPU not available") @pytest.mark.parametrize("scale_calculation_mode", [s for s in ScaleCalculationMode]) @pytest.mark.parametrize("elem_dtype", SUPPORTED_ELEM_DTYPES) def test_realistic_numerics(elem_dtype, scale_calculation_mode): - data = torch.randn(128, 128, device="cuda", dtype=torch.bfloat16) + data = torch.randn(128, 128, device=_DEVICE, dtype=torch.bfloat16) block_size = 32 _test_mx(data, elem_dtype, block_size, scale_calculation_mode) -@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +@pytest.mark.skipif(not torch.accelerator.is_available(), reason="GPU not available") @pytest.mark.parametrize("elem_dtype", SUPPORTED_ELEM_DTYPES) def test_all_zeros(elem_dtype): - data = torch.zeros(4, 4, device="cuda", dtype=torch.bfloat16) + data = torch.zeros(4, 4, device=_DEVICE, dtype=torch.bfloat16) block_size = 4 _test_mx(data, elem_dtype, block_size) -@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +@pytest.mark.skipif(not torch.accelerator.is_available(), reason="GPU not available") @pytest.mark.parametrize("elem_dtype", SUPPORTED_ELEM_DTYPES) def test_some_zeros(elem_dtype): - data = torch.randn(4, 4, device="cuda", dtype=torch.bfloat16) + data = torch.randn(4, 4, device=_DEVICE, dtype=torch.bfloat16) data[0, :] = 0.0 data[:, 2] = 0.0 block_size = 4 _test_mx(data, elem_dtype, block_size) -@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +@pytest.mark.skipif(not torch.accelerator.is_available(), reason="GPU not available") def test_to_mx_rceil(): # nan # fmt: off @@ -325,7 +327,7 @@ def test_to_mx_rceil(): torch.testing.assert_close(data_mx.qdata, ground_truth_fp8) -@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +@pytest.mark.skipif(not torch.accelerator.is_available(), reason="GPU not available") @pytest.mark.parametrize("elem_dtype", SUPPORTED_ELEM_DTYPES) def test_exponent_nan_in(elem_dtype): """ @@ -333,7 +335,7 @@ def test_exponent_nan_in(elem_dtype): value is set to is NaN """ tensor_hp = torch.tensor( - [float("nan"), 1, 2, 3, 4, 5, 6, 7], device="cuda", dtype=torch.bfloat16 + [float("nan"), 1, 2, 3, 4, 5, 6, 7], device=_DEVICE, dtype=torch.bfloat16 ) block_size = 4 tensor_mx = MXTensor.to_mx(tensor_hp, elem_dtype, block_size) @@ -341,29 +343,29 @@ def test_exponent_nan_in(elem_dtype): assert not torch.any(torch.isnan(tensor_mx.scale[1:])) -@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +@pytest.mark.skipif(not torch.accelerator.is_available(), reason="GPU not available") @pytest.mark.parametrize("elem_dtype", SUPPORTED_ELEM_DTYPES) def test_exponent_nan_out(elem_dtype): """ If block exponent value is NaN, the MX tensor block value is NaN """ scale_e8m0 = torch.tensor( - [float("nan"), 1.0], dtype=torch.float8_e8m0fnu, device="cuda" + [float("nan"), 1.0], dtype=torch.float8_e8m0fnu, device=_DEVICE ) block_size = 4 if elem_dtype in (torch.float8_e4m3fn, torch.float8_e5m2): data_bits = torch.tensor( - [0, 1, 2, 3, 4, 5, 6, 7], dtype=elem_dtype, device="cuda" + [0, 1, 2, 3, 4, 5, 6, 7], dtype=elem_dtype, device=_DEVICE ) # noqa: E501 elif elem_dtype in (DTYPE_FP6_E2M3, DTYPE_FP6_E3M2): data_bits = torch.tensor( - [0, 1, 2, 3, 4, 5, 6, 7], dtype=torch.uint8, device="cuda" + [0, 1, 2, 3, 4, 5, 6, 7], dtype=torch.uint8, device=_DEVICE ) # noqa: E501 elif elem_dtype == torch.float4_e2m1fn_x2: data_bits = torch.tensor( - [0, 1, 2, 3, 4, 5, 6, 7], dtype=torch.uint8, device="cuda" + [0, 1, 2, 3, 4, 5, 6, 7], dtype=torch.uint8, device=_DEVICE ) # noqa: E501 data_bits = pack_uint4(data_bits) else: @@ -384,7 +386,7 @@ def test_exponent_nan_out(elem_dtype): assert not torch.any(torch.isnan(tensor_hp.flatten()[4:])) -@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +@pytest.mark.skipif(not torch.accelerator.is_available(), reason="GPU not available") @pytest.mark.parametrize("elem_dtype", SUPPORTED_ELEM_DTYPES) def test_ranks(elem_dtype): """ @@ -393,11 +395,11 @@ def test_ranks(elem_dtype): B = 4 shapes = ((B * 4,), (B * 4, 4), (B * 4, 4, 4), (B * 4, 4, 4, 4)) for s in shapes: - tensor_hp = torch.randn(*s, device="cuda", dtype=torch.bfloat16) + tensor_hp = torch.randn(*s, device=_DEVICE, dtype=torch.bfloat16) _test_mx(tensor_hp, elem_dtype, B) -@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +@pytest.mark.skipif(not torch.accelerator.is_available(), reason="GPU not available") @pytest.mark.parametrize("elem_dtype", SUPPORTED_ELEM_DTYPES) @pytest.mark.parametrize("B", [1, 4, 32]) def test_block_sizes(elem_dtype, B): @@ -408,11 +410,11 @@ def test_block_sizes(elem_dtype, B): pytest.skip("unsupported configuration") elif B % 4 != 0 and elem_dtype in [DTYPE_FP6_E2M3, DTYPE_FP6_E3M2]: pytest.skip("unsupported configuration") - tensor_hp = torch.randn(B, device="cuda", dtype=torch.bfloat16) + tensor_hp = torch.randn(B, device=_DEVICE, dtype=torch.bfloat16) _test_mx(tensor_hp, elem_dtype, B) -@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +@pytest.mark.skipif(not torch.accelerator.is_available(), reason="GPU not available") @pytest.mark.parametrize("elem_dtype", SUPPORTED_ELEM_DTYPES) def test_transpose(elem_dtype): """ @@ -420,7 +422,7 @@ def test_transpose(elem_dtype): """ M, K = 128, 256 block_size = 32 - tensor_hp = torch.randn(M, K, device="cuda", dtype=torch.bfloat16) + tensor_hp = torch.randn(M, K, device=_DEVICE, dtype=torch.bfloat16) tensor_mx = MXTensor.to_mx( tensor_hp, elem_dtype, @@ -435,18 +437,18 @@ def test_transpose(elem_dtype): torch.testing.assert_close(tensor_mx_dq_t, tensor_mx_t_dq, atol=0, rtol=0) -@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +@pytest.mark.skipif(not torch.accelerator.is_available(), reason="GPU not available") @pytest.mark.parametrize("elem_dtype", SUPPORTED_ELEM_DTYPES) def test_view(elem_dtype): - x = torch.randn(1, 2, 4, device="cuda") + x = torch.randn(1, 2, 4, device=_DEVICE) block_size = 4 x_mx = MXTensor.to_mx(x, elem_dtype, block_size) x_mx_2 = x_mx.view(2, 4) # noqa: F841 -@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +@pytest.mark.skipif(not torch.accelerator.is_available(), reason="GPU not available") def test_clone(): - data = torch.randn(8, 8, device="cuda", dtype=torch.bfloat16) + data = torch.randn(8, 8, device=_DEVICE, dtype=torch.bfloat16) block_size = 4 data_mx = MXTensor.to_mx(data, torch.float8_e4m3fn, block_size) data_mx_c = data_mx.clone() @@ -458,7 +460,7 @@ def test_clone(): ) -@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +@pytest.mark.skipif(not torch.accelerator.is_available(), reason="GPU not available") @pytest.mark.parametrize("elem_dtype", SUPPORTED_ELEM_DTYPES) @pytest.mark.parametrize("hp_dtype", [torch.float32, torch.bfloat16]) @pytest.mark.parametrize("all_zeros", [False, True]) @@ -467,15 +469,15 @@ def test_to_mx_from_mx_compile_numerics(elem_dtype, hp_dtype, all_zeros): Verifies that compile does not change numerics of MX casts """ if elem_dtype in (torch.float8_e4m3fn, torch.float8_e5m2): - if not is_sm_at_least_89(): + if torch.cuda.is_available() and not is_sm_at_least_89(): # separate ifs because flake8 is outsmarting me pytest.skip("CUDA capability >= 8.9 required for float8 in triton") shape = 4, 8 if not all_zeros: - x = torch.randn(*shape, dtype=hp_dtype, device="cuda") + x = torch.randn(*shape, dtype=hp_dtype, device=_DEVICE) else: - x = torch.zeros(*shape, dtype=hp_dtype, device="cuda") + x = torch.zeros(*shape, dtype=hp_dtype, device=_DEVICE) block_size = 4 to_mx_c = torch.compile(MXTensor.to_mx, fullgraph=True) @@ -508,9 +510,9 @@ def test_to_mx_from_mx_compile_numerics(elem_dtype, hp_dtype, all_zeros): torch.testing.assert_close(x_mx_dq, x_mx_c_dq, atol=0, rtol=0) -@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +@pytest.mark.skipif(not torch.accelerator.is_available(), reason="GPU not available") @pytest.mark.skipif( - not is_sm_at_least_89(), + torch.cuda.is_available() and not is_sm_at_least_89(), reason="float8 in triton requires CUDA capability 8.9 or greater", ) def test_to_mx_inductor_single_kernel(): @@ -520,15 +522,15 @@ def test_to_mx_inductor_single_kernel(): """ # TODO(future PR): add fp4 and fp6 here # TODO(#1773): add swizzled scale format here - x = torch.randn(2048, 2048, dtype=torch.bfloat16, device="cuda") + x = torch.randn(2048, 2048, dtype=torch.bfloat16, device=_DEVICE) block_size = 32 to_mx_c = torch.compile(MXTensor.to_mx, fullgraph=True) out, code = run_and_get_code(to_mx_c, x, torch.float8_e4m3fn, block_size) FileCheck().check("def call(").check_count(".run(", 1, exactly=True).run(code[0]) -@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") -@pytest.mark.skipIf(not is_sm_at_least_90(), "Need sm90+") +@pytest.mark.skipif(not torch.accelerator.is_available(), reason="GPU not available") +@pytest.mark.skipIf(torch.cuda.is_available() and not is_sm_at_least_90(), "Need sm90+") def test_index_select(): """ test that `x_0 = x[0]` works when `x` is a 3D `MXTensor`. This is @@ -538,7 +540,7 @@ def test_index_select(): """ E, K, N = 128, 256, 512 - x = torch.randn(E, N, K, device="cuda", dtype=torch.bfloat16) + x = torch.randn(E, N, K, device=_DEVICE, dtype=torch.bfloat16) x_mx = MXTensor.to_mx(x, torch.float8_e4m3fn, 32) x_mx_1 = x_mx[1] @@ -547,9 +549,9 @@ def test_index_select(): ) -@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +@pytest.mark.skipif(not torch.accelerator.is_available(), reason="GPU not available") @pytest.mark.skipif( - not is_sm_at_least_89(), + torch.cuda.is_available() and not is_sm_at_least_89(), reason="float8 in triton requires CUDA capability 8.9 or greater", ) def test_cast_to_float8_e4m3fn_saturation_behavior(): @@ -564,7 +566,7 @@ def test_cast_to_float8_e4m3fn_saturation_behavior(): -1 * max_val, ], dtype=torch.bfloat16, - device="cuda", + device=_DEVICE, ) # create example data outside the representable range @@ -574,7 +576,7 @@ def test_cast_to_float8_e4m3fn_saturation_behavior(): -1 * (max_val * 2), ], dtype=torch.bfloat16, - device="cuda", + device=_DEVICE, ) # verify that in eager mode PyTorch casting to float8 is unsaturated @@ -611,14 +613,14 @@ def to_f8(x): ], ) @pytest.mark.parametrize( - "use_triton_kernel", [False, True] if torch.cuda.is_available() else [False] + "use_triton_kernel", [False, True] if torch.accelerator.is_available() else [False] ) @pytest.mark.skipif( not torch_version_at_least("2.8.0"), reason="torch.compile requires PyTorch 2.8+" ) def test_to_blocked_from_blocked_roundtrip(shape, use_triton_kernel: bool): rows, cols = shape - device = "cuda" if torch.cuda.is_available() else "cpu" + device = _DEVICE if torch.accelerator.is_available() else "cpu" original = torch.randint(0, 255, (rows, cols), device=device, dtype=torch.uint8) @@ -634,8 +636,8 @@ def test_to_blocked_from_blocked_roundtrip(shape, use_triton_kernel: bool): ) -@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") -@pytest.mark.skipif(not torch_version_at_least("2.8.0"), reason="requires PyTorch 2.8+") +@pytest.mark.skipif(not torch.accelerator.is_available(), reason="GPU not available") +@pytest.mark.skipif(torch.cuda.is_available() and not torch_version_at_least("2.8.0"), reason="requires PyTorch 2.8+") @pytest.mark.parametrize("transpose", [False, True]) @pytest.mark.parametrize( "shape", @@ -650,7 +652,7 @@ def test_scale_shape_matches_qdata(transpose, shape): block_size = 32 - x_hp = torch.randn(*shape, device="cuda") + x_hp = torch.randn(*shape, device=_DEVICE) x = MXTensor.to_mx( x_hp, torch.float8_e4m3fn, @@ -688,8 +690,8 @@ def test_scale_shape_matches_qdata(transpose, shape): ) -@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") -@pytest.mark.skipif(not torch_version_at_least("2.8.0"), reason="requires PyTorch 2.8+") +@pytest.mark.skipif(not torch.accelerator.is_available(), reason="GPU not available") +@pytest.mark.skipif(torch.cuda.is_available() and not torch_version_at_least("2.8.0"), reason="requires PyTorch 2.8+") @pytest.mark.parametrize("elem_dtype", (torch.float8_e4m3fn, torch.float4_e2m1fn_x2)) @pytest.mark.parametrize("transpose", [False, True]) @pytest.mark.parametrize( @@ -705,7 +707,7 @@ def test_swizzle(elem_dtype, transpose, shape): block_size = 32 - x_hp = torch.randn(*shape, device="cuda") + x_hp = torch.randn(*shape, device=_DEVICE) x = MXTensor.to_mx( x_hp, elem_dtype, From b775d33c1f2c6c1c3910c038259ec20e72ac4373 Mon Sep 17 00:00:00 2001 From: "Zeng, Xiangdong" Date: Mon, 1 Dec 2025 16:12:44 +0800 Subject: [PATCH 2/6] test/prototype/mx_formats/test_mx_tensor.py --- torchao/prototype/mx_formats/kernels.py | 1 - 1 file changed, 1 deletion(-) diff --git a/torchao/prototype/mx_formats/kernels.py b/torchao/prototype/mx_formats/kernels.py index b4cd192244..0f9985f7f6 100644 --- a/torchao/prototype/mx_formats/kernels.py +++ b/torchao/prototype/mx_formats/kernels.py @@ -24,7 +24,6 @@ logger = logging.getLogger(__name__) - def get_bits(x: torch.Tensor) -> str: bits_per_byte = 8 # Numpy has a nice function to get the string representation of binary. From 5c80bd163f63fbf06aca755de04ac8eae098a762 Mon Sep 17 00:00:00 2001 From: "Zeng, Xiangdong" Date: Mon, 1 Dec 2025 16:23:00 +0800 Subject: [PATCH 3/6] test/prototype/mx_formats/test_kernels.py --- test/prototype/mx_formats/test_kernels.py | 38 ++++++++++++----------- 1 file changed, 20 insertions(+), 18 deletions(-) diff --git a/test/prototype/mx_formats/test_kernels.py b/test/prototype/mx_formats/test_kernels.py index 4b6586b385..d77a8f22ff 100644 --- a/test/prototype/mx_formats/test_kernels.py +++ b/test/prototype/mx_formats/test_kernels.py @@ -46,9 +46,11 @@ is_sm_at_least_89, is_sm_at_least_100, torch_version_at_least, + get_current_accelerator_device, ) torch.manual_seed(0) +_DEVICE = get_current_accelerator_device() if not torch_version_at_least("2.8.0"): pytest.skip("Unsupported PyTorch version", allow_module_level=True) @@ -395,9 +397,9 @@ def test_fp6_values(dtype_name): [ "cpu", pytest.param( - "cuda", + _DEVICE, marks=pytest.mark.skipif( - not torch.cuda.is_available(), reason="CUDA not available" + not torch.accelerator.is_available(), reason="GPU not available" ), ), ], @@ -439,13 +441,13 @@ def triton_to_mxfp8_dim0_reference( @pytest.mark.skipif(not has_triton(), reason="unsupported without triton") @pytest.mark.skipif( - not is_sm_at_least_89(), + torch.cuda.is_available() and not is_sm_at_least_89(), reason="float8 in triton requires CUDA capability 8.9 or greater", ) @pytest.mark.parametrize("M", (128, 256)) @pytest.mark.parametrize("K", (128, 256)) def test_triton_mxfp8_dim1_randn(M, K): - x = torch.randn(M, K, dtype=torch.bfloat16, device="cuda") + x = torch.randn(M, K, dtype=torch.bfloat16, device=_DEVICE) x_mx_ref, x_s_ref = triton_to_mxfp8_dim1_reference(x, block_size=32) x_mx_t, x_s_t = triton_to_mxfp8_dim1(x, inner_block_size=32) torch.testing.assert_close(x_mx_t, x_mx_ref, rtol=0, atol=0) @@ -454,13 +456,13 @@ def test_triton_mxfp8_dim1_randn(M, K): @pytest.mark.skipif(not has_triton(), reason="unsupported without triton") @pytest.mark.skipif( - not is_sm_at_least_100(), + torch.cuda.is_available() and not is_sm_at_least_100(), reason="mxfp8 requires CUDA capability 10.0 or greater", ) @pytest.mark.parametrize("M", (128, 256)) @pytest.mark.parametrize("K", (128, 256)) def test_triton_mxfp8_dim0_randn(M, K): - x = torch.randn(M, K, dtype=torch.bfloat16, device="cuda") + x = torch.randn(M, K, dtype=torch.bfloat16, device=_DEVICE) x_mx_ref, x_s_ref = triton_to_mxfp8_dim0_reference(x, block_size=32) x_mx_t, x_s_t = triton_to_mxfp8_dim0(x, inner_block_size=32) torch.testing.assert_close(x_mx_t, x_mx_ref, rtol=0, atol=0) @@ -469,11 +471,11 @@ def test_triton_mxfp8_dim0_randn(M, K): @pytest.mark.skipif(not has_triton(), reason="unsupported without triton") @pytest.mark.skipif( - not is_sm_at_least_100(), + torch.cuda.is_available() and not is_sm_at_least_100(), reason="mxfp8 requires CUDA capability 10.0 or greater", ) def test_triton_mxfp8_dim0_zeros(): - x = torch.zeros(128, 256, dtype=torch.bfloat16, device="cuda") + x = torch.zeros(128, 256, dtype=torch.bfloat16, device=_DEVICE) x_mx_ref, x_s_ref = triton_to_mxfp8_dim0_reference(x, block_size=32) x_mx_t, x_s_t = triton_to_mxfp8_dim0(x, inner_block_size=32) assert not x_mx_t.isnan().any(), "quantized tensor should not contain NaNs" @@ -483,14 +485,14 @@ def test_triton_mxfp8_dim0_zeros(): @pytest.mark.skipif(not has_triton(), reason="unsupported without triton") @pytest.mark.skipif( - not is_sm_at_least_100(), + torch.cuda.is_available() and not is_sm_at_least_100(), reason="mxfp8 requires CUDA capability 10.0 or greater", ) @pytest.mark.parametrize("M", (128, 256)) @pytest.mark.parametrize("K", (128, 256)) @pytest.mark.parametrize("orig_dtype", (torch.float32, torch.bfloat16)) def test_triton_mxfp8_dequant_dim0(M, K, orig_dtype): - x = torch.zeros(M, K, dtype=orig_dtype, device="cuda") + x = torch.zeros(M, K, dtype=orig_dtype, device=_DEVICE) block_size = 32 x_data, x_scales = triton_to_mxfp8_dim0_reference(x, block_size=32) hp_ref = to_dtype( @@ -504,7 +506,7 @@ def test_triton_mxfp8_dequant_dim0(M, K, orig_dtype): torch.testing.assert_close(hp_t, hp_ref, rtol=0, atol=0) -@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +@pytest.mark.skipif(not torch.accelerator.is_available(), reason="GPU not available") @pytest.mark.parametrize( "shape", [ @@ -519,14 +521,14 @@ def test_triton_mxfp8_dequant_dim0(M, K, orig_dtype): ], ) def test_rearrange(shape): - scales = torch.randint(256, size=shape, device="cuda", dtype=torch.uint8) + scales = torch.randint(256, size=shape, device=_DEVICE, dtype=torch.uint8) eager = to_blocked(scales, False) triton = to_blocked(scales, True) torch.testing.assert_close(eager, triton, atol=0, rtol=0) @pytest.mark.skipif( - not is_sm_at_least_100(), + torch.cuda.is_available() and not is_sm_at_least_100(), reason="MXFP8 requires CUDA capability 10.0 or greater", ) @pytest.mark.parametrize("M", (32, 256)) @@ -545,7 +547,7 @@ def test_cuda_mx_dim1_numerics(M, K, input_dtype, scaling_mode): # Use disinct incrementing values from 0 to M*K-1 to make debugging easier. x = ( - torch.arange(0, M * K, dtype=input_dtype, device="cuda") + torch.arange(0, M * K, dtype=input_dtype, device=_DEVICE) .reshape(M, K) .contiguous() ) @@ -574,7 +576,7 @@ def test_cuda_mx_dim1_numerics(M, K, input_dtype, scaling_mode): @pytest.mark.skipif( - not is_sm_at_least_100(), + torch.cuda.is_available() and not is_sm_at_least_100(), reason="MXFP8 requires CUDA capability 10.0 or greater", ) def test_cuda_mx_dim0_not_supported(): @@ -583,7 +585,7 @@ def test_cuda_mx_dim0_not_supported(): M, K = 64, 64 block_size = 32 x = ( - torch.arange(0, M * K, dtype=torch.bfloat16, device="cuda") + torch.arange(0, M * K, dtype=torch.bfloat16, device=_DEVICE) .reshape(M, K) .contiguous() ) @@ -598,7 +600,7 @@ def test_cuda_mx_dim0_not_supported(): @pytest.mark.skipif( - not is_sm_at_least_100(), + torch.cuda.is_available() and not is_sm_at_least_100(), reason="MXFP8 requires CUDA capability 10.0 or greater", ) def test_cuda_mx_dim1_invalid_block_size(): @@ -606,7 +608,7 @@ def test_cuda_mx_dim1_invalid_block_size(): M, K = 64, 64 x = ( - torch.arange(0, M * K, dtype=torch.bfloat16, device="cuda") + torch.arange(0, M * K, dtype=torch.bfloat16, device=_DEVICE) .reshape(M, K) .contiguous() ) From 06f9fed6dd68eeb69a82ce4c16b82223a75a51a7 Mon Sep 17 00:00:00 2001 From: "Zeng, Xiangdong" Date: Mon, 1 Dec 2025 16:42:55 +0800 Subject: [PATCH 4/6] test/prototype/mx_formats/test_kernels.py --- test/prototype/mx_formats/test_kernels.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/prototype/mx_formats/test_kernels.py b/test/prototype/mx_formats/test_kernels.py index d77a8f22ff..f4185abae4 100644 --- a/test/prototype/mx_formats/test_kernels.py +++ b/test/prototype/mx_formats/test_kernels.py @@ -485,7 +485,7 @@ def test_triton_mxfp8_dim0_zeros(): @pytest.mark.skipif(not has_triton(), reason="unsupported without triton") @pytest.mark.skipif( - torch.cuda.is_available() and not is_sm_at_least_100(), + not is_sm_at_least_100(), reason="mxfp8 requires CUDA capability 10.0 or greater", ) @pytest.mark.parametrize("M", (128, 256)) From 360945ca5a47822c43d42729a446945aae798854 Mon Sep 17 00:00:00 2001 From: "Zeng, Xiangdong" Date: Mon, 1 Dec 2025 16:57:58 +0800 Subject: [PATCH 5/6] test/prototype/mx_formats/test_kernels.py --- test/prototype/mx_formats/test_kernels.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/test/prototype/mx_formats/test_kernels.py b/test/prototype/mx_formats/test_kernels.py index f4185abae4..8e84608e82 100644 --- a/test/prototype/mx_formats/test_kernels.py +++ b/test/prototype/mx_formats/test_kernels.py @@ -528,7 +528,7 @@ def test_rearrange(shape): @pytest.mark.skipif( - torch.cuda.is_available() and not is_sm_at_least_100(), + not is_sm_at_least_100(), reason="MXFP8 requires CUDA capability 10.0 or greater", ) @pytest.mark.parametrize("M", (32, 256)) @@ -576,7 +576,7 @@ def test_cuda_mx_dim1_numerics(M, K, input_dtype, scaling_mode): @pytest.mark.skipif( - torch.cuda.is_available() and not is_sm_at_least_100(), + not is_sm_at_least_100(), reason="MXFP8 requires CUDA capability 10.0 or greater", ) def test_cuda_mx_dim0_not_supported(): @@ -600,7 +600,7 @@ def test_cuda_mx_dim0_not_supported(): @pytest.mark.skipif( - torch.cuda.is_available() and not is_sm_at_least_100(), + not is_sm_at_least_100(), reason="MXFP8 requires CUDA capability 10.0 or greater", ) def test_cuda_mx_dim1_invalid_block_size(): From dd6ae61ed2bf2b61cba93f93406b7f677a89443e Mon Sep 17 00:00:00 2001 From: "Zeng, Xiangdong" Date: Mon, 1 Dec 2025 17:07:48 +0800 Subject: [PATCH 6/6] fix format issue --- test/prototype/mx_formats/test_kernels.py | 2 +- test/prototype/mx_formats/test_mx_tensor.py | 12 +++++++++--- torchao/prototype/mx_formats/kernels.py | 1 + 3 files changed, 11 insertions(+), 4 deletions(-) diff --git a/test/prototype/mx_formats/test_kernels.py b/test/prototype/mx_formats/test_kernels.py index 8e84608e82..4e2c03eb70 100644 --- a/test/prototype/mx_formats/test_kernels.py +++ b/test/prototype/mx_formats/test_kernels.py @@ -43,10 +43,10 @@ from torchao.prototype.mx_formats.mx_tensor import ScaleCalculationMode, to_dtype, to_mx from torchao.prototype.mx_formats.utils import to_blocked from torchao.utils import ( + get_current_accelerator_device, is_sm_at_least_89, is_sm_at_least_100, torch_version_at_least, - get_current_accelerator_device, ) torch.manual_seed(0) diff --git a/test/prototype/mx_formats/test_mx_tensor.py b/test/prototype/mx_formats/test_mx_tensor.py index 716f8c0fc1..922f44848d 100644 --- a/test/prototype/mx_formats/test_mx_tensor.py +++ b/test/prototype/mx_formats/test_mx_tensor.py @@ -27,10 +27,10 @@ from torchao.quantization.quantize_.common import KernelPreference from torchao.quantization.utils import compute_error from torchao.utils import ( + get_current_accelerator_device, is_sm_at_least_89, is_sm_at_least_90, torch_version_at_least, - get_current_accelerator_device, ) torch.manual_seed(2) @@ -637,7 +637,10 @@ def test_to_blocked_from_blocked_roundtrip(shape, use_triton_kernel: bool): @pytest.mark.skipif(not torch.accelerator.is_available(), reason="GPU not available") -@pytest.mark.skipif(torch.cuda.is_available() and not torch_version_at_least("2.8.0"), reason="requires PyTorch 2.8+") +@pytest.mark.skipif( + torch.cuda.is_available() and not torch_version_at_least("2.8.0"), + reason="requires PyTorch 2.8+", +) @pytest.mark.parametrize("transpose", [False, True]) @pytest.mark.parametrize( "shape", @@ -691,7 +694,10 @@ def test_scale_shape_matches_qdata(transpose, shape): @pytest.mark.skipif(not torch.accelerator.is_available(), reason="GPU not available") -@pytest.mark.skipif(torch.cuda.is_available() and not torch_version_at_least("2.8.0"), reason="requires PyTorch 2.8+") +@pytest.mark.skipif( + torch.cuda.is_available() and not torch_version_at_least("2.8.0"), + reason="requires PyTorch 2.8+", +) @pytest.mark.parametrize("elem_dtype", (torch.float8_e4m3fn, torch.float4_e2m1fn_x2)) @pytest.mark.parametrize("transpose", [False, True]) @pytest.mark.parametrize( diff --git a/torchao/prototype/mx_formats/kernels.py b/torchao/prototype/mx_formats/kernels.py index 0f9985f7f6..b4cd192244 100644 --- a/torchao/prototype/mx_formats/kernels.py +++ b/torchao/prototype/mx_formats/kernels.py @@ -24,6 +24,7 @@ logger = logging.getLogger(__name__) + def get_bits(x: torch.Tensor) -> str: bits_per_byte = 8 # Numpy has a nice function to get the string representation of binary.