|
27 | 27 | from torchao.quantization.quantize_.common import KernelPreference |
28 | 28 | from torchao.quantization.utils import compute_error |
29 | 29 | from torchao.utils import ( |
| 30 | + get_current_accelerator_device, |
30 | 31 | is_sm_at_least_89, |
31 | 32 | is_sm_at_least_90, |
32 | 33 | torch_version_at_least, |
33 | | - get_current_accelerator_device, |
34 | 34 | ) |
35 | 35 |
|
36 | 36 | torch.manual_seed(2) |
@@ -637,7 +637,10 @@ def test_to_blocked_from_blocked_roundtrip(shape, use_triton_kernel: bool): |
637 | 637 |
|
638 | 638 |
|
639 | 639 | @pytest.mark.skipif(not torch.accelerator.is_available(), reason="GPU not available") |
640 | | -@pytest.mark.skipif(torch.cuda.is_available() and not torch_version_at_least("2.8.0"), reason="requires PyTorch 2.8+") |
| 640 | +@pytest.mark.skipif( |
| 641 | + torch.cuda.is_available() and not torch_version_at_least("2.8.0"), |
| 642 | + reason="requires PyTorch 2.8+", |
| 643 | +) |
641 | 644 | @pytest.mark.parametrize("transpose", [False, True]) |
642 | 645 | @pytest.mark.parametrize( |
643 | 646 | "shape", |
@@ -691,7 +694,10 @@ def test_scale_shape_matches_qdata(transpose, shape): |
691 | 694 |
|
692 | 695 |
|
693 | 696 | @pytest.mark.skipif(not torch.accelerator.is_available(), reason="GPU not available") |
694 | | -@pytest.mark.skipif(torch.cuda.is_available() and not torch_version_at_least("2.8.0"), reason="requires PyTorch 2.8+") |
| 697 | +@pytest.mark.skipif( |
| 698 | + torch.cuda.is_available() and not torch_version_at_least("2.8.0"), |
| 699 | + reason="requires PyTorch 2.8+", |
| 700 | +) |
695 | 701 | @pytest.mark.parametrize("elem_dtype", (torch.float8_e4m3fn, torch.float4_e2m1fn_x2)) |
696 | 702 | @pytest.mark.parametrize("transpose", [False, True]) |
697 | 703 | @pytest.mark.parametrize( |
|
0 commit comments