Skip to content

Commit 6a4e65c

Browse files
committed
test/integration/test_integration.py
1 parent 61195ee commit 6a4e65c

File tree

1 file changed

+22
-7
lines changed

1 file changed

+22
-7
lines changed

test/integration/test_integration.py

Lines changed: 22 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -74,12 +74,12 @@
7474
benchmark_model,
7575
check_cpu_version,
7676
check_xpu_version,
77+
get_current_accelerator_device,
7778
is_fbcode,
7879
is_sm_at_least_89,
7980
is_sm_at_least_90,
8081
torch_version_at_least,
8182
unwrap_tensor_subclass,
82-
get_current_accelerator_device,
8383
)
8484

8585
try:
@@ -1053,7 +1053,11 @@ def test_weight_only_quant_force_mixed_mm(self, device, dtype):
10531053
self.skipTest(
10541054
f"weight_only_quant_force_mixed_mm can't be constructed on {device}"
10551055
)
1056-
if torch.cuda.is_available() and dtype == torch.bfloat16 and torch.cuda.get_device_capability() < (8, 0):
1056+
if (
1057+
torch.cuda.is_available()
1058+
and dtype == torch.bfloat16
1059+
and torch.cuda.get_device_capability() < (8, 0)
1060+
):
10571061
self.skipTest("test requires SM capability of at least (8, 0).")
10581062
from torch._inductor import config
10591063

@@ -1085,7 +1089,11 @@ def test_weight_only_quant_use_mixed_mm(self, device, dtype):
10851089
self.skipTest(
10861090
f"weight_only_quant_force_mixed_mm can't be constructed on {device}"
10871091
)
1088-
if torch.cuda.is_available() and dtype == torch.bfloat16 and torch.cuda.get_device_capability() < (8, 0):
1092+
if (
1093+
torch.cuda.is_available()
1094+
and dtype == torch.bfloat16
1095+
and torch.cuda.get_device_capability() < (8, 0)
1096+
):
10891097
self.skipTest("test requires SM capability of at least (8, 0).")
10901098
torch.manual_seed(0)
10911099
from torch._inductor import config
@@ -1255,7 +1263,11 @@ class SmoothquantIntegrationTest(unittest.TestCase):
12551263
@unittest.skipIf(not torch.accelerator.is_available(), "Need GPU available")
12561264
@unittest.skip("Seg fault?")
12571265
def test_non_dynamically_quantizable_linear(self):
1258-
if torch.cuda.is_available() and torch.cuda.is_available() and torch.cuda.get_device_capability() < (8, 0):
1266+
if (
1267+
torch.cuda.is_available()
1268+
and torch.cuda.is_available()
1269+
and torch.cuda.get_device_capability() < (8, 0)
1270+
):
12591271
self.skipTest("test requires SM capability of at least (8, 0).")
12601272
model = (
12611273
torch.nn.Sequential(
@@ -1695,7 +1707,10 @@ def test_autoquant_int4wo(self, device, dtype):
16951707
self.assertGreater(compute_error(ref, out), 20)
16961708

16971709
@parameterized.expand(COMMON_DEVICE_DTYPE)
1698-
@unittest.skipIf(torch.cuda.is_available() and not is_sm_at_least_90(), "Need cuda arch greater than SM90")
1710+
@unittest.skipIf(
1711+
torch.cuda.is_available() and not is_sm_at_least_90(),
1712+
"Need cuda arch greater than SM90",
1713+
)
16991714
@unittest.skipIf(
17001715
True, "Skipping for now, do to lowering bug in inductor"
17011716
) # TODO unblock when fixed
@@ -1938,9 +1953,9 @@ def run_benchmark_model(self, device):
19381953
num_runs = 1
19391954
return benchmark_model(m_bf16, num_runs, example_inputs)
19401955

1941-
@unittest.skipIf(not torch.accelerator.is_available(), "Need GPU available")
1956+
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
19421957
def test_benchmark_model_cuda(self):
1943-
assert self.run_benchmark_model(_DEVICE) is not None
1958+
assert self.run_benchmark_model("cuda") is not None
19441959

19451960
def test_benchmark_model_cpu(self):
19461961
assert self.run_benchmark_model("cpu") is not None

0 commit comments

Comments
 (0)