Skip to content

Commit 77feb44

Browse files
committed
test/prototype/mx_formats/test_kernels.py
1 parent c2ecff1 commit 77feb44

File tree

1 file changed

+24
-22
lines changed

1 file changed

+24
-22
lines changed

test/prototype/mx_formats/test_kernels.py

Lines changed: 24 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -49,9 +49,11 @@
4949
is_sm_at_least_89,
5050
is_sm_at_least_100,
5151
torch_version_at_least,
52+
get_current_accelerator_device,
5253
)
5354

5455
torch.manual_seed(0)
56+
_DEVICE = get_current_accelerator_device()
5557

5658
if not torch_version_at_least("2.8.0"):
5759
pytest.skip("Unsupported PyTorch version", allow_module_level=True)
@@ -398,9 +400,9 @@ def test_fp6_values(dtype_name):
398400
[
399401
"cpu",
400402
pytest.param(
401-
"cuda",
403+
_DEVICE,
402404
marks=pytest.mark.skipif(
403-
not torch.cuda.is_available(), reason="CUDA not available"
405+
not torch.accelerator.is_available(), reason="GPU not available"
404406
),
405407
),
406408
],
@@ -423,11 +425,11 @@ def test_fp6_e3m2_rounding(f32_val, f6_e3m2_enc, device):
423425
assert f6_e3m2_unpacked.item() == (f6_e3m2_enc | 0b100000)
424426

425427

426-
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
428+
@pytest.mark.skipif(not torch.accelerator.is_available(), reason="GPU not available")
427429
@pytest.mark.skipif(not has_triton(), reason="unsupported without triton")
428430
def test_fp6_e2m3_pack_unpack():
429431
orig_vals = torch.Tensor([[0.0, 0.5, 7.5, -0.0], [-0.875, 1.0, -6.0, 0.125]]).to(
430-
"cuda"
432+
_DEVICE
431433
)
432434
orig_vals_f6_unpacked = f32_to_f6_e2m3_unpacked(orig_vals)
433435
orig_vals_f6_packed = pack_uint6(orig_vals_f6_unpacked)
@@ -438,11 +440,11 @@ def test_fp6_e2m3_pack_unpack():
438440
assert torch.all(orig_vals_f6_packed_unpacked == orig_vals)
439441

440442

