Skip to content

Commit b9416e9

Browse files
committed
test/hqq/test_hqq_affine.py
1 parent 5111111 commit b9416e9

File tree

1 file changed

+4
-2
lines changed

1 file changed

+4
-2
lines changed

test/hqq/test_hqq_affine.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,13 @@
1515
quantize_,
1616
)
1717
from torchao.testing.utils import skip_if_rocm
18+
from torchao.utils import get_current_accelerator_device
1819

19-
cuda_available = torch.cuda.is_available()
20+
cuda_available = torch.accelerator.is_available()
21+
_DEVICE = get_current_accelerator_device()
2022

2123
# Parameters
22-
device = "cuda:0"
24+
device = f"{_DEVICE}:0"
2325
compute_dtype = torch.bfloat16
2426
group_size = 64
2527
mapping_type = MappingType.ASYMMETRIC

0 commit comments

Comments
 (0)