3939from torchao .utils import (
4040 check_cpu_version ,
4141 check_xpu_version ,
42+ get_current_accelerator_device ,
4243 is_fbcode ,
4344 is_ROCM ,
4445 is_sm_at_least_89 ,
4748is_cusparselt_available = (
4849 hasattr (torch .backends , "cusparselt" ) and torch .backends .cusparselt .is_available ()
4950)
51+ _DEVICE = get_current_accelerator_device ()
5052
5153
5254def get_quantization_functions (
53- do_sparse : bool , do_int4 : bool , device : str = "cuda" , int4_zp_int : bool = False
55+ do_sparse : bool , do_int4 : bool , device : str = _DEVICE , int4_zp_int : bool = False
5456):
5557 base_functions = [
5658 Int8WeightOnlyConfig (),
@@ -105,9 +107,9 @@ class TestAffineQuantized(TestCase):
105107 ["xpu" ] if torch .xpu .is_available () else []
106108 )
107109
108- @unittest .skipIf (not torch .cuda .is_available (), "Need CUDA available" )
110+ @unittest .skipIf (not torch .accelerator .is_available (), "Need GPU available" )
109111 def test_tensor_core_layout_transpose (self ):
110- linear = torch .nn .Linear (128 , 256 , dtype = torch .bfloat16 , device = "cuda" )
112+ linear = torch .nn .Linear (128 , 256 , dtype = torch .bfloat16 , device = _DEVICE )
111113 t = linear .weight
112114 shape = t .shape
113115 apply_int4_weight_only_quant = Int4WeightOnlyConfig (group_size = 32 , version = 1 )
@@ -169,7 +171,7 @@ def _apply(module, config_or_subclass_inserter):
169171 ql = _apply (linear , apply_quant )
170172 ql .to (device )
171173
172- @unittest .skipIf (not torch .cuda .is_available (), "Need CUDA available" )
174+ @unittest .skipIf (not torch .accelerator .is_available (), "Need GPU available" )
173175 def test_register_new_dispatch (self ):
174176 from torchao .dtypes import AffineQuantizedTensor
175177 from torchao .dtypes .affine_quantized_tensor_ops import (
@@ -206,10 +208,10 @@ def apply_uint6_weight_only_quant(linear):
206208 )
207209 return linear
208210
209- linear = torch .nn .Linear (128 , 256 , dtype = torch .bfloat16 , device = "cuda" )
211+ linear = torch .nn .Linear (128 , 256 , dtype = torch .bfloat16 , device = _DEVICE )
210212 apply_uint6_weight_only_quant (linear )
211213
212- example_input = torch .randn (1 , 128 , dtype = torch .bfloat16 , device = "cuda" )
214+ example_input = torch .randn (1 , 128 , dtype = torch .bfloat16 , device = _DEVICE )
213215 with self .assertRaisesRegex (
214216 AssertionError , "dispatching to my impl for uint6 weight only quant"
215217 ):
@@ -234,11 +236,11 @@ def test_print_quantized_module(self):
234236
235237 @unittest .skipIf (not torch .cuda .is_available (), "Need CUDA available" )
236238 @common_utils .parametrize (
237- "apply_quant" , get_quantization_functions (False , True , "cuda" , False )
239+ "apply_quant" , get_quantization_functions (False , True , _DEVICE , False )
238240 )
239241 def test_test_copy__apply (self , apply_quant ):
240- linear = torch .nn .Linear (128 , 256 , dtype = torch .bfloat16 , device = "cuda" )
241- linear2 = torch .nn .Linear (128 , 256 , dtype = torch .bfloat16 , device = "cuda" )
242+ linear = torch .nn .Linear (128 , 256 , dtype = torch .bfloat16 , device = _DEVICE )
243+ linear2 = torch .nn .Linear (128 , 256 , dtype = torch .bfloat16 , device = _DEVICE )
242244
243245 if isinstance (apply_quant , AOBaseConfig ):
244246 quantize_ (linear , apply_quant )
@@ -249,20 +251,20 @@ def test_test_copy__apply(self, apply_quant):
249251 ql = apply_quant (linear )
250252 ql2 = apply_quant (linear2 )
251253
252- example_input = torch .randn (1 , 128 , dtype = torch .bfloat16 , device = "cuda" )
254+ example_input = torch .randn (1 , 128 , dtype = torch .bfloat16 , device = _DEVICE )
253255 output = ql (example_input )
254256 ql2 .weight .copy_ (ql .weight )
255257 ql2 .bias = ql .bias
256258 output2 = ql2 (example_input )
257259 self .assertEqual (output , output2 )
258260
259- @unittest .skipIf (not torch .cuda .is_available (), "Need CUDA available" )
261+ @unittest .skipIf (not torch .accelerator .is_available (), "Need GPU available" )
260262 @common_utils .parametrize (
261- "apply_quant" , get_quantization_functions (False , True , "cuda" , False )
263+ "apply_quant" , get_quantization_functions (False , True , _DEVICE , False )
262264 )
263265 def test_copy__mismatch_metadata (self , apply_quant ):
264- linear = torch .nn .Linear (128 , 256 , dtype = torch .bfloat16 , device = "cuda" )
265- linear2 = torch .nn .Linear (128 , 512 , dtype = torch .bfloat16 , device = "cuda" )
266+ linear = torch .nn .Linear (128 , 256 , dtype = torch .bfloat16 , device = _DEVICE )
267+ linear2 = torch .nn .Linear (128 , 512 , dtype = torch .bfloat16 , device = _DEVICE )
266268
267269 if isinstance (apply_quant , AOBaseConfig ):
268270 quantize_ (linear , apply_quant )
@@ -336,7 +338,7 @@ def test_alias(self, device, dtype):
336338 quantize_ (dummy , Int8DynamicActivationInt8WeightConfig ())
337339 _ = dummy .weight [...]
338340
339- @common_utils .parametrize ("device" , ["cuda" ])
341+ @common_utils .parametrize ("device" , [_DEVICE ])
340342 @common_utils .parametrize ("dtype" , [torch .bfloat16 ])
341343 @skip_if_no_cuda ()
342344 @skip_if_rocm ("ROCm enablement in progress" )
@@ -350,9 +352,9 @@ def test_slice_int4wo(self, device, dtype):
350352 _ = dummy .weight .narrow (0 , 0 , 64 )
351353 _ = dummy .weight .narrow (1 , 0 , 128 )
352354
353- @common_utils .parametrize ("device" , ["cuda" ])
355+ @common_utils .parametrize ("device" , [_DEVICE ])
354356 @common_utils .parametrize ("dtype" , [torch .float16 , torch .bfloat16 ])
355- @skip_if_no_cuda ( )
357+ @unittest . skipIf ( not torch . accelerator . is_available (), "Need GPU available" )
356358 @skip_if_no_gemlite ()
357359 def test_slice_gemlite (self , device , dtype ):
358360 # in_feature not divisible by 1024
@@ -433,7 +435,7 @@ def dequant(input_layer, in_features, orig_shape):
433435 )
434436 self .assertEqual ((W_slice_ref - W_slice ).abs ().mean ().item (), 0 )
435437
436- @common_utils .parametrize ("device" , ["cuda" ])
438+ @common_utils .parametrize ("device" , [_DEVICE ])
437439 @common_utils .parametrize ("dtype" , [torch .bfloat16 ])
438440 def test_matmul (self , device , dtype ):
439441 x = torch .randn (53 , 2048 )
@@ -450,14 +452,14 @@ def test_matmul(self, device, dtype):
450452 # make sure it runs
451453 torch .matmul (x , w .t ())
452454
453- @common_utils .parametrize ("device" , ["cuda" ])
455+ @common_utils .parametrize ("device" , [_DEVICE ])
454456 @common_utils .parametrize ("dtype" , [torch .bfloat16 ])
455457 @skip_if_no_cuda ()
456458 @skip_if_rocm ("ROCm enablement in progress" )
457459 def test_slice_and_copy_int4wo (self , device , dtype ):
458- l = torch .nn .Linear (1024 , 1024 ).to ("cuda" ).to (torch .bfloat16 )
460+ l = torch .nn .Linear (1024 , 1024 ).to (_DEVICE ).to (torch .bfloat16 )
459461 l .weight = torch .nn .Parameter (
460- torch .zeros (1024 , 1024 , dtype = torch .bfloat16 , device = "cuda" )
462+ torch .zeros (1024 , 1024 , dtype = torch .bfloat16 , device = _DEVICE )
461463 )
462464 quantize_ (l , Int4WeightOnlyConfig (version = 1 ))
463465 param = l .weight
@@ -474,7 +476,7 @@ def test_slice_and_copy_int4wo(self, device, dtype):
474476 assert param .data .dequantize ()[0 ][0 ] == 0
475477
476478 # dummy_l has random input (shouldn't be 0)
477- dummy_l = torch .nn .Linear (1024 , 1024 ).to ("cuda" ).to (torch .bfloat16 )
479+ dummy_l = torch .nn .Linear (1024 , 1024 ).to (_DEVICE ).to (torch .bfloat16 )
478480 quantize_ (dummy_l , Int4WeightOnlyConfig (version = 1 ))
479481 quantized = dummy_l .weight
480482 quantized = quantized .narrow (0 , 0 , 512 )
@@ -484,9 +486,9 @@ def test_slice_and_copy_int4wo(self, device, dtype):
484486 # making sure param.data is updated
485487 assert param .data .dequantize ()[0 ][0 ] != 0
486488
487- @common_utils .parametrize ("device" , ["cuda" ])
489+ @common_utils .parametrize ("device" , [_DEVICE ])
488490 @common_utils .parametrize ("dtype" , [torch .bfloat16 ])
489- @skip_if_no_cuda ( )
491+ @unittest . skipIf ( not torch . accelerator . is_available (), "Need GPU available" )
490492 @skip_if_rocm ("ROCm enablement in progress" )
491493 def test_mm_int4wo (self , device , dtype ):
492494 weight = torch .randn (512 , 1024 ).to (device ).to (dtype )
0 commit comments