diff --git a/test/float8/test_float8_utils.py b/test/float8/test_float8_utils.py index c253af55ea..be5ff46f5c 100644 --- a/test/float8/test_float8_utils.py +++ b/test/float8/test_float8_utils.py @@ -10,11 +10,14 @@ from torchao.float8.float8_utils import _round_scale_down_to_power_of_2 from torchao.testing.utils import skip_if_rocm +from torchao.utils import get_current_accelerator_device + +_DEVICE = get_current_accelerator_device() # source for notable single-precision cases: # https://en.wikipedia.org/wiki/Single-precision_floating-point_format -@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") +@unittest.skipIf(not torch.accelerator.is_available(), "GPU not available") @pytest.mark.parametrize( "test_case", [ @@ -38,8 +41,8 @@ def test_round_scale_down_to_power_of_2_valid_inputs( ): test_case_name, input, expected_result = test_case input_tensor, expected_tensor = ( - torch.tensor(input, dtype=torch.float32).cuda(), - torch.tensor(expected_result, dtype=torch.float32).cuda(), + torch.tensor(input, dtype=torch.float32).to(_DEVICE), + torch.tensor(expected_result, dtype=torch.float32).to(_DEVICE), ) result = _round_scale_down_to_power_of_2(input_tensor) diff --git a/test/hqq/test_hqq_affine.py b/test/hqq/test_hqq_affine.py index 09bdfa8e61..b103e8ccf0 100644 --- a/test/hqq/test_hqq_affine.py +++ b/test/hqq/test_hqq_affine.py @@ -14,12 +14,14 @@ ZeroPointDomain, quantize_, ) -from torchao.testing.utils import skip_if_rocm +from torchao.testing.utils import skip_if_rocm, skip_if_xpu +from torchao.utils import get_current_accelerator_device -cuda_available = torch.cuda.is_available() +cuda_available = torch.accelerator.is_available() +_DEVICE = get_current_accelerator_device() # Parameters -device = "cuda:0" +device = f"{_DEVICE}:0" compute_dtype = torch.bfloat16 group_size = 64 mapping_type = MappingType.ASYMMETRIC @@ -114,6 +116,7 @@ def test_hqq_plain_5bit(self): ) @skip_if_rocm("ROCm enablement in progress") + @skip_if_xpu("XPU enablement in progress") def test_hqq_plain_4bit(self): self._test_hqq( dtype=torch.uint4,