3434 quantize_int8_rowwise ,
3535)
3636from torchao .quantization .quant_api import quantize_
37+ from torchao .utils import get_current_accelerator_device
3738
3839if common_utils .SEED is None :
3940 common_utils .SEED = 1234
4041
41- _DEVICES = ["cpu" ] + (["cuda" ] if torch .cuda .is_available () else [])
42+ _DEVICES = (
43+ ["cpu" ]
44+ + (["cuda" ] if torch .cuda .is_available () else [])
45+ + (["xpu" ] if torch .xpu .is_available () else [])
46+ )
47+ _DEVICE = get_current_accelerator_device ()
4248
4349
4450def _reset ():
@@ -182,12 +188,14 @@ def test_int8_weight_only_training(self, compile, device):
182188 ],
183189 )
184190 @parametrize ("module_swap" , [False , True ])
185- @pytest .mark .skipif (not torch .cuda .is_available (), reason = "CUDA not available" )
191+ @pytest .mark .skipif (
192+ not torch .accelerator .is_available (), reason = "GPU not available"
193+ )
186194 def test_int8_mixed_precision_training (self , compile , config , module_swap ):
187195 _reset ()
188196 bsize = 64
189197 embed_dim = 64
190- device = "cuda"
198+ device = _DEVICE
191199
192200 linear = nn .Linear (embed_dim , embed_dim , device = device )
193201 linear_int8mp = copy .deepcopy (linear )
@@ -221,7 +229,9 @@ def snr(ref, actual):
221229
222230 @pytest .mark .skip ("Flaky on CI" )
223231 @parametrize ("compile" , [False , True ])
224- @pytest .mark .skipif (not torch .cuda .is_available (), reason = "CUDA not available" )
232+ @pytest .mark .skipif (
233+ not torch .accelerator .is_available (), reason = "GPU not available"
234+ )
225235 def test_bitnet_training (self , compile ):
226236 # reference implementation
227237 # https://github.com/microsoft/unilm/blob/master/bitnet/The-Era-of-1-bit-LLMs__Training_Tips_Code_FAQ.pdf
@@ -246,7 +256,7 @@ def forward(self, x):
246256 _reset ()
247257 bsize = 4
248258 embed_dim = 32
249- device = "cuda"
259+ device = _DEVICE
250260
251261 # only use 1 matmul shape to reduce triton autotune time
252262 model_ref = nn .Sequential (
@@ -342,7 +352,7 @@ def _run_subtest(self, args):
342352 dropout_p = 0 ,
343353 )
344354 torch .manual_seed (42 )
345- base_model = Transformer (model_args ).cuda ( )
355+ base_model = Transformer (model_args ).to ( _DEVICE )
346356 fsdp_model = copy .deepcopy (base_model )
347357
348358 quantize_ (base_model .layers , quantize_fn )
@@ -362,7 +372,7 @@ def _run_subtest(self, args):
362372
363373 torch .manual_seed (42 + self .rank + 1 )
364374 for iter_idx in range (5 ):
365- inp = torch .randint (0 , vocab_size , (batch_size , seq_len ), device = "cuda" )
375+ inp = torch .randint (0 , vocab_size , (batch_size , seq_len ), device = _DEVICE )
366376 fsdp_optim .zero_grad (set_to_none = (iter_idx % 2 == 0 ))
367377 fsdp_loss = fsdp_model (inp ).sum ()
368378 fsdp_loss .backward ()
@@ -387,14 +397,18 @@ def _run_subtest(self, args):
387397 )
388398
389399 @skip_if_lt_x_gpu (_FSDP_WORLD_SIZE )
390- @pytest .mark .skipif (not torch .cuda .is_available (), reason = "CUDA not available" )
400+ @pytest .mark .skipif (
401+ not torch .accelerator .is_available (), reason = "GPU not available"
402+ )
391403 def test_precompute_bitnet_scale (self ):
392404 from torchao .prototype .quantized_training .bitnet import (
393405 get_bitnet_scale ,
394406 precompute_bitnet_scale_for_fsdp ,
395407 )
396408
397- model = nn .Sequential (nn .Linear (32 , 64 ), nn .GELU (), nn .Linear (64 , 32 )).cuda ()
409+ model = nn .Sequential (nn .Linear (32 , 64 ), nn .GELU (), nn .Linear (64 , 32 )).to (
410+ _DEVICE
411+ )
398412 model_fsdp = copy .deepcopy (model )
399413 quantize_ (model_fsdp , bitnet_training ())
400414 fully_shard (model_fsdp )
0 commit comments