441-
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
443+
@pytest.mark.skipif(not torch.accelerator.is_available(), reason="GPU not available")
442444
@pytest.mark.skipif(not has_triton(), reason="unsupported without triton")
443445
def test_fp6_e3m2_pack_unpack():
444446
orig_vals = torch.Tensor([[0.0, 5.0, 28.0, -0.0], [-0.25, 0.1875, 0.0625, 8.0]]).to(
445-
"cuda"
447+
_DEVICE
446448
)
447449
orig_vals_f6_unpacked = f32_to_f6_e3m2_unpacked(orig_vals)
448450
orig_vals_f6_packed = pack_uint6(orig_vals_f6_unpacked)
@@ -472,13 +474,13 @@ def triton_to_mxfp8_dim0_reference(
472474

473475
@pytest.mark.skipif(not has_triton(), reason="unsupported without triton")
474476
@pytest.mark.skipif(
475-
not is_sm_at_least_89(),
477+
torch.cuda.is_available() and not is_sm_at_least_89(),
476478
reason="float8 in triton requires CUDA capability 8.9 or greater",
477479
)
478480
@pytest.mark.parametrize("M", (256, 2048))
479481
@pytest.mark.parametrize("K", (256, 2048))
480482
def test_triton_mxfp8_dim1_randn(M, K):
481-
x = torch.randn(M, K, dtype=torch.bfloat16, device="cuda")
483+
x = torch.randn(M, K, dtype=torch.bfloat16, device=_DEVICE)
482484
x_mx_ref, x_s_ref = triton_to_mxfp8_dim1_reference(x, block_size=32)
483485
x_mx_t, x_s_t = triton_to_mxfp8_dim1(x, inner_block_size=32)
484486
torch.testing.assert_close(x_mx_t, x_mx_ref, rtol=0, atol=0)
@@ -487,13 +489,13 @@ def test_triton_mxfp8_dim1_randn(M, K):
487489

488490
@pytest.mark.skipif(not has_triton(), reason="unsupported without triton")
489491
@pytest.mark.skipif(
490-
not is_sm_at_least_100(),
492+
torch.cuda.is_available() and not is_sm_at_least_100(),
491493
reason="mxfp8 requires CUDA capability 10.0 or greater",
492494
)
493495
@pytest.mark.parametrize("M", (256, 2048, 131072))
494496
@pytest.mark.parametrize("K", (256, 5120, 7168))
495497
def test_triton_mxfp8_dim0_randn(M, K):
496-
x = torch.randn(M, K, dtype=torch.bfloat16, device="cuda")
498+
x = torch.randn(M, K, dtype=torch.bfloat16, device=_DEVICE)
497499
x_mx_ref, x_s_ref = triton_to_mxfp8_dim0_reference(x, block_size=32)
498500
x_mx_t, x_s_t = triton_to_mxfp8_dim0(x, inner_block_size=32)
499501
torch.testing.assert_close(x_mx_t, x_mx_ref, rtol=0, atol=0)
@@ -502,11 +504,11 @@ def test_triton_mxfp8_dim0_randn(M, K):
502504

503505
@pytest.mark.skipif(not has_triton(), reason="unsupported without triton")
504506
@pytest.mark.skipif(
505-
not is_sm_at_least_100(),
507+
torch.cuda.is_available() and not is_sm_at_least_100(),
506508
reason="mxfp8 requires CUDA capability 10.0 or greater",
507509
)
508510
def test_triton_mxfp8_dim0_zeros():
509-
x = torch.zeros(8192, 5120, dtype=torch.bfloat16, device="cuda")
511+
x = torch.zeros(8192, 5120, dtype=torch.bfloat16, device=_DEVICE)
510512
x_mx_ref, x_s_ref = triton_to_mxfp8_dim0_reference(x, block_size=32)
511513
x_mx_t, x_s_t = triton_to_mxfp8_dim0(x, inner_block_size=32)
512514
assert not x_mx_t.isnan().any(), "quantized tensor should not contain NaNs"
@@ -516,14 +518,14 @@ def test_triton_mxfp8_dim0_zeros():
516518

517519
@pytest.mark.skipif(not has_triton(), reason="unsupported without triton")
518520
@pytest.mark.skipif(
519-
not is_sm_at_least_100(),
521+
torch.cuda.is_available() and not is_sm_at_least_100(),
520522
reason="mxfp8 requires CUDA capability 10.0 or greater",
521523
)
522524
@pytest.mark.parametrize("M", (256, 2048, 131072))
523525
@pytest.mark.parametrize("K", (256, 5120, 7168))
524526
@pytest.mark.parametrize("orig_dtype", (torch.float32, torch.bfloat16))
525527
def test_triton_mxfp8_dequant_dim0(M, K, orig_dtype):
526-
x = torch.zeros(M, K, dtype=orig_dtype, device="cuda")
528+
x = torch.zeros(M, K, dtype=orig_dtype, device=_DEVICE)
527529
block_size = 32
528530
x_data, x_scales = triton_to_mxfp8_dim0_reference(x, block_size=32)
529531
hp_ref = to_dtype(
@@ -537,7 +539,7 @@ def test_triton_mxfp8_dequant_dim0(M, K, orig_dtype):
537539
torch.testing.assert_close(hp_t, hp_ref, rtol=0, atol=0)
538540

539541

540-
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
542+
@pytest.mark.skipif(not torch.accelerator.is_available(), reason="GPU not available")
541543
@pytest.mark.parametrize(
542544
"shape",
543545
[
@@ -552,14 +554,14 @@ def test_triton_mxfp8_dequant_dim0(M, K, orig_dtype):
552554
],
553555
)
554556
def test_rearrange(shape):
555-
scales = torch.randint(256, size=shape, device="cuda", dtype=torch.uint8)
557+
scales = torch.randint(256, size=shape, device=_DEVICE, dtype=torch.uint8)
556558
eager = to_blocked(scales, False)
557559
triton = to_blocked(scales, True)
558560
torch.testing.assert_close(eager, triton, atol=0, rtol=0)
559561

560562

561563
@pytest.mark.skipif(
562-
not is_sm_at_least_100(),
564+
torch.cuda.is_available() and not is_sm_at_least_100(),
563565
reason="MXFP8 requires CUDA capability 10.0 or greater",
564566
)
565567
@pytest.mark.parametrize("M", (32, 64, 2048))
@@ -578,7 +580,7 @@ def test_cuda_mx_dim1_numerics(M, K, input_dtype, scaling_mode):
578580

579581
# Use disinct incrementing values from 0 to M*K-1 to make debugging easier.
580582
x = (
581-
torch.arange(0, M * K, dtype=input_dtype, device="cuda")
583+
torch.arange(0, M * K, dtype=input_dtype, device=_DEVICE)
582584
.reshape(M, K)
583585
.contiguous()
584586
)
@@ -607,7 +609,7 @@ def test_cuda_mx_dim1_numerics(M, K, input_dtype, scaling_mode):
607609

608610

609611
@pytest.mark.skipif(
610-
not is_sm_at_least_100(),
612+
torch.cuda.is_available() and not is_sm_at_least_100(),
611613
reason="MXFP8 requires CUDA capability 10.0 or greater",
612614
)
613615
def test_cuda_mx_dim0_not_supported():
@@ -616,7 +618,7 @@ def test_cuda_mx_dim0_not_supported():
616618
M, K = 64, 64
617619
block_size = 32
618620
x = (
619-
torch.arange(0, M * K, dtype=torch.bfloat16, device="cuda")
621+
torch.arange(0, M * K, dtype=torch.bfloat16, device=_DEVICE)
620622
.reshape(M, K)
621623
.contiguous()
622624
)
@@ -631,15 +633,15 @@ def test_cuda_mx_dim0_not_supported():
631633

632634

633635
@pytest.mark.skipif(
634-
not is_sm_at_least_100(),
636+
torch.cuda.is_available() and not is_sm_at_least_100(),
635637
reason="MXFP8 requires CUDA capability 10.0 or greater",
636638
)
637639
def test_cuda_mx_dim1_invalid_block_size():
638640
from torchao.prototype import mxfp8_cuda
639641

640642
M, K = 64, 64
641643
x = (
642-
torch.arange(0, M * K, dtype=torch.bfloat16, device="cuda")
644+
torch.arange(0, M * K, dtype=torch.bfloat16, device=_DEVICE)
643645
.reshape(M, K)
644646
.contiguous()
645647
)

0 commit comments

Comments
 (0)