3030 is_sm_at_least_89 ,
3131 is_sm_at_least_90 ,
3232 torch_version_at_least ,
33+ get_current_accelerator_device ,
3334)
3435
3536torch .manual_seed (2 )
37+ _DEVICE = get_current_accelerator_device ()
3638
3739if not torch_version_at_least ("2.8.0" ):
3840 pytest .skip ("Unsupported PyTorch version" , allow_module_level = True )
@@ -81,42 +83,42 @@ def assert_sqnr_gt_threshold(orig, new, threshold):
8183 assert data_mx .scale .shape == (* prev_dims , K // block_size )
8284
8385
84- @pytest .mark .skipif (not torch .cuda .is_available (), reason = "CUDA not available" )
86+ @pytest .mark .skipif (not torch .accelerator .is_available (), reason = "GPU not available" )
8587@pytest .mark .parametrize ("elem_dtype" , SUPPORTED_ELEM_DTYPES )
8688def test_hello_world (elem_dtype ):
87- data = torch .randn (8 , 8 , device = "cuda" , dtype = torch .bfloat16 )
89+ data = torch .randn (8 , 8 , device = _DEVICE , dtype = torch .bfloat16 )
8890 block_size = 4
8991 _test_mx (data , elem_dtype , block_size )
9092
9193
92- @pytest .mark .skipif (not torch .cuda .is_available (), reason = "CUDA not available" )
94+ @pytest .mark .skipif (not torch .accelerator .is_available (), reason = "GPU not available" )
9395@pytest .mark .parametrize ("scale_calculation_mode" , [s for s in ScaleCalculationMode ])
9496@pytest .mark .parametrize ("elem_dtype" , SUPPORTED_ELEM_DTYPES )
9597def test_realistic_numerics (elem_dtype , scale_calculation_mode ):
96- data = torch .randn (128 , 128 , device = "cuda" , dtype = torch .bfloat16 )
98+ data = torch .randn (128 , 128 , device = _DEVICE , dtype = torch .bfloat16 )
9799 block_size = 32
98100 _test_mx (data , elem_dtype , block_size , scale_calculation_mode )
99101
100102
101- @pytest .mark .skipif (not torch .cuda .is_available (), reason = "CUDA not available" )
103+ @pytest .mark .skipif (not torch .accelerator .is_available (), reason = "GPU not available" )
102104@pytest .mark .parametrize ("elem_dtype" , SUPPORTED_ELEM_DTYPES )
103105def test_all_zeros (elem_dtype ):
104- data = torch .zeros (4 , 4 , device = "cuda" , dtype = torch .bfloat16 )
106+ data = torch .zeros (4 , 4 , device = _DEVICE , dtype = torch .bfloat16 )
105107 block_size = 4
106108 _test_mx (data , elem_dtype , block_size )
107109
108110
109- @pytest .mark .skipif (not torch .cuda .is_available (), reason = "CUDA not available" )
111+ @pytest .mark .skipif (not torch .accelerator .is_available (), reason = "GPU not available" )
110112@pytest .mark .parametrize ("elem_dtype" , SUPPORTED_ELEM_DTYPES )
111113def test_some_zeros (elem_dtype ):
112- data = torch .randn (4 , 4 , device = "cuda" , dtype = torch .bfloat16 )
114+ data = torch .randn (4 , 4 , device = _DEVICE , dtype = torch .bfloat16 )
113115 data [0 , :] = 0.0
114116 data [:, 2 ] = 0.0
115117 block_size = 4
116118 _test_mx (data , elem_dtype , block_size )
117119
118120
119- @pytest .mark .skipif (not torch .cuda .is_available (), reason = "CUDA not available" )
121+ @pytest .mark .skipif (not torch .accelerator .is_available (), reason = "GPU not available" )
120122def test_to_mx_rceil ():
121123 # nan
122124 # fmt: off
@@ -325,23 +327,23 @@ def test_to_mx_rceil():
325327 torch .testing .assert_close (data_mx .qdata , ground_truth_fp8 )
326328
327329
328- @pytest .mark .skipif (not torch .cuda .is_available (), reason = "CUDA not available" )
330+ @pytest .mark .skipif (not torch .accelerator .is_available (), reason = "GPU not available" )
329331@pytest .mark .parametrize ("elem_dtype" , SUPPORTED_ELEM_DTYPES )
330332def test_exponent_nan_in (elem_dtype ):
331333 """
332334 If high precision block values has a NaN, the exponent block
333335 value is set to is NaN
334336 """
335337 tensor_hp = torch .tensor (
336- [float ("nan" ), 1 , 2 , 3 , 4 , 5 , 6 , 7 ], device = "cuda" , dtype = torch .bfloat16
338+ [float ("nan" ), 1 , 2 , 3 , 4 , 5 , 6 , 7 ], device = _DEVICE , dtype = torch .bfloat16
337339 )
338340 block_size = 4
339341 tensor_mx = MXTensor .to_mx (tensor_hp , elem_dtype , block_size )
340342 assert torch .all (torch .isnan (tensor_mx .scale [0 ]))
341343 assert not torch .any (torch .isnan (tensor_mx .scale [1 :]))
342344
343345
344- @pytest .mark .skipif (not torch .cuda .is_available (), reason = "CUDA not available" )
346+ @pytest .mark .skipif (not torch .accelerator .is_available (), reason = "GPU not available" )
345347@pytest .mark .parametrize ("elem_dtype" , SUPPORTED_ELEM_DTYPES )
346348@pytest .mark .parametrize ("pack_fp6" , [False , True ])
347349def test_exponent_nan_out (elem_dtype , pack_fp6 ):
@@ -352,25 +354,25 @@ def test_exponent_nan_out(elem_dtype, pack_fp6):
352354 pytest .skip ("invalid configuration" )
353355
354356 scale_e8m0 = torch .tensor (
355- [float ("nan" ), 1.0 ], dtype = torch .float8_e8m0fnu , device = "cuda"
357+ [float ("nan" ), 1.0 ], dtype = torch .float8_e8m0fnu , device = _DEVICE
356358 )
357359
358360 block_size = 4
359361
360362 if elem_dtype in (torch .float8_e4m3fn , torch .float8_e5m2 ):
361363 data_bits = torch .tensor (
362- [0 , 1 , 2 , 3 , 4 , 5 , 6 , 7 ], dtype = elem_dtype , device = "cuda"
364+ [0 , 1 , 2 , 3 , 4 , 5 , 6 , 7 ], dtype = elem_dtype , device = _DEVICE
363365 ) # noqa: E501
364366 elif elem_dtype in (DTYPE_FP6_E2M3 , DTYPE_FP6_E3M2 ):
365367 data_bits = torch .tensor (
366- [0 , 1 , 2 , 3 , 4 , 5 , 6 , 7 ], dtype = torch .uint8 , device = "cuda"
368+ [0 , 1 , 2 , 3 , 4 , 5 , 6 , 7 ], dtype = torch .uint8 , device = _DEVICE
367369 ) # noqa: E501
368370 if pack_fp6 :
369371 data_bits = data_bits .reshape (- 1 , block_size )
370372 data_bits = pack_uint6 (data_bits )
371373 elif elem_dtype == torch .float4_e2m1fn_x2 :
372374 data_bits = torch .tensor (
373- [0 , 1 , 2 , 3 , 4 , 5 , 6 , 7 ], dtype = torch .uint8 , device = "cuda"
375+ [0 , 1 , 2 , 3 , 4 , 5 , 6 , 7 ], dtype = torch .uint8 , device = _DEVICE
374376 ) # noqa: E501
375377 data_bits = pack_uint4 (data_bits )
376378 else :
@@ -392,7 +394,7 @@ def test_exponent_nan_out(elem_dtype, pack_fp6):
392394 assert not torch .any (torch .isnan (tensor_hp .flatten ()[4 :]))
393395
394396
395- @pytest .mark .skipif (not torch .cuda .is_available (), reason = "CUDA not available" )
397+ @pytest .mark .skipif (not torch .accelerator .is_available (), reason = "GPU not available" )
396398@pytest .mark .parametrize ("elem_dtype" , SUPPORTED_ELEM_DTYPES )
397399def test_ranks (elem_dtype ):
398400 """
@@ -401,11 +403,11 @@ def test_ranks(elem_dtype):
401403 B = 4
402404 shapes = ((B * 4 ,), (B * 4 , 4 ), (B * 4 , 4 , 4 ), (B * 4 , 4 , 4 , 4 ))
403405 for s in shapes :
404- tensor_hp = torch .randn (* s , device = "cuda" , dtype = torch .bfloat16 )
406+ tensor_hp = torch .randn (* s , device = _DEVICE , dtype = torch .bfloat16 )
405407 _test_mx (tensor_hp , elem_dtype , B )
406408
407409
408- @pytest .mark .skipif (not torch .cuda .is_available (), reason = "CUDA not available" )
410+ @pytest .mark .skipif (not torch .accelerator .is_available (), reason = "GPU not available" )
409411@pytest .mark .parametrize ("elem_dtype" , SUPPORTED_ELEM_DTYPES )
410412@pytest .mark .parametrize ("B" , [1 , 4 , 32 ])
411413def test_block_sizes (elem_dtype , B ):
@@ -416,19 +418,19 @@ def test_block_sizes(elem_dtype, B):
416418 pytest .skip ("unsupported configuration" )
417419 elif B % 4 != 0 and elem_dtype in [DTYPE_FP6_E2M3 , DTYPE_FP6_E3M2 ]:
418420 pytest .skip ("unsupported configuration" )
419- tensor_hp = torch .randn (B , device = "cuda" , dtype = torch .bfloat16 )
421+ tensor_hp = torch .randn (B , device = _DEVICE , dtype = torch .bfloat16 )
420422 _test_mx (tensor_hp , elem_dtype , B )
421423
422424
423- @pytest .mark .skipif (not torch .cuda .is_available (), reason = "CUDA not available" )
425+ @pytest .mark .skipif (not torch .accelerator .is_available (), reason = "GPU not available" )
424426@pytest .mark .parametrize ("elem_dtype" , SUPPORTED_ELEM_DTYPES )
425427def test_transpose (elem_dtype ):
426428 """
427429 Verify that transposing an MX tensor works
428430 """
429431 M , K = 128 , 256
430432 block_size = 32
431- tensor_hp = torch .randn (M , K , device = "cuda" , dtype = torch .bfloat16 )
433+ tensor_hp = torch .randn (M , K , device = _DEVICE , dtype = torch .bfloat16 )
432434 tensor_mx = MXTensor .to_mx (
433435 tensor_hp ,
434436 elem_dtype ,
@@ -443,18 +445,18 @@ def test_transpose(elem_dtype):
443445 torch .testing .assert_close (tensor_mx_dq_t , tensor_mx_t_dq , atol = 0 , rtol = 0 )
444446
445447
446- @pytest .mark .skipif (not torch .cuda .is_available (), reason = "CUDA not available" )
448+ @pytest .mark .skipif (not torch .accelerator .is_available (), reason = "GPU not available" )
447449@pytest .mark .parametrize ("elem_dtype" , SUPPORTED_ELEM_DTYPES )
448450def test_view (elem_dtype ):
449- x = torch .randn (1 , 2 , 4 , device = "cuda" )
451+ x = torch .randn (1 , 2 , 4 , device = _DEVICE )
450452 block_size = 4
451453 x_mx = MXTensor .to_mx (x , elem_dtype , block_size )
452454 x_mx_2 = x_mx .view (2 , 4 ) # noqa: F841
453455
454456
455- @pytest .mark .skipif (not torch .cuda .is_available (), reason = "CUDA not available" )
457+ @pytest .mark .skipif (not torch .accelerator .is_available (), reason = "GPU not available" )
456458def test_clone ():
457- data = torch .randn (8 , 8 , device = "cuda" , dtype = torch .bfloat16 )
459+ data = torch .randn (8 , 8 , device = _DEVICE , dtype = torch .bfloat16 )
458460 block_size = 4
459461 data_mx = MXTensor .to_mx (data , torch .float8_e4m3fn , block_size )
460462 data_mx_c = data_mx .clone ()
@@ -466,11 +468,11 @@ def test_clone():
466468 )
467469
468470
469- @pytest .mark .skipif (not torch .cuda .is_available (), reason = "CUDA not available" )
471+ @pytest .mark .skipif (not torch .accelerator .is_available (), reason = "GPU not available" )
470472@pytest .mark .parametrize ("elem_dtype" , [DTYPE_FP6_E2M3 , DTYPE_FP6_E3M2 ])
471473@pytest .mark .parametrize ("pack_fp6" , [False , True ])
472474def test_fp6_packing (elem_dtype , pack_fp6 ):
473- x = torch .randn (1 , 2 , 4 , device = "cuda" )
475+ x = torch .randn (1 , 2 , 4 , device = _DEVICE )
474476 block_size = 4
475477 x_mx = MXTensor .to_mx (x , elem_dtype , block_size , pack_fp6 = pack_fp6 )
476478 if pack_fp6 :
@@ -481,7 +483,7 @@ def test_fp6_packing(elem_dtype, pack_fp6):
481483 assert x_mx .qdata .shape == expected_packed_shape
482484
483485
484- @pytest .mark .skipif (not torch .cuda .is_available (), reason = "CUDA not available" )
486+ @pytest .mark .skipif (not torch .accelerator .is_available (), reason = "GPU not available" )
485487@pytest .mark .parametrize ("elem_dtype" , SUPPORTED_ELEM_DTYPES )
486488@pytest .mark .parametrize ("hp_dtype" , [torch .float32 , torch .bfloat16 ])
487489@pytest .mark .parametrize ("all_zeros" , [False , True ])
@@ -490,15 +492,15 @@ def test_to_mx_from_mx_compile_numerics(elem_dtype, hp_dtype, all_zeros):
490492 Verifies that compile does not change numerics of MX casts
491493 """
492494 if elem_dtype in (torch .float8_e4m3fn , torch .float8_e5m2 ):
493- if not is_sm_at_least_89 ():
495+ if torch . cuda . is_available () and not is_sm_at_least_89 ():
494496 # separate ifs because flake8 is outsmarting me
495497 pytest .skip ("CUDA capability >= 8.9 required for float8 in triton" )
496498
497499 shape = 4 , 8
498500 if not all_zeros :
499- x = torch .randn (* shape , dtype = hp_dtype , device = "cuda" )
501+ x = torch .randn (* shape , dtype = hp_dtype , device = _DEVICE )
500502 else :
501- x = torch .zeros (* shape , dtype = hp_dtype , device = "cuda" )
503+ x = torch .zeros (* shape , dtype = hp_dtype , device = _DEVICE )
502504 block_size = 4
503505 to_mx_c = torch .compile (MXTensor .to_mx , fullgraph = True )
504506
@@ -534,9 +536,9 @@ def test_to_mx_from_mx_compile_numerics(elem_dtype, hp_dtype, all_zeros):
534536 torch .testing .assert_close (x_mx_dq , x_mx_c_dq , atol = 0 , rtol = 0 )
535537
536538
537- @pytest .mark .skipif (not torch .cuda .is_available (), reason = "CUDA not available" )
539+ @pytest .mark .skipif (not torch .accelerator .is_available (), reason = "GPU not available" )
538540@pytest .mark .skipif (
539- not is_sm_at_least_89 (),
541+ torch . cuda . is_available () and not is_sm_at_least_89 (),
540542 reason = "float8 in triton requires CUDA capability 8.9 or greater" ,
541543)
542544def test_to_mx_inductor_single_kernel ():
@@ -546,15 +548,15 @@ def test_to_mx_inductor_single_kernel():
546548 """
547549 # TODO(future PR): add fp4 and fp6 here
548550 # TODO(#1773): add swizzled scale format here
549- x = torch .randn (2048 , 2048 , dtype = torch .bfloat16 , device = "cuda" )
551+ x = torch .randn (2048 , 2048 , dtype = torch .bfloat16 , device = _DEVICE )
550552 block_size = 32
551553 to_mx_c = torch .compile (MXTensor .to_mx , fullgraph = True )
552554 out , code = run_and_get_code (to_mx_c , x , torch .float8_e4m3fn , block_size )
553555 FileCheck ().check ("def call(" ).check_count (".run(" , 1 , exactly = True ).run (code [0 ])
554556
555557
556- @pytest .mark .skipif (not torch .cuda .is_available (), reason = "CUDA not available" )
557- @pytest .mark .skipIf (not is_sm_at_least_90 (), "Need sm90+" )
558+ @pytest .mark .skipif (not torch .accelerator .is_available (), reason = "GPU not available" )
559+ @pytest .mark .skipIf (torch . cuda . is_available () and not is_sm_at_least_90 (), "Need sm90+" )
558560def test_index_select ():
559561 """
560562 test that `x_0 = x[0]` works when `x` is a 3D `MXTensor`. This is
@@ -564,7 +566,7 @@ def test_index_select():
564566 """
565567
566568 E , K , N = 128 , 256 , 512
567- x = torch .randn (E , N , K , device = "cuda" , dtype = torch .bfloat16 )
569+ x = torch .randn (E , N , K , device = _DEVICE , dtype = torch .bfloat16 )
568570 x_mx = MXTensor .to_mx (x , torch .float8_e4m3fn , 32 )
569571
570572 x_mx_1 = x_mx [1 ]
@@ -573,9 +575,9 @@ def test_index_select():
573575 )
574576
575577
576- @pytest .mark .skipif (not torch .cuda .is_available (), reason = "CUDA not available" )
578+ @pytest .mark .skipif (not torch .accelerator .is_available (), reason = "GPU not available" )
577579@pytest .mark .skipif (
578- not is_sm_at_least_89 (),
580+ torch . cuda . is_available () and not is_sm_at_least_89 (),
579581 reason = "float8 in triton requires CUDA capability 8.9 or greater" ,
580582)
581583def test_cast_to_float8_e4m3fn_saturation_behavior ():
@@ -590,7 +592,7 @@ def test_cast_to_float8_e4m3fn_saturation_behavior():
590592 - 1 * max_val ,
591593 ],
592594 dtype = torch .bfloat16 ,
593- device = "cuda" ,
595+ device = _DEVICE ,
594596 )
595597
596598 # create example data outside the representable range
@@ -600,7 +602,7 @@ def test_cast_to_float8_e4m3fn_saturation_behavior():
600602 - 1 * (max_val * 2 ),
601603 ],
602604 dtype = torch .bfloat16 ,
603- device = "cuda" ,
605+ device = _DEVICE ,
604606 )
605607
606608 # verify that in eager mode PyTorch casting to float8 is unsaturated
@@ -637,14 +639,14 @@ def to_f8(x):
637639 ],
638640)
639641@pytest .mark .parametrize (
640- "use_triton_kernel" , [False , True ] if torch .cuda .is_available () else [False ]
642+ "use_triton_kernel" , [False , True ] if torch .accelerator .is_available () else [False ]
641643)
642644@pytest .mark .skipif (
643645 not torch_version_at_least ("2.8.0" ), reason = "torch.compile requires PyTorch 2.8+"
644646)
645647def test_to_blocked_from_blocked_roundtrip (shape , use_triton_kernel : bool ):
646648 rows , cols = shape
647- device = "cuda" if torch .cuda .is_available () else "cpu"
649+ device = _DEVICE if torch .accelerator .is_available () else "cpu"
648650
649651 original = torch .randint (0 , 255 , (rows , cols ), device = device , dtype = torch .uint8 )
650652
@@ -660,8 +662,8 @@ def test_to_blocked_from_blocked_roundtrip(shape, use_triton_kernel: bool):
660662 )
661663
662664
663- @pytest .mark .skipif (not torch .cuda .is_available (), reason = "CUDA not available" )
664- @pytest .mark .skipif (not torch_version_at_least ("2.8.0" ), reason = "requires PyTorch 2.8+" )
665+ @pytest .mark .skipif (not torch .accelerator .is_available (), reason = "GPU not available" )
666+ @pytest .mark .skipif (torch . cuda . is_available () and not torch_version_at_least ("2.8.0" ), reason = "requires PyTorch 2.8+" )
665667@pytest .mark .parametrize ("transpose" , [False , True ])
666668@pytest .mark .parametrize (
667669 "shape" ,
@@ -676,7 +678,7 @@ def test_scale_shape_matches_qdata(transpose, shape):
676678
677679 block_size = 32
678680
679- x_hp = torch .randn (* shape , device = "cuda" )
681+ x_hp = torch .randn (* shape , device = _DEVICE )
680682 x = MXTensor .to_mx (
681683 x_hp ,
682684 torch .float8_e4m3fn ,
@@ -714,8 +716,8 @@ def test_scale_shape_matches_qdata(transpose, shape):
714716 )
715717
716718
717- @pytest .mark .skipif (not torch .cuda .is_available (), reason = "CUDA not available" )
718- @pytest .mark .skipif (not torch_version_at_least ("2.8.0" ), reason = "requires PyTorch 2.8+" )
719+ @pytest .mark .skipif (not torch .accelerator .is_available (), reason = "GPU not available" )
720+ @pytest .mark .skipif (torch . cuda . is_available () and not torch_version_at_least ("2.8.0" ), reason = "requires PyTorch 2.8+" )
719721@pytest .mark .parametrize ("elem_dtype" , (torch .float8_e4m3fn , torch .float4_e2m1fn_x2 ))
720722@pytest .mark .parametrize ("transpose" , [False , True ])
721723@pytest .mark .parametrize (
@@ -731,7 +733,7 @@ def test_swizzle(elem_dtype, transpose, shape):
731733
732734 block_size = 32
733735
734- x_hp = torch .randn (* shape , device = "cuda" )
736+ x_hp = torch .randn (* shape , device = _DEVICE )
735737 x = MXTensor .to_mx (
736738 x_hp ,
737739 elem_dtype ,
0 commit comments