diff --git a/test/prototype/mx_formats/test_kernels.py b/test/prototype/mx_formats/test_kernels.py index 254b749767..d9c00a4647 100644 --- a/test/prototype/mx_formats/test_kernels.py +++ b/test/prototype/mx_formats/test_kernels.py @@ -43,6 +43,7 @@ from torchao.prototype.mx_formats.mx_tensor import ScaleCalculationMode, to_dtype, to_mx from torchao.prototype.mx_formats.utils import to_blocked from torchao.utils import ( + get_current_accelerator_device, is_cuda_version_at_least, is_sm_at_least_89, is_sm_at_least_100, @@ -50,6 +51,7 @@ ) torch.manual_seed(0) +_DEVICE = get_current_accelerator_device() if not torch_version_at_least("2.8.0"): pytest.skip("Unsupported PyTorch version", allow_module_level=True) @@ -396,9 +398,9 @@ def test_fp6_values(dtype_name): [ "cpu", pytest.param( - "cuda", + _DEVICE, marks=pytest.mark.skipif( - not torch.cuda.is_available(), reason="CUDA not available" + not torch.accelerator.is_available(), reason="GPU not available" ), ), ], @@ -440,13 +442,13 @@ def triton_to_mxfp8_dim0_reference( @pytest.mark.skipif(not has_triton(), reason="unsupported without triton") @pytest.mark.skipif( - not is_sm_at_least_89(), + torch.cuda.is_available() and not is_sm_at_least_89(), reason="float8 in triton requires CUDA capability 8.9 or greater", ) @pytest.mark.parametrize("M", (128, 256)) @pytest.mark.parametrize("K", (128, 256)) def test_triton_mxfp8_dim1_randn(M, K): - x = torch.randn(M, K, dtype=torch.bfloat16, device="cuda") + x = torch.randn(M, K, dtype=torch.bfloat16, device=_DEVICE) x_mx_ref, x_s_ref = triton_to_mxfp8_dim1_reference(x, block_size=32) x_mx_t, x_s_t = triton_to_mxfp8_dim1(x, inner_block_size=32) torch.testing.assert_close(x_mx_t, x_mx_ref, rtol=0, atol=0) @@ -455,13 +457,13 @@ def test_triton_mxfp8_dim1_randn(M, K): @pytest.mark.skipif(not has_triton(), reason="unsupported without triton") @pytest.mark.skipif( - not is_sm_at_least_100(), + torch.cuda.is_available() and not is_sm_at_least_100(), reason="mxfp8 requires CUDA capability 10.0 or greater", ) @pytest.mark.parametrize("M", (128, 256)) @pytest.mark.parametrize("K", (128, 256)) def test_triton_mxfp8_dim0_randn(M, K): - x = torch.randn(M, K, dtype=torch.bfloat16, device="cuda") + x = torch.randn(M, K, dtype=torch.bfloat16, device=_DEVICE) x_mx_ref, x_s_ref = triton_to_mxfp8_dim0_reference(x, block_size=32) x_mx_t, x_s_t = triton_to_mxfp8_dim0(x, inner_block_size=32) torch.testing.assert_close(x_mx_t, x_mx_ref, rtol=0, atol=0) @@ -470,11 +472,11 @@ def test_triton_mxfp8_dim0_randn(M, K): @pytest.mark.skipif(not has_triton(), reason="unsupported without triton") @pytest.mark.skipif( - not is_sm_at_least_100(), + torch.cuda.is_available() and not is_sm_at_least_100(), reason="mxfp8 requires CUDA capability 10.0 or greater", ) def test_triton_mxfp8_dim0_zeros(): - x = torch.zeros(128, 256, dtype=torch.bfloat16, device="cuda") + x = torch.zeros(128, 256, dtype=torch.bfloat16, device=_DEVICE) x_mx_ref, x_s_ref = triton_to_mxfp8_dim0_reference(x, block_size=32) x_mx_t, x_s_t = triton_to_mxfp8_dim0(x, inner_block_size=32) assert not x_mx_t.isnan().any(), "quantized tensor should not contain NaNs" @@ -491,7 +493,7 @@ def test_triton_mxfp8_dim0_zeros(): @pytest.mark.parametrize("K", (128, 256)) @pytest.mark.parametrize("orig_dtype", (torch.float32, torch.bfloat16)) def test_triton_mxfp8_dequant_dim0(M, K, orig_dtype): - x = torch.zeros(M, K, dtype=orig_dtype, device="cuda") + x = torch.zeros(M, K, dtype=orig_dtype, device=_DEVICE) block_size = 32 x_data, x_scales = triton_to_mxfp8_dim0_reference(x, block_size=32) hp_ref = to_dtype( @@ -505,7 +507,7 @@ def test_triton_mxfp8_dequant_dim0(M, K, orig_dtype): torch.testing.assert_close(hp_t, hp_ref, rtol=0, atol=0) -@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +@pytest.mark.skipif(not torch.accelerator.is_available(), reason="GPU not available") @pytest.mark.parametrize( "shape", [ @@ -520,7 +522,7 @@ def test_triton_mxfp8_dequant_dim0(M, K, orig_dtype): ], ) def test_rearrange(shape): - scales = torch.randint(256, size=shape, device="cuda", dtype=torch.uint8) + scales = torch.randint(256, size=shape, device=_DEVICE, dtype=torch.uint8) eager = to_blocked(scales, False) triton = to_blocked(scales, True) torch.testing.assert_close(eager, triton, atol=0, rtol=0) @@ -550,7 +552,7 @@ def test_cuda_mx_dim1_numerics(M, K, input_dtype, scaling_mode): # Use disinct incrementing values from 0 to M*K-1 to make debugging easier. x = ( - torch.arange(0, M * K, dtype=input_dtype, device="cuda") + torch.arange(0, M * K, dtype=input_dtype, device=_DEVICE) .reshape(M, K) .contiguous() ) @@ -592,7 +594,7 @@ def test_cuda_mx_dim0_not_supported(): M, K = 64, 64 block_size = 32 x = ( - torch.arange(0, M * K, dtype=torch.bfloat16, device="cuda") + torch.arange(0, M * K, dtype=torch.bfloat16, device=_DEVICE) .reshape(M, K) .contiguous() ) @@ -619,7 +621,7 @@ def test_cuda_mx_dim1_invalid_block_size(): M, K = 64, 64 x = ( - torch.arange(0, M * K, dtype=torch.bfloat16, device="cuda") + torch.arange(0, M * K, dtype=torch.bfloat16, device=_DEVICE) .reshape(M, K) .contiguous() ) diff --git a/test/prototype/mx_formats/test_mx_tensor.py b/test/prototype/mx_formats/test_mx_tensor.py index 2b8c72ff91..922f44848d 100644 --- a/test/prototype/mx_formats/test_mx_tensor.py +++ b/test/prototype/mx_formats/test_mx_tensor.py @@ -27,12 +27,14 @@ from torchao.quantization.quantize_.common import KernelPreference from torchao.quantization.utils import compute_error from torchao.utils import ( + get_current_accelerator_device, is_sm_at_least_89, is_sm_at_least_90, torch_version_at_least, ) torch.manual_seed(2) +_DEVICE = get_current_accelerator_device() if not torch_version_at_least("2.8.0"): pytest.skip("Unsupported PyTorch version", allow_module_level=True) @@ -81,42 +83,42 @@ def assert_sqnr_gt_threshold(orig, new, threshold): assert data_mx.scale.shape == (*prev_dims, K // block_size) -@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +@pytest.mark.skipif(not torch.accelerator.is_available(), reason="GPU not available") @pytest.mark.parametrize("elem_dtype", SUPPORTED_ELEM_DTYPES) def test_hello_world(elem_dtype): - data = torch.randn(8, 8, device="cuda", dtype=torch.bfloat16) + data = torch.randn(8, 8, device=_DEVICE, dtype=torch.bfloat16) block_size = 4 _test_mx(data, elem_dtype, block_size) -@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +@pytest.mark.skipif(not torch.accelerator.is_available(), reason="GPU not available") @pytest.mark.parametrize("scale_calculation_mode", [s for s in ScaleCalculationMode]) @pytest.mark.parametrize("elem_dtype", SUPPORTED_ELEM_DTYPES) def test_realistic_numerics(elem_dtype, scale_calculation_mode): - data = torch.randn(128, 128, device="cuda", dtype=torch.bfloat16) + data = torch.randn(128, 128, device=_DEVICE, dtype=torch.bfloat16) block_size = 32 _test_mx(data, elem_dtype, block_size, scale_calculation_mode) -@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +@pytest.mark.skipif(not torch.accelerator.is_available(), reason="GPU not available") @pytest.mark.parametrize("elem_dtype", SUPPORTED_ELEM_DTYPES) def test_all_zeros(elem_dtype): - data = torch.zeros(4, 4, device="cuda", dtype=torch.bfloat16) + data = torch.zeros(4, 4, device=_DEVICE, dtype=torch.bfloat16) block_size = 4 _test_mx(data, elem_dtype, block_size) -@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +@pytest.mark.skipif(not torch.accelerator.is_available(), reason="GPU not available") @pytest.mark.parametrize("elem_dtype", SUPPORTED_ELEM_DTYPES) def test_some_zeros(elem_dtype): - data = torch.randn(4, 4, device="cuda", dtype=torch.bfloat16) + data = torch.randn(4, 4, device=_DEVICE, dtype=torch.bfloat16) data[0, :] = 0.0 data[:, 2] = 0.0 block_size = 4 _test_mx(data, elem_dtype, block_size) -@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +@pytest.mark.skipif(not torch.accelerator.is_available(), reason="GPU not available") def test_to_mx_rceil(): # nan # fmt: off @@ -325,7 +327,7 @@ def test_to_mx_rceil(): torch.testing.assert_close(data_mx.qdata, ground_truth_fp8) -@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +@pytest.mark.skipif(not torch.accelerator.is_available(), reason="GPU not available") @pytest.mark.parametrize("elem_dtype", SUPPORTED_ELEM_DTYPES) def test_exponent_nan_in(elem_dtype): """ @@ -333,7 +335,7 @@ def test_exponent_nan_in(elem_dtype): value is set to is NaN """ tensor_hp = torch.tensor( - [float("nan"), 1, 2, 3, 4, 5, 6, 7], device="cuda", dtype=torch.bfloat16 + [float("nan"), 1, 2, 3, 4, 5, 6, 7], device=_DEVICE, dtype=torch.bfloat16 ) block_size = 4 tensor_mx = MXTensor.to_mx(tensor_hp, elem_dtype, block_size) @@ -341,29 +343,29 @@ def test_exponent_nan_in(elem_dtype): assert not torch.any(torch.isnan(tensor_mx.scale[1:])) -@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +@pytest.mark.skipif(not torch.accelerator.is_available(), reason="GPU not available") @pytest.mark.parametrize("elem_dtype", SUPPORTED_ELEM_DTYPES) def test_exponent_nan_out(elem_dtype): """ If block exponent value is NaN, the MX tensor block value is NaN """ scale_e8m0 = torch.tensor( - [float("nan"), 1.0], dtype=torch.float8_e8m0fnu, device="cuda" + [float("nan"), 1.0], dtype=torch.float8_e8m0fnu, device=_DEVICE ) block_size = 4 if elem_dtype in (torch.float8_e4m3fn, torch.float8_e5m2): data_bits = torch.tensor( - [0, 1, 2, 3, 4, 5, 6, 7], dtype=elem_dtype, device="cuda" + [0, 1, 2, 3, 4, 5, 6, 7], dtype=elem_dtype, device=_DEVICE ) # noqa: E501 elif elem_dtype in (DTYPE_FP6_E2M3, DTYPE_FP6_E3M2): data_bits = torch.tensor( - [0, 1, 2, 3, 4, 5, 6, 7], dtype=torch.uint8, device="cuda" + [0, 1, 2, 3, 4, 5, 6, 7], dtype=torch.uint8, device=_DEVICE ) # noqa: E501 elif elem_dtype == torch.float4_e2m1fn_x2: data_bits = torch.tensor( - [0, 1, 2, 3, 4, 5, 6, 7], dtype=torch.uint8, device="cuda" + [0, 1, 2, 3, 4, 5, 6, 7], dtype=torch.uint8, device=_DEVICE ) # noqa: E501 data_bits = pack_uint4(data_bits) else: @@ -384,7 +386,7 @@ def test_exponent_nan_out(elem_dtype): assert not torch.any(torch.isnan(tensor_hp.flatten()[4:])) -@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +@pytest.mark.skipif(not torch.accelerator.is_available(), reason="GPU not available") @pytest.mark.parametrize("elem_dtype", SUPPORTED_ELEM_DTYPES) def test_ranks(elem_dtype): """ @@ -393,11 +395,11 @@ def test_ranks(elem_dtype): B = 4 shapes = ((B * 4,), (B * 4, 4), (B * 4, 4, 4), (B * 4, 4, 4, 4)) for s in shapes: - tensor_hp = torch.randn(*s, device="cuda", dtype=torch.bfloat16) + tensor_hp = torch.randn(*s, device=_DEVICE, dtype=torch.bfloat16) _test_mx(tensor_hp, elem_dtype, B) -@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +@pytest.mark.skipif(not torch.accelerator.is_available(), reason="GPU not available") @pytest.mark.parametrize("elem_dtype", SUPPORTED_ELEM_DTYPES) @pytest.mark.parametrize("B", [1, 4, 32]) def test_block_sizes(elem_dtype, B): @@ -408,11 +410,11 @@ def test_block_sizes(elem_dtype, B): pytest.skip("unsupported configuration") elif B % 4 != 0 and elem_dtype in [DTYPE_FP6_E2M3, DTYPE_FP6_E3M2]: pytest.skip("unsupported configuration") - tensor_hp = torch.randn(B, device="cuda", dtype=torch.bfloat16) + tensor_hp = torch.randn(B, device=_DEVICE, dtype=torch.bfloat16) _test_mx(tensor_hp, elem_dtype, B) -@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +@pytest.mark.skipif(not torch.accelerator.is_available(), reason="GPU not available") @pytest.mark.parametrize("elem_dtype", SUPPORTED_ELEM_DTYPES) def test_transpose(elem_dtype): """ @@ -420,7 +422,7 @@ def test_transpose(elem_dtype): """ M, K = 128, 256 block_size = 32 - tensor_hp = torch.randn(M, K, device="cuda", dtype=torch.bfloat16) + tensor_hp = torch.randn(M, K, device=_DEVICE, dtype=torch.bfloat16) tensor_mx = MXTensor.to_mx( tensor_hp, elem_dtype, @@ -435,18 +437,18 @@ def test_transpose(elem_dtype): torch.testing.assert_close(tensor_mx_dq_t, tensor_mx_t_dq, atol=0, rtol=0) -@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +@pytest.mark.skipif(not torch.accelerator.is_available(), reason="GPU not available") @pytest.mark.parametrize("elem_dtype", SUPPORTED_ELEM_DTYPES) def test_view(elem_dtype): - x = torch.randn(1, 2, 4, device="cuda") + x = torch.randn(1, 2, 4, device=_DEVICE) block_size = 4 x_mx = MXTensor.to_mx(x, elem_dtype, block_size) x_mx_2 = x_mx.view(2, 4) # noqa: F841 -@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +@pytest.mark.skipif(not torch.accelerator.is_available(), reason="GPU not available") def test_clone(): - data = torch.randn(8, 8, device="cuda", dtype=torch.bfloat16) + data = torch.randn(8, 8, device=_DEVICE, dtype=torch.bfloat16) block_size = 4 data_mx = MXTensor.to_mx(data, torch.float8_e4m3fn, block_size) data_mx_c = data_mx.clone() @@ -458,7 +460,7 @@ def test_clone(): ) -@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +@pytest.mark.skipif(not torch.accelerator.is_available(), reason="GPU not available") @pytest.mark.parametrize("elem_dtype", SUPPORTED_ELEM_DTYPES) @pytest.mark.parametrize("hp_dtype", [torch.float32, torch.bfloat16]) @pytest.mark.parametrize("all_zeros", [False, True]) @@ -467,15 +469,15 @@ def test_to_mx_from_mx_compile_numerics(elem_dtype, hp_dtype, all_zeros): Verifies that compile does not change numerics of MX casts """ if elem_dtype in (torch.float8_e4m3fn, torch.float8_e5m2): - if not is_sm_at_least_89(): + if torch.cuda.is_available() and not is_sm_at_least_89(): # separate ifs because flake8 is outsmarting me pytest.skip("CUDA capability >= 8.9 required for float8 in triton") shape = 4, 8 if not all_zeros: - x = torch.randn(*shape, dtype=hp_dtype, device="cuda") + x = torch.randn(*shape, dtype=hp_dtype, device=_DEVICE) else: - x = torch.zeros(*shape, dtype=hp_dtype, device="cuda") + x = torch.zeros(*shape, dtype=hp_dtype, device=_DEVICE) block_size = 4 to_mx_c = torch.compile(MXTensor.to_mx, fullgraph=True) @@ -508,9 +510,9 @@ def test_to_mx_from_mx_compile_numerics(elem_dtype, hp_dtype, all_zeros): torch.testing.assert_close(x_mx_dq, x_mx_c_dq, atol=0, rtol=0) -@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +@pytest.mark.skipif(not torch.accelerator.is_available(), reason="GPU not available") @pytest.mark.skipif( - not is_sm_at_least_89(), + torch.cuda.is_available() and not is_sm_at_least_89(), reason="float8 in triton requires CUDA capability 8.9 or greater", ) def test_to_mx_inductor_single_kernel(): @@ -520,15 +522,15 @@ def test_to_mx_inductor_single_kernel(): """ # TODO(future PR): add fp4 and fp6 here # TODO(#1773): add swizzled scale format here - x = torch.randn(2048, 2048, dtype=torch.bfloat16, device="cuda") + x = torch.randn(2048, 2048, dtype=torch.bfloat16, device=_DEVICE) block_size = 32 to_mx_c = torch.compile(MXTensor.to_mx, fullgraph=True) out, code = run_and_get_code(to_mx_c, x, torch.float8_e4m3fn, block_size) FileCheck().check("def call(").check_count(".run(", 1, exactly=True).run(code[0]) -@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") -@pytest.mark.skipIf(not is_sm_at_least_90(), "Need sm90+") +@pytest.mark.skipif(not torch.accelerator.is_available(), reason="GPU not available") +@pytest.mark.skipIf(torch.cuda.is_available() and not is_sm_at_least_90(), "Need sm90+") def test_index_select(): """ test that `x_0 = x[0]` works when `x` is a 3D `MXTensor`. This is @@ -538,7 +540,7 @@ def test_index_select(): """ E, K, N = 128, 256, 512 - x = torch.randn(E, N, K, device="cuda", dtype=torch.bfloat16) + x = torch.randn(E, N, K, device=_DEVICE, dtype=torch.bfloat16) x_mx = MXTensor.to_mx(x, torch.float8_e4m3fn, 32) x_mx_1 = x_mx[1] @@ -547,9 +549,9 @@ def test_index_select(): ) -@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +@pytest.mark.skipif(not torch.accelerator.is_available(), reason="GPU not available") @pytest.mark.skipif( - not is_sm_at_least_89(), + torch.cuda.is_available() and not is_sm_at_least_89(), reason="float8 in triton requires CUDA capability 8.9 or greater", ) def test_cast_to_float8_e4m3fn_saturation_behavior(): @@ -564,7 +566,7 @@ def test_cast_to_float8_e4m3fn_saturation_behavior(): -1 * max_val, ], dtype=torch.bfloat16, - device="cuda", + device=_DEVICE, ) # create example data outside the representable range @@ -574,7 +576,7 @@ def test_cast_to_float8_e4m3fn_saturation_behavior(): -1 * (max_val * 2), ], dtype=torch.bfloat16, - device="cuda", + device=_DEVICE, ) # verify that in eager mode PyTorch casting to float8 is unsaturated @@ -611,14 +613,14 @@ def to_f8(x): ], ) @pytest.mark.parametrize( - "use_triton_kernel", [False, True] if torch.cuda.is_available() else [False] + "use_triton_kernel", [False, True] if torch.accelerator.is_available() else [False] ) @pytest.mark.skipif( not torch_version_at_least("2.8.0"), reason="torch.compile requires PyTorch 2.8+" ) def test_to_blocked_from_blocked_roundtrip(shape, use_triton_kernel: bool): rows, cols = shape - device = "cuda" if torch.cuda.is_available() else "cpu" + device = _DEVICE if torch.accelerator.is_available() else "cpu" original = torch.randint(0, 255, (rows, cols), device=device, dtype=torch.uint8) @@ -634,8 +636,11 @@ def test_to_blocked_from_blocked_roundtrip(shape, use_triton_kernel: bool): ) -@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") -@pytest.mark.skipif(not torch_version_at_least("2.8.0"), reason="requires PyTorch 2.8+") +@pytest.mark.skipif(not torch.accelerator.is_available(), reason="GPU not available") +@pytest.mark.skipif( + torch.cuda.is_available() and not torch_version_at_least("2.8.0"), + reason="requires PyTorch 2.8+", +) @pytest.mark.parametrize("transpose", [False, True]) @pytest.mark.parametrize( "shape", @@ -650,7 +655,7 @@ def test_scale_shape_matches_qdata(transpose, shape): block_size = 32 - x_hp = torch.randn(*shape, device="cuda") + x_hp = torch.randn(*shape, device=_DEVICE) x = MXTensor.to_mx( x_hp, torch.float8_e4m3fn, @@ -688,8 +693,11 @@ def test_scale_shape_matches_qdata(transpose, shape): ) -@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") -@pytest.mark.skipif(not torch_version_at_least("2.8.0"), reason="requires PyTorch 2.8+") +@pytest.mark.skipif(not torch.accelerator.is_available(), reason="GPU not available") +@pytest.mark.skipif( + torch.cuda.is_available() and not torch_version_at_least("2.8.0"), + reason="requires PyTorch 2.8+", +) @pytest.mark.parametrize("elem_dtype", (torch.float8_e4m3fn, torch.float4_e2m1fn_x2)) @pytest.mark.parametrize("transpose", [False, True]) @pytest.mark.parametrize( @@ -705,7 +713,7 @@ def test_swizzle(elem_dtype, transpose, shape): block_size = 32 - x_hp = torch.randn(*shape, device="cuda") + x_hp = torch.randn(*shape, device=_DEVICE) x = MXTensor.to_mx( x_hp, elem_dtype,