Skip to content

Commit 61195ee

Browse files
committed
test/integration/test_integration.py
1 parent 0bac63c commit 61195ee

File tree

1 file changed

+8
-7
lines changed

1 file changed

+8
-7
lines changed

test/integration/test_integration.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -682,7 +682,7 @@ def _test_lin_weight_subclass_impl(
682682
test_dtype=torch.bfloat16,
683683
test_shape=(32, 64, 32),
684684
):
685-
if not _DEVICE in test_device:
685+
if not torch.accelerator.is_available():
686686
self.skipTest("test requires gpu")
687687
with torch.no_grad():
688688
m, k, n = test_shape
@@ -1403,12 +1403,13 @@ def test_autoquant_compile(self, device, dtype, m1, m2, k, n):
14031403
if (
14041404
is_supported_device and torch.version.hip is None
14051405
): # Only apply to CUDA, not ROCm
1406-
device_capability = torch.cuda.get_device_capability()
1407-
if torch.cuda.is_available() and device_capability < (8, 0):
1408-
if dtype == torch.bfloat16:
1409-
self.skipTest("bfloat16 requires sm80+")
1410-
if m1 == 1 or m2 == 1:
1411-
self.skipTest(f"Shape {(m1, m2, k, n)} requires sm80+")
1406+
if torch.cuda.is_available():
1407+
device_capability = torch.cuda.get_device_capability()
1408+
if device_capability < (8, 0):
1409+
if dtype == torch.bfloat16:
1410+
self.skipTest("bfloat16 requires sm80+")
1411+
if m1 == 1 or m2 == 1:
1412+
self.skipTest(f"Shape {(m1, m2, k, n)} requires sm80+")
14121413

14131414
# TODO remove this once https://github.com/pytorch/pytorch/issues/155838 is resolved
14141415
if m1 == 1 or m2 == 1:

0 commit comments

Comments
 (0)