|
27 | 27 | from torchao.prototype.mx_formats.utils import from_blocked, to_blocked |
28 | 28 | from torchao.quantization.utils import compute_error |
29 | 29 | from torchao.utils import ( |
| 30 | + get_current_accelerator_device, |
30 | 31 | is_sm_at_least_89, |
31 | 32 | is_sm_at_least_90, |
32 | 33 | torch_version_at_least, |
33 | | - get_current_accelerator_device, |
34 | 34 | ) |
35 | 35 |
|
36 | 36 | torch.manual_seed(2) |
@@ -663,7 +663,10 @@ def test_to_blocked_from_blocked_roundtrip(shape, use_triton_kernel: bool): |
663 | 663 |
|
664 | 664 |
|
665 | 665 | @pytest.mark.skipif(not torch.accelerator.is_available(), reason="GPU not available") |
666 | | -@pytest.mark.skipif(torch.cuda.is_available() and not torch_version_at_least("2.8.0"), reason="requires PyTorch 2.8+") |
| 666 | +@pytest.mark.skipif( |
| 667 | + torch.cuda.is_available() and not torch_version_at_least("2.8.0"), |
| 668 | + reason="requires PyTorch 2.8+", |
| 669 | +) |
667 | 670 | @pytest.mark.parametrize("transpose", [False, True]) |
668 | 671 | @pytest.mark.parametrize( |
669 | 672 | "shape", |
@@ -717,7 +720,10 @@ def test_scale_shape_matches_qdata(transpose, shape): |
717 | 720 |
|
718 | 721 |
|
719 | 722 | @pytest.mark.skipif(not torch.accelerator.is_available(), reason="GPU not available") |
720 | | -@pytest.mark.skipif(torch.cuda.is_available() and not torch_version_at_least("2.8.0"), reason="requires PyTorch 2.8+") |
| 723 | +@pytest.mark.skipif( |
| 724 | + torch.cuda.is_available() and not torch_version_at_least("2.8.0"), |
| 725 | + reason="requires PyTorch 2.8+", |
| 726 | +) |
721 | 727 | @pytest.mark.parametrize("elem_dtype", (torch.float8_e4m3fn, torch.float4_e2m1fn_x2)) |
722 | 728 | @pytest.mark.parametrize("transpose", [False, True]) |
723 | 729 | @pytest.mark.parametrize( |
|
0 commit comments