Skip to content

Commit 69ad8ff

Browse files
committed
test/prototype/mx_formats/test_mx_tensor.py
1 parent f3ebaf9 commit 69ad8ff

File tree

1 file changed

+52
-50
lines changed

1 file changed

+52
-50
lines changed

test/prototype/mx_formats/test_mx_tensor.py

Lines changed: 52 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -30,9 +30,11 @@
3030
is_sm_at_least_89,
3131
is_sm_at_least_90,
3232
torch_version_at_least,
33+
get_current_accelerator_device,
3334
)
3435

3536
torch.manual_seed(2)
37+
_DEVICE = get_current_accelerator_device()
3638

3739
if not torch_version_at_least("2.8.0"):
3840
pytest.skip("Unsupported PyTorch version", allow_module_level=True)
@@ -81,42 +83,42 @@ def assert_sqnr_gt_threshold(orig, new, threshold):
8183
assert data_mx.scale.shape == (*prev_dims, K // block_size)
8284

8385

84-
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
86+
@pytest.mark.skipif(not torch.accelerator.is_available(), reason="GPU not available")
8587
@pytest.mark.parametrize("elem_dtype", SUPPORTED_ELEM_DTYPES)
8688
def test_hello_world(elem_dtype):
87-
data = torch.randn(8, 8, device="cuda", dtype=torch.bfloat16)
89+
data = torch.randn(8, 8, device=_DEVICE, dtype=torch.bfloat16)
8890
block_size = 4
8991
_test_mx(data, elem_dtype, block_size)
9092

9193

92-
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
94+
@pytest.mark.skipif(not torch.accelerator.is_available(), reason="GPU not available")
9395
@pytest.mark.parametrize("scale_calculation_mode", [s for s in ScaleCalculationMode])
9496
@pytest.mark.parametrize("elem_dtype", SUPPORTED_ELEM_DTYPES)
9597
def test_realistic_numerics(elem_dtype, scale_calculation_mode):
96-
data = torch.randn(128, 128, device="cuda", dtype=torch.bfloat16)
98+
data = torch.randn(128, 128, device=_DEVICE, dtype=torch.bfloat16)
9799
block_size = 32
98100
_test_mx(data, elem_dtype, block_size, scale_calculation_mode)
99101

100102

101-
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
103+
@pytest.mark.skipif(not torch.accelerator.is_available(), reason="GPU not available")
102104
@pytest.mark.parametrize("elem_dtype", SUPPORTED_ELEM_DTYPES)
103105
def test_all_zeros(elem_dtype):
104-
data = torch.zeros(4, 4, device="cuda", dtype=torch.bfloat16)
106+
data = torch.zeros(4, 4, device=_DEVICE, dtype=torch.bfloat16)
105107
block_size = 4
106108
_test_mx(data, elem_dtype, block_size)
107109

108110

109-
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
111+
@pytest.mark.skipif(not torch.accelerator.is_available(), reason="GPU not available")
110112
@pytest.mark.parametrize("elem_dtype", SUPPORTED_ELEM_DTYPES)
111113
def test_some_zeros(elem_dtype):
112-
data = torch.randn(4, 4, device="cuda", dtype=torch.bfloat16)
114+
data = torch.randn(4, 4, device=_DEVICE, dtype=torch.bfloat16)
113115
data[0, :] = 0.0
114116
data[:, 2] = 0.0
115117
block_size = 4
116118
_test_mx(data, elem_dtype, block_size)
117119

118120

119-
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
121+
@pytest.mark.skipif(not torch.accelerator.is_available(), reason="GPU not available")
120122
def test_to_mx_rceil():
121123
# nan
122124
# fmt: off
@@ -325,23 +327,23 @@ def test_to_mx_rceil():
325327
torch.testing.assert_close(data_mx.qdata, ground_truth_fp8)
326328

327329

328-
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
330+
@pytest.mark.skipif(not torch.accelerator.is_available(), reason="GPU not available")
329331
@pytest.mark.parametrize("elem_dtype", SUPPORTED_ELEM_DTYPES)
330332
def test_exponent_nan_in(elem_dtype):
331333
"""
332334
If high precision block values has a NaN, the exponent block
333335
value is set to is NaN
334336
"""
335337
tensor_hp = torch.tensor(
336-
[float("nan"), 1, 2, 3, 4, 5, 6, 7], device="cuda", dtype=torch.bfloat16
338+
[float("nan"), 1, 2, 3, 4, 5, 6, 7], device=_DEVICE, dtype=torch.bfloat16
337339
)
338340
block_size = 4
339341
tensor_mx = MXTensor.to_mx(tensor_hp, elem_dtype, block_size)
340342
assert torch.all(torch.isnan(tensor_mx.scale[0]))
341343
assert not torch.any(torch.isnan(tensor_mx.scale[1:]))
342344

343345

344-
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
346+
@pytest.mark.skipif(not torch.accelerator.is_available(), reason="GPU not available")
345347
@pytest.mark.parametrize("elem_dtype", SUPPORTED_ELEM_DTYPES)
346348
@pytest.mark.parametrize("pack_fp6", [False, True])
347349
def test_exponent_nan_out(elem_dtype, pack_fp6):
@@ -352,25 +354,25 @@ def test_exponent_nan_out(elem_dtype, pack_fp6):
352354
pytest.skip("invalid configuration")
353355

354356
scale_e8m0 = torch.tensor(
355-
[float("nan"), 1.0], dtype=torch.float8_e8m0fnu, device="cuda"
357+
[float("nan"), 1.0], dtype=torch.float8_e8m0fnu, device=_DEVICE
356358
)
357359

358360
block_size = 4
359361

360362
if elem_dtype in (torch.float8_e4m3fn, torch.float8_e5m2):
361363
data_bits = torch.tensor(
362-
[0, 1, 2, 3, 4, 5, 6, 7], dtype=elem_dtype, device="cuda"
364+
[0, 1, 2, 3, 4, 5, 6, 7], dtype=elem_dtype, device=_DEVICE
363365
) # noqa: E501
364366
elif elem_dtype in (DTYPE_FP6_E2M3, DTYPE_FP6_E3M2):
365367
data_bits = torch.tensor(
366-
[0, 1, 2, 3, 4, 5, 6, 7], dtype=torch.uint8, device="cuda"
368+
[0, 1, 2, 3, 4, 5, 6, 7], dtype=torch.uint8, device=_DEVICE
367369
) # noqa: E501
368370
if pack_fp6:
369371
data_bits = data_bits.reshape(-1, block_size)
370372
data_bits = pack_uint6(data_bits)
371373
elif elem_dtype == torch.float4_e2m1fn_x2:
372374
data_bits = torch.tensor(
373-
[0, 1, 2, 3, 4, 5, 6, 7], dtype=torch.uint8, device="cuda"
375+
[0, 1, 2, 3, 4, 5, 6, 7], dtype=torch.uint8, device=_DEVICE
374376
) # noqa: E501
375377
data_bits = pack_uint4(data_bits)
376378
else:
@@ -392,7 +394,7 @@ def test_exponent_nan_out(elem_dtype, pack_fp6):
392394
assert not torch.any(torch.isnan(tensor_hp.flatten()[4:]))
393395

394396

395-
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
397+
@pytest.mark.skipif(not torch.accelerator.is_available(), reason="GPU not available")
396398
@pytest.mark.parametrize("elem_dtype", SUPPORTED_ELEM_DTYPES)
397399
def test_ranks(elem_dtype):
398400
"""
@@ -401,11 +403,11 @@ def test_ranks(elem_dtype):
401403
B = 4
402404
shapes = ((B * 4,), (B * 4, 4), (B * 4, 4, 4), (B * 4, 4, 4, 4))
403405
for s in shapes:
404-
tensor_hp = torch.randn(*s, device="cuda", dtype=torch.bfloat16)
406+
tensor_hp = torch.randn(*s, device=_DEVICE, dtype=torch.bfloat16)
405407
_test_mx(tensor_hp, elem_dtype, B)
406408

407409

408-
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
410+
@pytest.mark.skipif(not torch.accelerator.is_available(), reason="GPU not available")
409411
@pytest.mark.parametrize("elem_dtype", SUPPORTED_ELEM_DTYPES)
410412
@pytest.mark.parametrize("B", [1, 4, 32])
411413
def test_block_sizes(elem_dtype, B):
@@ -416,19 +418,19 @@ def test_block_sizes(elem_dtype, B):
416418
pytest.skip("unsupported configuration")
417419
elif B % 4 != 0 and elem_dtype in [DTYPE_FP6_E2M3, DTYPE_FP6_E3M2]:
418420
pytest.skip("unsupported configuration")
419-
tensor_hp = torch.randn(B, device="cuda", dtype=torch.bfloat16)
421+
tensor_hp = torch.randn(B, device=_DEVICE, dtype=torch.bfloat16)
420422
_test_mx(tensor_hp, elem_dtype, B)
421423

422424

423-
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
425+
@pytest.mark.skipif(not torch.accelerator.is_available(), reason="GPU not available")
424426
@pytest.mark.parametrize("elem_dtype", SUPPORTED_ELEM_DTYPES)
425427
def test_transpose(elem_dtype):
426428
"""
427429
Verify that transposing an MX tensor works
428430
"""
429431
M, K = 128, 256
430432
block_size = 32
431-
tensor_hp = torch.randn(M, K, device="cuda", dtype=torch.bfloat16)
433+
tensor_hp = torch.randn(M, K, device=_DEVICE, dtype=torch.bfloat16)
432434
tensor_mx = MXTensor.to_mx(
433435
tensor_hp,
434436
elem_dtype,
@@ -443,18 +445,18 @@ def test_transpose(elem_dtype):
443445
torch.testing.assert_close(tensor_mx_dq_t, tensor_mx_t_dq, atol=0, rtol=0)
444446

445447

446-
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
448+
@pytest.mark.skipif(not torch.accelerator.is_available(), reason="GPU not available")
447449
@pytest.mark.parametrize("elem_dtype", SUPPORTED_ELEM_DTYPES)
448450
def test_view(elem_dtype):
449-
x = torch.randn(1, 2, 4, device="cuda")
451+
x = torch.randn(1, 2, 4, device=_DEVICE)
450452
block_size = 4
451453
x_mx = MXTensor.to_mx(x, elem_dtype, block_size)
452454
x_mx_2 = x_mx.view(2, 4) # noqa: F841
453455

454456

455-
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
457+
@pytest.mark.skipif(not torch.accelerator.is_available(), reason="GPU not available")
456458
def test_clone():
457-
data = torch.randn(8, 8, device="cuda", dtype=torch.bfloat16)
459+
data = torch.randn(8, 8, device=_DEVICE, dtype=torch.bfloat16)
458460
block_size = 4
459461
data_mx = MXTensor.to_mx(data, torch.float8_e4m3fn, block_size)
460462
data_mx_c = data_mx.clone()
@@ -466,11 +468,11 @@ def test_clone():
466468
)
467469

468470

469-
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
471+
@pytest.mark.skipif(not torch.accelerator.is_available(), reason="GPU not available")
470472
@pytest.mark.parametrize("elem_dtype", [DTYPE_FP6_E2M3, DTYPE_FP6_E3M2])
471473
@pytest.mark.parametrize("pack_fp6", [False, True])
472474
def test_fp6_packing(elem_dtype, pack_fp6):
473-
x = torch.randn(1, 2, 4, device="cuda")
475+
x = torch.randn(1, 2, 4, device=_DEVICE)
474476
block_size = 4
475477
x_mx = MXTensor.to_mx(x, elem_dtype, block_size, pack_fp6=pack_fp6)
476478
if pack_fp6:
@@ -481,7 +483,7 @@ def test_fp6_packing(elem_dtype, pack_fp6):
481483
assert x_mx.qdata.shape == expected_packed_shape
482484

483485

484-
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
486+
@pytest.mark.skipif(not torch.accelerator.is_available(), reason="GPU not available")
485487
@pytest.mark.parametrize("elem_dtype", SUPPORTED_ELEM_DTYPES)
486488
@pytest.mark.parametrize("hp_dtype", [torch.float32, torch.bfloat16])
487489
@pytest.mark.parametrize("all_zeros", [False, True])
@@ -490,15 +492,15 @@ def test_to_mx_from_mx_compile_numerics(elem_dtype, hp_dtype, all_zeros):
490492
Verifies that compile does not change numerics of MX casts
491493
"""
492494
if elem_dtype in (torch.float8_e4m3fn, torch.float8_e5m2):
493-
if not is_sm_at_least_89():
495+
if torch.cuda.is_available() and not is_sm_at_least_89():
494496
# separate ifs because flake8 is outsmarting me
495497
pytest.skip("CUDA capability >= 8.9 required for float8 in triton")
496498

497499
shape = 4, 8
498500
if not all_zeros:
499-
x = torch.randn(*shape, dtype=hp_dtype, device="cuda")
501+
x = torch.randn(*shape, dtype=hp_dtype, device=_DEVICE)
500502
else:
501-
x = torch.zeros(*shape, dtype=hp_dtype, device="cuda")
503+
x = torch.zeros(*shape, dtype=hp_dtype, device=_DEVICE)
502504
block_size = 4
503505
to_mx_c = torch.compile(MXTensor.to_mx, fullgraph=True)
504506

@@ -534,9 +536,9 @@ def test_to_mx_from_mx_compile_numerics(elem_dtype, hp_dtype, all_zeros):
534536
torch.testing.assert_close(x_mx_dq, x_mx_c_dq, atol=0, rtol=0)
535537

536538

537-
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
539+
@pytest.mark.skipif(not torch.accelerator.is_available(), reason="GPU not available")
538540
@pytest.mark.skipif(
539-
not is_sm_at_least_89(),
541+
torch.cuda.is_available() and not is_sm_at_least_89(),
540542
reason="float8 in triton requires CUDA capability 8.9 or greater",
541543
)
542544
def test_to_mx_inductor_single_kernel():
@@ -546,15 +548,15 @@ def test_to_mx_inductor_single_kernel():
546548
"""
547549
# TODO(future PR): add fp4 and fp6 here
548550
# TODO(#1773): add swizzled scale format here
549-
x = torch.randn(2048, 2048, dtype=torch.bfloat16, device="cuda")
551+
x = torch.randn(2048, 2048, dtype=torch.bfloat16, device=_DEVICE)
550552
block_size = 32
551553
to_mx_c = torch.compile(MXTensor.to_mx, fullgraph=True)
552554
out, code = run_and_get_code(to_mx_c, x, torch.float8_e4m3fn, block_size)
553555
FileCheck().check("def call(").check_count(".run(", 1, exactly=True).run(code[0])
554556

555557

556-
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
557-
@pytest.mark.skipIf(not is_sm_at_least_90(), "Need sm90+")
558+
@pytest.mark.skipif(not torch.accelerator.is_available(), reason="GPU not available")
559+
@pytest.mark.skipIf(torch.cuda.is_available() and not is_sm_at_least_90(), "Need sm90+")
558560
def test_index_select():
559561
"""
560562
test that `x_0 = x[0]` works when `x` is a 3D `MXTensor`. This is
@@ -564,7 +566,7 @@ def test_index_select():
564566
"""
565567

566568
E, K, N = 128, 256, 512
567-
x = torch.randn(E, N, K, device="cuda", dtype=torch.bfloat16)
569+
x = torch.randn(E, N, K, device=_DEVICE, dtype=torch.bfloat16)
568570
x_mx = MXTensor.to_mx(x, torch.float8_e4m3fn, 32)
569571

570572
x_mx_1 = x_mx[1]
@@ -573,9 +575,9 @@ def test_index_select():
573575
)
574576

575577

576-
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
578+
@pytest.mark.skipif(not torch.accelerator.is_available(), reason="GPU not available")
577579
@pytest.mark.skipif(
578-
not is_sm_at_least_89(),
580+
torch.cuda.is_available() and not is_sm_at_least_89(),
579581
reason="float8 in triton requires CUDA capability 8.9 or greater",
580582
)
581583
def test_cast_to_float8_e4m3fn_saturation_behavior():
@@ -590,7 +592,7 @@ def test_cast_to_float8_e4m3fn_saturation_behavior():
590592
-1 * max_val,
591593
],
592594
dtype=torch.bfloat16,
593-
device="cuda",
595+
device=_DEVICE,
594596
)
595597

596598
# create example data outside the representable range
@@ -600,7 +602,7 @@ def test_cast_to_float8_e4m3fn_saturation_behavior():
600602
-1 * (max_val * 2),
601603
],
602604
dtype=torch.bfloat16,
603-
device="cuda",
605+
device=_DEVICE,
604606
)
605607

606608
# verify that in eager mode PyTorch casting to float8 is unsaturated
@@ -637,14 +639,14 @@ def to_f8(x):
637639
],
638640
)
639641
@pytest.mark.parametrize(
640-
"use_triton_kernel", [False, True] if torch.cuda.is_available() else [False]
642+
"use_triton_kernel", [False, True] if torch.accelerator.is_available() else [False]
641643
)
642644
@pytest.mark.skipif(
643645
not torch_version_at_least("2.8.0"), reason="torch.compile requires PyTorch 2.8+"
644646
)
645647
def test_to_blocked_from_blocked_roundtrip(shape, use_triton_kernel: bool):
646648
rows, cols = shape
647-
device = "cuda" if torch.cuda.is_available() else "cpu"
649+
device = _DEVICE if torch.accelerator.is_available() else "cpu"
648650

649651
original = torch.randint(0, 255, (rows, cols), device=device, dtype=torch.uint8)
650652

@@ -660,8 +662,8 @@ def test_to_blocked_from_blocked_roundtrip(shape, use_triton_kernel: bool):
660662
)
661663

662664

663-
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
664-
@pytest.mark.skipif(not torch_version_at_least("2.8.0"), reason="requires PyTorch 2.8+")
665+
@pytest.mark.skipif(not torch.accelerator.is_available(), reason="GPU not available")
666+
@pytest.mark.skipif(torch.cuda.is_available() and not torch_version_at_least("2.8.0"), reason="requires PyTorch 2.8+")
665667
@pytest.mark.parametrize("transpose", [False, True])
666668
@pytest.mark.parametrize(
667669
"shape",
@@ -676,7 +678,7 @@ def test_scale_shape_matches_qdata(transpose, shape):
676678

677679
block_size = 32
678680

679-
x_hp = torch.randn(*shape, device="cuda")
681+
x_hp = torch.randn(*shape, device=_DEVICE)
680682
x = MXTensor.to_mx(
681683
x_hp,
682684
torch.float8_e4m3fn,
@@ -714,8 +716,8 @@ def test_scale_shape_matches_qdata(transpose, shape):
714716
)
715717

716718

717-
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
718-
@pytest.mark.skipif(not torch_version_at_least("2.8.0"), reason="requires PyTorch 2.8+")
719+
@pytest.mark.skipif(not torch.accelerator.is_available(), reason="GPU not available")
720+
@pytest.mark.skipif(torch.cuda.is_available() and not torch_version_at_least("2.8.0"), reason="requires PyTorch 2.8+")
719721
@pytest.mark.parametrize("elem_dtype", (torch.float8_e4m3fn, torch.float4_e2m1fn_x2))
720722
@pytest.mark.parametrize("transpose", [False, True])
721723
@pytest.mark.parametrize(
@@ -731,7 +733,7 @@ def test_swizzle(elem_dtype, transpose, shape):
731733

732734
block_size = 32
733735

734-
x_hp = torch.randn(*shape, device="cuda")
736+
x_hp = torch.randn(*shape, device=_DEVICE)
735737
x = MXTensor.to_mx(
736738
x_hp,
737739
elem_dtype,

0 commit comments

Comments
 (0)