Skip to content

Commit 64fa6b2

Browse files
committed
fix format issue
1 parent 8d57626 commit 64fa6b2

File tree

3 files changed

+11
-5
lines changed

3 files changed

+11
-5
lines changed

test/prototype/mx_formats/test_kernels.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,10 +46,10 @@
4646
from torchao.prototype.mx_formats.mx_tensor import ScaleCalculationMode, to_dtype, to_mx
4747
from torchao.prototype.mx_formats.utils import to_blocked
4848
from torchao.utils import (
49+
get_current_accelerator_device,
4950
is_sm_at_least_89,
5051
is_sm_at_least_100,
5152
torch_version_at_least,
52-
get_current_accelerator_device,
5353
)
5454

5555
torch.manual_seed(0)

test/prototype/mx_formats/test_mx_tensor.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,10 +27,10 @@
2727
from torchao.prototype.mx_formats.utils import from_blocked, to_blocked
2828
from torchao.quantization.utils import compute_error
2929
from torchao.utils import (
30+
get_current_accelerator_device,
3031
is_sm_at_least_89,
3132
is_sm_at_least_90,
3233
torch_version_at_least,
33-
get_current_accelerator_device,
3434
)
3535

3636
torch.manual_seed(2)
@@ -663,7 +663,10 @@ def test_to_blocked_from_blocked_roundtrip(shape, use_triton_kernel: bool):
663663

664664

665665
@pytest.mark.skipif(not torch.accelerator.is_available(), reason="GPU not available")
666-
@pytest.mark.skipif(torch.cuda.is_available() and not torch_version_at_least("2.8.0"), reason="requires PyTorch 2.8+")
666+
@pytest.mark.skipif(
667+
torch.cuda.is_available() and not torch_version_at_least("2.8.0"),
668+
reason="requires PyTorch 2.8+",
669+
)
667670
@pytest.mark.parametrize("transpose", [False, True])
668671
@pytest.mark.parametrize(
669672
"shape",
@@ -717,7 +720,10 @@ def test_scale_shape_matches_qdata(transpose, shape):
717720

718721

719722
@pytest.mark.skipif(not torch.accelerator.is_available(), reason="GPU not available")
720-
@pytest.mark.skipif(torch.cuda.is_available() and not torch_version_at_least("2.8.0"), reason="requires PyTorch 2.8+")
723+
@pytest.mark.skipif(
724+
torch.cuda.is_available() and not torch_version_at_least("2.8.0"),
725+
reason="requires PyTorch 2.8+",
726+
)
721727
@pytest.mark.parametrize("elem_dtype", (torch.float8_e4m3fn, torch.float4_e2m1fn_x2))
722728
@pytest.mark.parametrize("transpose", [False, True])
723729
@pytest.mark.parametrize(

torchao/prototype/mx_formats/kernels.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838

3939
logger = logging.getLogger(__name__)
4040

41+
4142
def get_bits(x: torch.Tensor) -> str:
4243
bits_per_byte = 8
4344
# Numpy has a nice function to get the string representation of binary.
@@ -591,7 +592,6 @@ def triton_f6_e3m2_to_bf16(x: torch.Tensor) -> torch.Tensor:
591592
assert x.is_cuda or x.device.type == "xpu"
592593
assert output.is_cuda or output.device.type == "xpu"
593594

594-
595595
n_mx_blocks = x.shape[0]
596596
grid = lambda meta: (triton.cdiv(n_mx_blocks, meta["BLOCK_SIZE_IN"]),)
597597
triton_f6_to_bf16_kernel[grid](

0 commit comments

Comments
 (0)