diff --git a/test/prototype/test_parq.py b/test/prototype/test_parq.py index 119bfb8d05..9e802ee92c 100644 --- a/test/prototype/test_parq.py +++ b/test/prototype/test_parq.py @@ -51,7 +51,12 @@ torch_version_at_least, ) -_DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") +if torch.cuda.is_available(): + _DEVICE = "cuda" +elif torch.xpu.is_available(): + _DEVICE = "xpu" +else: + _DEVICE = "cpu" class M(nn.Module): diff --git a/test/prototype/test_quantized_training.py b/test/prototype/test_quantized_training.py index ab25a38bb3..d3b14f0f94 100644 --- a/test/prototype/test_quantized_training.py +++ b/test/prototype/test_quantized_training.py @@ -34,11 +34,13 @@ quantize_int8_rowwise, ) from torchao.quantization.quant_api import quantize_ +from torchao.utils import get_current_accelerator_device if common_utils.SEED is None: common_utils.SEED = 1234 -_DEVICES = ["cpu"] + (["cuda"] if torch.cuda.is_available() else []) +_DEVICE = get_current_accelerator_device() +_DEVICES = ["cpu"] + ([_DEVICE] if torch.accelerator.is_available() else []) def _reset(): @@ -182,12 +184,14 @@ def test_int8_weight_only_training(self, compile, device): ], ) @parametrize("module_swap", [False, True]) - @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") + @pytest.mark.skipif( + not torch.accelerator.is_available(), reason="GPU not available" + ) def test_int8_mixed_precision_training(self, compile, config, module_swap): _reset() bsize = 64 embed_dim = 64 - device = "cuda" + device = _DEVICE linear = nn.Linear(embed_dim, embed_dim, device=device) linear_int8mp = copy.deepcopy(linear) @@ -221,7 +225,9 @@ def snr(ref, actual): @pytest.mark.skip("Flaky on CI") @parametrize("compile", [False, True]) - @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") + @pytest.mark.skipif( + not torch.accelerator.is_available(), reason="GPU not available" + ) def test_bitnet_training(self, compile): # reference implementation # https://github.com/microsoft/unilm/blob/master/bitnet/The-Era-of-1-bit-LLMs__Training_Tips_Code_FAQ.pdf @@ -246,7 +252,7 @@ def forward(self, x): _reset() bsize = 4 embed_dim = 32 - device = "cuda" + device = _DEVICE # only use 1 matmul shape to reduce triton autotune time model_ref = nn.Sequential( @@ -342,7 +348,7 @@ def _run_subtest(self, args): dropout_p=0, ) torch.manual_seed(42) - base_model = Transformer(model_args).cuda() + base_model = Transformer(model_args).to(_DEVICE) fsdp_model = copy.deepcopy(base_model) quantize_(base_model.layers, quantize_fn) @@ -362,7 +368,7 @@ def _run_subtest(self, args): torch.manual_seed(42 + self.rank + 1) for iter_idx in range(5): - inp = torch.randint(0, vocab_size, (batch_size, seq_len), device="cuda") + inp = torch.randint(0, vocab_size, (batch_size, seq_len), device=_DEVICE) fsdp_optim.zero_grad(set_to_none=(iter_idx % 2 == 0)) fsdp_loss = fsdp_model(inp).sum() fsdp_loss.backward() @@ -387,14 +393,18 @@ def _run_subtest(self, args): ) @skip_if_lt_x_gpu(_FSDP_WORLD_SIZE) - @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") + @pytest.mark.skipif( + not torch.accelerator.is_available(), reason="GPU not available" + ) def test_precompute_bitnet_scale(self): from torchao.prototype.quantized_training.bitnet import ( get_bitnet_scale, precompute_bitnet_scale_for_fsdp, ) - model = nn.Sequential(nn.Linear(32, 64), nn.GELU(), nn.Linear(64, 32)).cuda() + model = nn.Sequential(nn.Linear(32, 64), nn.GELU(), nn.Linear(64, 32)).to( + _DEVICE + ) model_fsdp = copy.deepcopy(model) quantize_(model_fsdp, bitnet_training()) fully_shard(model_fsdp)