Skip to content

Commit dd6ae61

Browse files
committed
fix format issue
1 parent 360945c commit dd6ae61

File tree

3 files changed

+11
-4
lines changed

3 files changed

+11
-4
lines changed

test/prototype/mx_formats/test_kernels.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,10 +43,10 @@
4343
from torchao.prototype.mx_formats.mx_tensor import ScaleCalculationMode, to_dtype, to_mx
4444
from torchao.prototype.mx_formats.utils import to_blocked
4545
from torchao.utils import (
46+
get_current_accelerator_device,
4647
is_sm_at_least_89,
4748
is_sm_at_least_100,
4849
torch_version_at_least,
49-
get_current_accelerator_device,
5050
)
5151

5252
torch.manual_seed(0)

test/prototype/mx_formats/test_mx_tensor.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,10 +27,10 @@
2727
from torchao.quantization.quantize_.common import KernelPreference
2828
from torchao.quantization.utils import compute_error
2929
from torchao.utils import (
30+
get_current_accelerator_device,
3031
is_sm_at_least_89,
3132
is_sm_at_least_90,
3233
torch_version_at_least,
33-
get_current_accelerator_device,
3434
)
3535

3636
torch.manual_seed(2)
@@ -637,7 +637,10 @@ def test_to_blocked_from_blocked_roundtrip(shape, use_triton_kernel: bool):
637637

638638

639639
@pytest.mark.skipif(not torch.accelerator.is_available(), reason="GPU not available")
640-
@pytest.mark.skipif(torch.cuda.is_available() and not torch_version_at_least("2.8.0"), reason="requires PyTorch 2.8+")
640+
@pytest.mark.skipif(
641+
torch.cuda.is_available() and not torch_version_at_least("2.8.0"),
642+
reason="requires PyTorch 2.8+",
643+
)
641644
@pytest.mark.parametrize("transpose", [False, True])
642645
@pytest.mark.parametrize(
643646
"shape",
@@ -691,7 +694,10 @@ def test_scale_shape_matches_qdata(transpose, shape):
691694

692695

693696
@pytest.mark.skipif(not torch.accelerator.is_available(), reason="GPU not available")
694-
@pytest.mark.skipif(torch.cuda.is_available() and not torch_version_at_least("2.8.0"), reason="requires PyTorch 2.8+")
697+
@pytest.mark.skipif(
698+
torch.cuda.is_available() and not torch_version_at_least("2.8.0"),
699+
reason="requires PyTorch 2.8+",
700+
)
695701
@pytest.mark.parametrize("elem_dtype", (torch.float8_e4m3fn, torch.float4_e2m1fn_x2))
696702
@pytest.mark.parametrize("transpose", [False, True])
697703
@pytest.mark.parametrize(

torchao/prototype/mx_formats/kernels.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424

2525
logger = logging.getLogger(__name__)
2626

27+
2728
def get_bits(x: torch.Tensor) -> str:
2829
bits_per_byte = 8
2930
# Numpy has a nice function to get the string representation of binary.

0 commit comments

Comments
 (0)