3434 quantize_int8_rowwise ,
3535)
3636from torchao .quantization .quant_api import quantize_
37- from torchao .utils import get_current_accelerator_device
3837
3938if common_utils .SEED is None :
4039 common_utils .SEED = 1234
4140
42- _DEVICES = (
43- ["cpu" ]
44- + (["cuda" ] if torch .cuda .is_available () else [])
45- + (["xpu" ] if torch .xpu .is_available () else [])
46- )
47- _DEVICE = get_current_accelerator_device ()
41+ _DEVICES = ["cpu" ] + (["cuda" ] if torch .cuda .is_available () else [])
4842
4943
5044def _reset ():
@@ -188,14 +182,12 @@ def test_int8_weight_only_training(self, compile, device):
188182 ],
189183 )
190184 @parametrize ("module_swap" , [False , True ])
191- @pytest .mark .skipif (
192- not torch .accelerator .is_available (), reason = "GPU not available"
193- )
185+ @pytest .mark .skipif (not torch .cuda .is_available (), reason = "CUDA not available" )
194186 def test_int8_mixed_precision_training (self , compile , config , module_swap ):
195187 _reset ()
196188 bsize = 64
197189 embed_dim = 64
198- device = _DEVICE
190+ device = "cuda"
199191
200192 linear = nn .Linear (embed_dim , embed_dim , device = device )
201193 linear_int8mp = copy .deepcopy (linear )
@@ -229,9 +221,7 @@ def snr(ref, actual):
229221
230222 @pytest .mark .skip ("Flaky on CI" )
231223 @parametrize ("compile" , [False , True ])
232- @pytest .mark .skipif (
233- not torch .accelerator .is_available (), reason = "GPU not available"
234- )
224+ @pytest .mark .skipif (not torch .cuda .is_available (), reason = "CUDA not available" )
235225 def test_bitnet_training (self , compile ):
236226 # reference implementation
237227 # https://github.com/microsoft/unilm/blob/master/bitnet/The-Era-of-1-bit-LLMs__Training_Tips_Code_FAQ.pdf
@@ -256,7 +246,7 @@ def forward(self, x):
256246 _reset ()
257247 bsize = 4
258248 embed_dim = 32
259- device = _DEVICE
249+ device = "cuda"
260250
261251 # only use 1 matmul shape to reduce triton autotune time
262252 model_ref = nn .Sequential (
@@ -352,7 +342,7 @@ def _run_subtest(self, args):
352342 dropout_p = 0 ,
353343 )
354344 torch .manual_seed (42 )
355- base_model = Transformer (model_args ).to ( _DEVICE )
345+ base_model = Transformer (model_args ).cuda ( )
356346 fsdp_model = copy .deepcopy (base_model )
357347
358348 quantize_ (base_model .layers , quantize_fn )
@@ -372,7 +362,7 @@ def _run_subtest(self, args):
372362
373363 torch .manual_seed (42 + self .rank + 1 )
374364 for iter_idx in range (5 ):
375- inp = torch .randint (0 , vocab_size , (batch_size , seq_len ), device = _DEVICE )
365+ inp = torch .randint (0 , vocab_size , (batch_size , seq_len ), device = "cuda" )
376366 fsdp_optim .zero_grad (set_to_none = (iter_idx % 2 == 0 ))
377367 fsdp_loss = fsdp_model (inp ).sum ()
378368 fsdp_loss .backward ()
@@ -397,18 +387,14 @@ def _run_subtest(self, args):
397387 )
398388
399389 @skip_if_lt_x_gpu (_FSDP_WORLD_SIZE )
400- @pytest .mark .skipif (
401- not torch .accelerator .is_available (), reason = "GPU not available"
402- )
390+ @pytest .mark .skipif (not torch .cuda .is_available (), reason = "CUDA not available" )
403391 def test_precompute_bitnet_scale (self ):
404392 from torchao .prototype .quantized_training .bitnet import (
405393 get_bitnet_scale ,
406394 precompute_bitnet_scale_for_fsdp ,
407395 )
408396
409- model = nn .Sequential (nn .Linear (32 , 64 ), nn .GELU (), nn .Linear (64 , 32 )).to (
410- _DEVICE
411- )
397+ model = nn .Sequential (nn .Linear (32 , 64 ), nn .GELU (), nn .Linear (64 , 32 )).cuda ()
412398 model_fsdp = copy .deepcopy (model )
413399 quantize_ (model_fsdp , bitnet_training ())
414400 fully_shard (model_fsdp )
0 commit comments