4949 is_sm_at_least_89 ,
5050 is_sm_at_least_100 ,
5151 torch_version_at_least ,
52+ get_current_accelerator_device ,
5253)
5354
5455torch .manual_seed (0 )
56+ _DEVICE = get_current_accelerator_device ()
5557
5658if not torch_version_at_least ("2.8.0" ):
5759 pytest .skip ("Unsupported PyTorch version" , allow_module_level = True )
@@ -398,9 +400,9 @@ def test_fp6_values(dtype_name):
398400 [
399401 "cpu" ,
400402 pytest .param (
401- "cuda" ,
403+ _DEVICE ,
402404 marks = pytest .mark .skipif (
403- not torch .cuda .is_available (), reason = "CUDA not available"
405+ not torch .accelerator .is_available (), reason = "GPU not available"
404406 ),
405407 ),
406408 ],
@@ -423,11 +425,11 @@ def test_fp6_e3m2_rounding(f32_val, f6_e3m2_enc, device):
423425 assert f6_e3m2_unpacked .item () == (f6_e3m2_enc | 0b100000 )
424426
425427
426- @pytest .mark .skipif (not torch .cuda .is_available (), reason = "CUDA not available" )
428+ @pytest .mark .skipif (not torch .accelerator .is_available (), reason = "GPU not available" )
427429@pytest .mark .skipif (not has_triton (), reason = "unsupported without triton" )
428430def test_fp6_e2m3_pack_unpack ():
429431 orig_vals = torch .Tensor ([[0.0 , 0.5 , 7.5 , - 0.0 ], [- 0.875 , 1.0 , - 6.0 , 0.125 ]]).to (
430- "cuda"
432+ _DEVICE
431433 )
432434 orig_vals_f6_unpacked = f32_to_f6_e2m3_unpacked (orig_vals )
433435 orig_vals_f6_packed = pack_uint6 (orig_vals_f6_unpacked )
@@ -438,11 +440,11 @@ def test_fp6_e2m3_pack_unpack():
438440 assert torch .all (orig_vals_f6_packed_unpacked == orig_vals )
439441
440442
441- @pytest .mark .skipif (not torch .cuda .is_available (), reason = "CUDA not available" )
443+ @pytest .mark .skipif (not torch .accelerator .is_available (), reason = "GPU not available" )
442444@pytest .mark .skipif (not has_triton (), reason = "unsupported without triton" )
443445def test_fp6_e3m2_pack_unpack ():
444446 orig_vals = torch .Tensor ([[0.0 , 5.0 , 28.0 , - 0.0 ], [- 0.25 , 0.1875 , 0.0625 , 8.0 ]]).to (
445- "cuda"
447+ _DEVICE
446448 )
447449 orig_vals_f6_unpacked = f32_to_f6_e3m2_unpacked (orig_vals )
448450 orig_vals_f6_packed = pack_uint6 (orig_vals_f6_unpacked )
@@ -472,13 +474,13 @@ def triton_to_mxfp8_dim0_reference(
472474
473475@pytest .mark .skipif (not has_triton (), reason = "unsupported without triton" )
474476@pytest .mark .skipif (
475- not is_sm_at_least_89 (),
477+ torch . cuda . is_available () and not is_sm_at_least_89 (),
476478 reason = "float8 in triton requires CUDA capability 8.9 or greater" ,
477479)
478480@pytest .mark .parametrize ("M" , (256 , 2048 ))
479481@pytest .mark .parametrize ("K" , (256 , 2048 ))
480482def test_triton_mxfp8_dim1_randn (M , K ):
481- x = torch .randn (M , K , dtype = torch .bfloat16 , device = "cuda" )
483+ x = torch .randn (M , K , dtype = torch .bfloat16 , device = _DEVICE )
482484 x_mx_ref , x_s_ref = triton_to_mxfp8_dim1_reference (x , block_size = 32 )
483485 x_mx_t , x_s_t = triton_to_mxfp8_dim1 (x , inner_block_size = 32 )
484486 torch .testing .assert_close (x_mx_t , x_mx_ref , rtol = 0 , atol = 0 )
@@ -487,13 +489,13 @@ def test_triton_mxfp8_dim1_randn(M, K):
487489
488490@pytest .mark .skipif (not has_triton (), reason = "unsupported without triton" )
489491@pytest .mark .skipif (
490- not is_sm_at_least_100 (),
492+ torch . cuda . is_available () and not is_sm_at_least_100 (),
491493 reason = "mxfp8 requires CUDA capability 10.0 or greater" ,
492494)
493495@pytest .mark .parametrize ("M" , (256 , 2048 , 131072 ))
494496@pytest .mark .parametrize ("K" , (256 , 5120 , 7168 ))
495497def test_triton_mxfp8_dim0_randn (M , K ):
496- x = torch .randn (M , K , dtype = torch .bfloat16 , device = "cuda" )
498+ x = torch .randn (M , K , dtype = torch .bfloat16 , device = _DEVICE )
497499 x_mx_ref , x_s_ref = triton_to_mxfp8_dim0_reference (x , block_size = 32 )
498500 x_mx_t , x_s_t = triton_to_mxfp8_dim0 (x , inner_block_size = 32 )
499501 torch .testing .assert_close (x_mx_t , x_mx_ref , rtol = 0 , atol = 0 )
@@ -502,11 +504,11 @@ def test_triton_mxfp8_dim0_randn(M, K):
502504
503505@pytest .mark .skipif (not has_triton (), reason = "unsupported without triton" )
504506@pytest .mark .skipif (
505- not is_sm_at_least_100 (),
507+ torch . cuda . is_available () and not is_sm_at_least_100 (),
506508 reason = "mxfp8 requires CUDA capability 10.0 or greater" ,
507509)
508510def test_triton_mxfp8_dim0_zeros ():
509- x = torch .zeros (8192 , 5120 , dtype = torch .bfloat16 , device = "cuda" )
511+ x = torch .zeros (8192 , 5120 , dtype = torch .bfloat16 , device = _DEVICE )
510512 x_mx_ref , x_s_ref = triton_to_mxfp8_dim0_reference (x , block_size = 32 )
511513 x_mx_t , x_s_t = triton_to_mxfp8_dim0 (x , inner_block_size = 32 )
512514 assert not x_mx_t .isnan ().any (), "quantized tensor should not contain NaNs"
@@ -516,14 +518,14 @@ def test_triton_mxfp8_dim0_zeros():
516518
517519@pytest .mark .skipif (not has_triton (), reason = "unsupported without triton" )
518520@pytest .mark .skipif (
519- not is_sm_at_least_100 (),
521+ torch . cuda . is_available () and not is_sm_at_least_100 (),
520522 reason = "mxfp8 requires CUDA capability 10.0 or greater" ,
521523)
522524@pytest .mark .parametrize ("M" , (256 , 2048 , 131072 ))
523525@pytest .mark .parametrize ("K" , (256 , 5120 , 7168 ))
524526@pytest .mark .parametrize ("orig_dtype" , (torch .float32 , torch .bfloat16 ))
525527def test_triton_mxfp8_dequant_dim0 (M , K , orig_dtype ):
526- x = torch .zeros (M , K , dtype = orig_dtype , device = "cuda" )
528+ x = torch .zeros (M , K , dtype = orig_dtype , device = _DEVICE )
527529 block_size = 32
528530 x_data , x_scales = triton_to_mxfp8_dim0_reference (x , block_size = 32 )
529531 hp_ref = to_dtype (
@@ -537,7 +539,7 @@ def test_triton_mxfp8_dequant_dim0(M, K, orig_dtype):
537539 torch .testing .assert_close (hp_t , hp_ref , rtol = 0 , atol = 0 )
538540
539541
540- @pytest .mark .skipif (not torch .cuda .is_available (), reason = "CUDA not available" )
542+ @pytest .mark .skipif (not torch .accelerator .is_available (), reason = "GPU not available" )
541543@pytest .mark .parametrize (
542544 "shape" ,
543545 [
@@ -552,14 +554,14 @@ def test_triton_mxfp8_dequant_dim0(M, K, orig_dtype):
552554 ],
553555)
554556def test_rearrange (shape ):
555- scales = torch .randint (256 , size = shape , device = "cuda" , dtype = torch .uint8 )
557+ scales = torch .randint (256 , size = shape , device = _DEVICE , dtype = torch .uint8 )
556558 eager = to_blocked (scales , False )
557559 triton = to_blocked (scales , True )
558560 torch .testing .assert_close (eager , triton , atol = 0 , rtol = 0 )
559561
560562
561563@pytest .mark .skipif (
562- not is_sm_at_least_100 (),
564+ torch . cuda . is_available () and not is_sm_at_least_100 (),
563565 reason = "MXFP8 requires CUDA capability 10.0 or greater" ,
564566)
565567@pytest .mark .parametrize ("M" , (32 , 64 , 2048 ))
@@ -578,7 +580,7 @@ def test_cuda_mx_dim1_numerics(M, K, input_dtype, scaling_mode):
578580
579581 # Use disinct incrementing values from 0 to M*K-1 to make debugging easier.
580582 x = (
581- torch .arange (0 , M * K , dtype = input_dtype , device = "cuda" )
583+ torch .arange (0 , M * K , dtype = input_dtype , device = _DEVICE )
582584 .reshape (M , K )
583585 .contiguous ()
584586 )
@@ -607,7 +609,7 @@ def test_cuda_mx_dim1_numerics(M, K, input_dtype, scaling_mode):
607609
608610
609611@pytest .mark .skipif (
610- not is_sm_at_least_100 (),
612+ torch . cuda . is_available () and not is_sm_at_least_100 (),
611613 reason = "MXFP8 requires CUDA capability 10.0 or greater" ,
612614)
613615def test_cuda_mx_dim0_not_supported ():
@@ -616,7 +618,7 @@ def test_cuda_mx_dim0_not_supported():
616618 M , K = 64 , 64
617619 block_size = 32
618620 x = (
619- torch .arange (0 , M * K , dtype = torch .bfloat16 , device = "cuda" )
621+ torch .arange (0 , M * K , dtype = torch .bfloat16 , device = _DEVICE )
620622 .reshape (M , K )
621623 .contiguous ()
622624 )
@@ -631,15 +633,15 @@ def test_cuda_mx_dim0_not_supported():
631633
632634
633635@pytest .mark .skipif (
634- not is_sm_at_least_100 (),
636+ torch . cuda . is_available () and not is_sm_at_least_100 (),
635637 reason = "MXFP8 requires CUDA capability 10.0 or greater" ,
636638)
637639def test_cuda_mx_dim1_invalid_block_size ():
638640 from torchao .prototype import mxfp8_cuda
639641
640642 M , K = 64 , 64
641643 x = (
642- torch .arange (0 , M * K , dtype = torch .bfloat16 , device = "cuda" )
644+ torch .arange (0 , M * K , dtype = torch .bfloat16 , device = _DEVICE )
643645 .reshape (M , K )
644646 .contiguous ()
645647 )
0 commit comments