4343 to_nf4 ,
4444)
4545from torchao .testing .utils import skip_if_rocm
46- from torchao .utils import torch_version_at_least
46+ from torchao .utils import get_current_accelerator_device , torch_version_at_least
4747
4848bnb_available = False
4949
5757logging .basicConfig (
5858 format = "%(asctime)s - %(name)s - %(levelname)s - %(message)s" , level = logging .INFO
5959)
60+ _DEVICE = get_current_accelerator_device ()
6061
6162
6263def _build_input_weight (embed_dim : int , device : torch .device , dtype : torch .dtype ):
@@ -68,7 +69,7 @@ def _build_input_weight(embed_dim: int, device: torch.device, dtype: torch.dtype
6869
6970def _build_bnb_linear (input_weight , device ):
7071 assert bnb_available , "Needs bitsandbytes support"
71- param = bnb .nn .Params4bit (input_weight , requires_grad = False , quant_type = "nf4" ).cuda (
72+ param = bnb .nn .Params4bit (input_weight , requires_grad = False , quant_type = "nf4" ).to (
7273 device
7374 )
7475 bnb_linear = bnb .nn .LinearNF4 (
@@ -121,7 +122,7 @@ def test_backward_dtype_match(self, dtype: torch.dtype):
121122 assert nf4_tensor .grad is None
122123
123124 @unittest .skipIf (not bnb_available , "Need bnb availble" )
124- @unittest .skipIf (not torch .cuda .is_available (), "Need CUDA available" )
125+ @unittest .skipIf (not torch .accelerator .is_available (), "Need GPU available" )
125126 @unittest .skipIf (
126127 torch_version_at_least ("2.7.0" ), reason = "Failing in CI"
127128 ) # TODO: fix this
@@ -130,7 +131,7 @@ def test_backward_dtype_match(self, dtype: torch.dtype):
130131 def test_reconstruction_qlora_vs_bnb (self , dtype : torch .dtype ):
131132 # From https://github.com/drisspg/transformer_nuggets/blob/f05afad68ad9086d342268f46a7f344617a02314/test/test_qlora.py#L65C1-L81C47
132133 torch .manual_seed (0 )
133- device = "cuda"
134+ device = _DEVICE
134135 embed_dim = 512
135136 input_weight = _build_input_weight (embed_dim , device , dtype )
136137 nf4_weight = to_nf4 (input_weight )
@@ -147,7 +148,7 @@ def test_reconstruction_qlora_vs_bnb(self, dtype: torch.dtype):
147148 assert (nugs_diff - bnb_diff ).abs () < 2e-1
148149
149150 @unittest .skipIf (not bnb_available , "Need bnb availble" )
150- @unittest .skipIf (not torch .cuda .is_available (), "Need CUDA available" )
151+ @unittest .skipIf (not torch .accelerator .is_available (), "Need GPU available" )
151152 @skip_if_rocm ("ROCm enablement in progress" )
152153 @unittest .skipIf (
153154 torch_version_at_least ("2.7.0" ), reason = "Failing in CI"
@@ -160,12 +161,12 @@ def test_nf4_bnb_linear(self, dtype: torch.dtype):
160161 """
161162 torch .manual_seed (0 )
162163 dim = 512
163- device = "cuda"
164+ device = _DEVICE
164165 input_weight = _build_input_weight (dim , device , dtype )
165166 nf4_weight = to_nf4 (input_weight )
166167 bnb_linear = _build_bnb_linear (input_weight , device )
167168
168- inp = torch .randn (2 , 512 , dtype = dtype , device = "cuda" )
169+ inp = torch .randn (2 , 512 , dtype = dtype , device = _DEVICE )
169170
170171 out_nf4 = linear_nf4 (inp , nf4_weight ).sum ()
171172 out_bnb = bnb_linear (inp ).sum ()
@@ -176,11 +177,11 @@ def test_nf4_bnb_linear(self, dtype: torch.dtype):
176177 assert err_native < 0.5 * dim
177178 assert err_bnb < 0.5 * dim
178179
179- @unittest .skipIf (not torch .cuda .is_available (), "Need cuda for test" )
180+ @unittest .skipIf (not torch .accelerator .is_available (), "Need GPU for test" )
180181 @parametrize ("dtype" , [torch .bfloat16 , torch .float16 , torch .float32 ])
181182 def test_load_from_state_dicts (self , dtype : torch .dtype ):
182183 """Tests loading to and from different module state dicts"""
183- input_tensor = torch .rand (64 , device = "cuda" , dtype = dtype )
184+ input_tensor = torch .rand (64 , device = _DEVICE , dtype = dtype )
184185 base_mod = self .TestMod (input_tensor , 32 , 2 )
185186
186187 dummy_dict = {"param" : input_tensor }
@@ -222,27 +223,27 @@ def test_to_copy(self, dtype: torch.dtype):
222223 nf4_to_dtype = input_tensor_nf4 .to (dtype )
223224 torch .testing .assert_allclose (input_tensor , nf4_to_dtype , atol = 0.13 , rtol = 0.13 )
224225
225- if torch .cuda .is_available ():
226- input_tensor = torch .rand (128 , device = "cuda" )
226+ if torch .accelerator .is_available ():
227+ input_tensor = torch .rand (128 , device = _DEVICE )
227228 input_tensor_nf4 = to_nf4 (input_tensor , 32 , 2 )
228229 nf4_to_dtype = input_tensor_nf4 .to (dtype )
229230 torch .testing .assert_allclose (
230231 input_tensor , nf4_to_dtype , atol = 0.13 , rtol = 0.13
231232 )
232233
233- @unittest .skipIf (not torch .cuda .is_available (), "Need cuda for test" )
234+ @unittest .skipIf (not torch .accelerator .is_available (), "Need gpu for test" )
234235 def test_to_copy_device (self ):
235236 input_tensor = torch .rand (128 , device = "cpu" )
236237 t = to_nf4 (input_tensor , 32 , 2 )
237238 assert t .device == torch .device ("cpu" )
238- z = t .cuda ( )
239- assert z .device .type == "cuda" # Because the device could be cuda:0
239+ z = t .to ( _DEVICE )
240+ assert z .device .type == _DEVICE . type # Because the device could be cuda:0
240241 x = z .cpu ()
241242 assert x .device == torch .device ("cpu" )
242243
243- input_tensor = torch .rand (128 , device = "cuda" )
244+ input_tensor = torch .rand (128 , device = _DEVICE )
244245 t = to_nf4 (input_tensor , 32 , 2 )
245- assert t .device .type == "cuda"
246+ assert t .device .type == _DEVICE . type
246247
247248 @parametrize ("dtype" , [torch .bfloat16 , torch .float16 , torch .float32 ])
248249 def test_to_dtype (self , dtype : torch .dtype ):
@@ -252,10 +253,10 @@ def test_to_dtype(self, dtype: torch.dtype):
252253 assert type (input_tensor_nf4 .to (dtype )) is torch .Tensor
253254 assert input_tensor_nf4 .to (dtype ).dtype is dtype
254255
255- @unittest .skipIf (not torch .cuda .is_available (), "Need CUDA available" )
256+ @unittest .skipIf (not torch .accelerator .is_available (), "Need GPU available" )
256257 @parametrize ("dtype" , [torch .bfloat16 , torch .float16 , torch .float32 ])
257258 def test_smoketest_linear (self , dtype : torch .dtype ):
258- a = torch .randn (32 , 32 , dtype = dtype , device = "cuda" )
259+ a = torch .randn (32 , 32 , dtype = dtype , device = _DEVICE )
259260 a_nf4 = torchao .dtypes .to_nf4 (a , 16 , 2 )
260261 inp = torch .randn (2 , 32 , 32 , dtype = a .dtype , device = a .device )
261262 _ = torch .nn .functional .linear (inp , a )
@@ -272,37 +273,37 @@ def test_smoketest_linear_compile(self, dtype: torch.dtype):
272273 self .skipTest ("test requires SM capability of at least (8, 0)." )
273274 if version .parse (torch .__version__ ) < version .parse ("2.3.0" ):
274275 self .skipTest ("test requires 2.3.0 and above for tracing NF4Tensor" )
275- a = torch .randn (32 , 32 , dtype = dtype , device = "cuda" )
276+ a = torch .randn (32 , 32 , dtype = dtype , device = _DEVICE )
276277 a_nf4 = torchao .dtypes .to_nf4 (a , 16 , 2 )
277278 inp = torch .randn (2 , 32 , 32 , dtype = a .dtype , device = a .device )
278279 _ = torch .compile (torch .nn .functional .linear , mode = "max-autotune" )(inp , a_nf4 )
279280
280- @unittest .skipIf (not torch .cuda .is_available (), "Need CUDA available" )
281+ @unittest .skipIf (not torch .accelerator .is_available (), "Need GPU available" )
281282 @parametrize ("dtype" , [torch .bfloat16 , torch .float16 , torch .float32 ])
282283 @parametrize ("shape" , [(16 , 16 ), (32 , 16 )])
283284 @parametrize ("chunk_size" , [8 , 16 , 32 ])
284285 def test_chunk_size_equivalence (self , dtype : torch .dtype , shape , chunk_size ):
285- a = torch .randn (shape , device = "cuda" , dtype = dtype )
286+ a = torch .randn (shape , device = _DEVICE , dtype = dtype )
286287 with unittest .mock .patch ("torchao.dtypes.nf4tensor.CHUNK_SIZE" , chunk_size ):
287288 nf4_patched = to_nf4 (a , 16 , 2 )
288289 # This will be essentially no chunking since the numel is alot smaller than default chunk_size
289290 nf4_base = to_nf4 (a , 16 , 2 )
290291
291292 torch .testing .assert_close (nf4_patched .quantized_data , nf4_base .quantized_data )
292293
293- @unittest .skipIf (not torch .cuda .is_available (), "Need CUDA available" )
294+ @unittest .skipIf (not torch .accelerator .is_available (), "Need GPU available" )
294295 @parametrize ("input_size" , [(512 * 512 ,), (512 , 512 )])
295296 def test_empty_like (self , input_size : Union [Tuple [int ], int ]):
296- nf4_tensor = to_nf4 (torch .rand (input_size , device = "cuda" ))
297+ nf4_tensor = to_nf4 (torch .rand (input_size , device = _DEVICE ))
297298 new_tensor = torch .empty_like (nf4_tensor , device = "cpu" )
298299 self .assertTrue (isinstance (new_tensor , NF4Tensor ))
299300 self .assertEqual (new_tensor .get_device (), - 1 ) # that it's on CPU
300301 self .assertEqual (new_tensor .size (), nf4_tensor .size ())
301302
302- @unittest .skipIf (not torch .cuda .is_available (), "Need CUDA available" )
303+ @unittest .skipIf (not torch .accelerator .is_available (), "Need GPU available" )
303304 @parametrize ("compile" , [False , True ])
304305 def test_quantize_api (self , compile ):
305- nf4_linear = nn .Linear (512 , 512 , device = "cuda" )
306+ nf4_linear = nn .Linear (512 , 512 , device = _DEVICE )
306307 torchao .quantize_ (nf4_linear , nf4_weight_only ())
307308 assert isinstance (nf4_linear .weight , NF4Tensor )
308309
@@ -313,14 +314,14 @@ def test_quantize_api(self, compile):
313314 nf4_linear .compile ()
314315 ref_linear .compile ()
315316
316- nf4_x = torch .randn (2 , 512 , device = "cuda" ).requires_grad_ ()
317+ nf4_x = torch .randn (2 , 512 , device = _DEVICE ).requires_grad_ ()
317318 ref_x = nf4_x .detach ().clone ().requires_grad_ ()
318319
319320 nf4_out = nf4_linear (nf4_x )
320321 ref_out = ref_linear (ref_x )
321322 self .assertEqual (nf4_out , ref_out )
322323
323- grad_out = torch .randn (2 , 512 , device = "cuda" )
324+ grad_out = torch .randn (2 , 512 , device = _DEVICE )
324325 nf4_out .backward (grad_out )
325326 ref_out .backward (grad_out )
326327 self .assertEqual (nf4_x .grad , ref_x .grad )
@@ -511,60 +512,60 @@ def test_tensor_as_strided_invalid(self, input_size: Union[Tuple[int], int]):
511512 nf4_tensor , nf4_tensor .size (), stride , nf4_tensor .storage_offset ()
512513 )
513514
514- @unittest .skipIf (not torch .cuda .is_available (), "Need CUDA available" )
515+ @unittest .skipIf (not torch .accelerator .is_available (), "Need GPu available" )
515516 def test_pin_memory (self ):
516517 nf4_tensor = to_nf4 (torch .randn (512 * 512 ))
517518 self .assertFalse (nf4_tensor .is_pinned ())
518519
519520 nf4_tensor = nf4_tensor .pin_memory ()
520521 self .assertTrue (nf4_tensor .is_pinned ())
521522
522- nf4_tensor = to_nf4 (torch .randn (512 * 512 , device = "cuda" ))
523+ nf4_tensor = to_nf4 (torch .randn (512 * 512 , device = _DEVICE ))
523524 self .assertFalse (nf4_tensor .is_pinned ())
524525
525- @unittest .skipIf (not torch .cuda .is_available (), "Need CUDA available" )
526+ @unittest .skipIf (not torch .accelerator .is_available (), "Need GPU available" )
526527 def test_to_cuda (self ):
527528 nf4_tensor = to_nf4 (torch .randn (512 * 512 ))
528529 self .assertEqual (nf4_tensor .device .type , "cpu" )
529- nf4_tensor = nf4_tensor .to ("cuda" , non_blocking = True )
530- self .assertEqual (nf4_tensor .device .type , "cuda" )
530+ nf4_tensor = nf4_tensor .to (_DEVICE , non_blocking = True )
531+ self .assertEqual (nf4_tensor .device .type , _DEVICE . type )
531532 self .assertEqual (type (nf4_tensor ), NF4Tensor )
532533 nf4_tensor .get_original_weight () # make sure we can dequantize
533534
534535 nf4_tensor = to_nf4 (torch .randn (512 * 512 ))
535536 self .assertEqual (nf4_tensor .device .type , "cpu" )
536- nf4_tensor = nf4_tensor .to ("cuda" )
537- self .assertEqual (nf4_tensor .device .type , "cuda" )
537+ nf4_tensor = nf4_tensor .to (_DEVICE )
538+ self .assertEqual (nf4_tensor .device .type , _DEVICE . type )
538539 self .assertEqual (type (nf4_tensor ), NF4Tensor )
539540 nf4_tensor .get_original_weight ()
540541
541542 nf4_tensor = to_nf4 (torch .randn (512 * 512 ))
542543 self .assertEqual (nf4_tensor .device .type , "cpu" )
543- nf4_tensor = nf4_tensor .to ("cuda" , torch .bfloat16 )
544- self .assertEqual (nf4_tensor .device .type , "cuda" )
544+ nf4_tensor = nf4_tensor .to (_DEVICE , torch .bfloat16 )
545+ self .assertEqual (nf4_tensor .device .type , _DEVICE . type )
545546 self .assertEqual (nf4_tensor .dtype , torch .bfloat16 )
546547 self .assertEqual (type (nf4_tensor ), torch .Tensor ) # dequantized
547548
548- @unittest .skipIf (not torch .cuda .is_available (), "Need CUDA available" )
549+ @unittest .skipIf (not torch .accelerator .is_available (), "Need GPU available" )
549550 def test_to_cpu (self ):
550- nf4_tensor = to_nf4 (torch .randn (512 * 512 , device = "cuda" ))
551+ nf4_tensor = to_nf4 (torch .randn (512 * 512 , device = _DEVICE ))
551552 nf4_tensor = nf4_tensor .cpu ()
552553 self .assertEqual (nf4_tensor .device .type , "cpu" )
553554 for attr in _INNER_TENSOR_NAMES_FOR_SHARDING :
554555 inner_tensor = getattr (nf4_tensor , attr )
555556 self .assertEqual (inner_tensor .device .type , "cpu" )
556557 nf4_tensor .get_original_weight () # make sure we can dequantize
557558
558- @unittest .skipIf (not torch .cuda .is_available (), "Need CUDA available" )
559+ @unittest .skipIf (not torch .accelerator .is_available (), "Need GPU available" )
559560 def test_to_module (self ):
560561 linear = nn .Linear (512 , 512 , bias = False )
561562 linear .weight = nn .Parameter (
562563 to_nf4 (linear .weight .detach ()), requires_grad = False
563564 )
564- linear .cuda ( )
565- self .assertEqual (linear .weight .device .type , "cuda" )
565+ linear .to ( _DEVICE )
566+ self .assertEqual (linear .weight .device .type , _DEVICE . type )
566567 weight = linear .weight .get_original_weight ()
567- self .assertEqual (weight .device .type , "cuda" )
568+ self .assertEqual (weight .device .type , _DEVICE . type )
568569
569570 linear .cpu ()
570571 self .assertEqual (linear .weight .device .type , "cpu" )
@@ -575,20 +576,20 @@ def test_to_module(self):
575576 linear .weight = nn .Parameter (
576577 to_nf4 (linear .weight .detach ()), requires_grad = False
577578 )
578- linear .to ("cuda" )
579- self .assertEqual (linear .weight .device .type , "cuda" )
579+ linear .to (_DEVICE )
580+ self .assertEqual (linear .weight .device .type , _DEVICE . type )
580581 weight = linear .weight .get_original_weight ()
581- self .assertEqual (weight .device .type , "cuda" )
582+ self .assertEqual (weight .device .type , _DEVICE . type )
582583
583584 linear .to ("cpu" )
584585 self .assertEqual (linear .weight .device .type , "cpu" )
585586 weight = linear .weight .get_original_weight ()
586587 self .assertEqual (weight .device .type , "cpu" )
587588
588- @unittest .skipIf (not torch .cuda .is_available (), "Need CUDA available" )
589+ @unittest .skipIf (not torch .accelerator .is_available (), "Need GPU available" )
589590 @parametrize ("input_size" , [512 * 512 , (512 * 512 ,), (512 , 512 )])
590591 def test_tensor_deepcopy (self , input_size : Union [Tuple [int ], int ]):
591- nf4_orig = to_nf4 (torch .randn (input_size , device = "cuda" ))
592+ nf4_orig = to_nf4 (torch .randn (input_size , device = _DEVICE ))
592593 nf4_clone = copy .deepcopy (nf4_orig )
593594 self .assertEqual (
594595 nf4_clone .get_original_weight (), nf4_orig .get_original_weight ()
@@ -678,7 +679,7 @@ def _test_qlora_fsdp2(
678679 dropout_p = 0 ,
679680 )
680681 torch .manual_seed (42 )
681- with torch .device ("cuda" ):
682+ with torch .device (_DEVICE ):
682683 base_model = Transformer (model_args )
683684 for layer in base_model .layers :
684685 # attention with lora adapters
@@ -732,7 +733,7 @@ def _test_qlora_fsdp2(
732733
733734 torch .manual_seed (42 + self .rank + 1 )
734735 for iter_idx in range (5 ):
735- inp = torch .randint (0 , vocab_size , (batch_size , seq_len ), device = "cuda" )
736+ inp = torch .randint (0 , vocab_size , (batch_size , seq_len ), device = _DEVICE )
736737 fsdp_optim .zero_grad (set_to_none = (iter_idx % 2 == 0 ))
737738 fsdp_loss = fsdp_model (inp ).sum ()
738739 fsdp_loss .backward ()
@@ -756,7 +757,7 @@ def world_size(self) -> int:
756757 return 2
757758
758759 @skip_if_lt_x_gpu (2 )
759- @unittest .skipIf (not torch .cuda .is_available (), "Need CUDA available" )
760+ @unittest .skipIf (not torch .accelerator .is_available (), "Need GPU available" )
760761 def test_comm (self ):
761762 self .run_subtests (
762763 {"input_size" : [512 , 2048 ]},
@@ -767,7 +768,7 @@ def _test_comm(self, input_size: int):
767768 from torch .distributed ._composable .fsdp import fully_shard
768769 from torch .distributed ._tensor import distribute_tensor
769770
770- model = nn .Linear (input_size , input_size , device = "cuda" )
771+ model = nn .Linear (input_size , input_size , device = _DEVICE )
771772 origin_tensor = model .weight
772773 origin_nf4_tensor = to_nf4 (origin_tensor )
773774 model = fully_shard (model )
0 commit comments