Skip to content

Commit 2872c9b

Browse files
authored
[xpu][test] Port 2 test/dtypes_{affine_quantized, affine_quantized_float} UT files to intel XPU (#3366)
* enable test/dtypes/test_affine_quantized.py on intel XPU * enable test/dtypes/test_affine_quantized.py on intel XPU * enable test/dtypes/test_affine_quantized_float.py on intel XPU * enable test/dtypes/test_affine_quantized_float.py on intel XPU * enable test/dtypes/test_affine_quantized_float.py on intel XPU * enable test/dtypes/test_affine_quantized_float.py on intel XPU * enable test/dtypes/test_affine_quantized_float.py on intel XPU * fix format issue * fix format issue * fix format issue
1 parent b0a668c commit 2872c9b

File tree

2 files changed

+61
-51
lines changed

2 files changed

+61
-51
lines changed

test/dtypes/test_affine_quantized.py

Lines changed: 26 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
from torchao.utils import (
4040
check_cpu_version,
4141
check_xpu_version,
42+
get_current_accelerator_device,
4243
is_fbcode,
4344
is_ROCM,
4445
is_sm_at_least_89,
@@ -47,10 +48,11 @@
4748
is_cusparselt_available = (
4849
hasattr(torch.backends, "cusparselt") and torch.backends.cusparselt.is_available()
4950
)
51+
_DEVICE = get_current_accelerator_device()
5052

5153

5254
def get_quantization_functions(
53-
do_sparse: bool, do_int4: bool, device: str = "cuda", int4_zp_int: bool = False
55+
do_sparse: bool, do_int4: bool, device: str = _DEVICE, int4_zp_int: bool = False
5456
):
5557
base_functions = [
5658
Int8WeightOnlyConfig(),
@@ -105,9 +107,9 @@ class TestAffineQuantized(TestCase):
105107
["xpu"] if torch.xpu.is_available() else []
106108
)
107109

108-
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
110+
@unittest.skipIf(not torch.accelerator.is_available(), "Need GPU available")
109111
def test_tensor_core_layout_transpose(self):
110-
linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16, device="cuda")
112+
linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16, device=_DEVICE)
111113
t = linear.weight
112114
shape = t.shape
113115
apply_int4_weight_only_quant = Int4WeightOnlyConfig(group_size=32, version=1)
@@ -169,7 +171,7 @@ def _apply(module, config_or_subclass_inserter):
169171
ql = _apply(linear, apply_quant)
170172
ql.to(device)
171173

172-
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
174+
@unittest.skipIf(not torch.accelerator.is_available(), "Need GPU available")
173175
def test_register_new_dispatch(self):
174176
from torchao.dtypes import AffineQuantizedTensor
175177
from torchao.dtypes.affine_quantized_tensor_ops import (
@@ -206,10 +208,10 @@ def apply_uint6_weight_only_quant(linear):
206208
)
207209
return linear
208210

209-
linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16, device="cuda")
211+
linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16, device=_DEVICE)
210212
apply_uint6_weight_only_quant(linear)
211213

212-
example_input = torch.randn(1, 128, dtype=torch.bfloat16, device="cuda")
214+
example_input = torch.randn(1, 128, dtype=torch.bfloat16, device=_DEVICE)
213215
with self.assertRaisesRegex(
214216
AssertionError, "dispatching to my impl for uint6 weight only quant"
215217
):
@@ -234,11 +236,11 @@ def test_print_quantized_module(self):
234236

235237
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
236238
@common_utils.parametrize(
237-
"apply_quant", get_quantization_functions(False, True, "cuda", False)
239+
"apply_quant", get_quantization_functions(False, True, _DEVICE, False)
238240
)
239241
def test_test_copy__apply(self, apply_quant):
240-
linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16, device="cuda")
241-
linear2 = torch.nn.Linear(128, 256, dtype=torch.bfloat16, device="cuda")
242+
linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16, device=_DEVICE)
243+
linear2 = torch.nn.Linear(128, 256, dtype=torch.bfloat16, device=_DEVICE)
242244

243245
if isinstance(apply_quant, AOBaseConfig):
244246
quantize_(linear, apply_quant)
@@ -249,20 +251,20 @@ def test_test_copy__apply(self, apply_quant):
249251
ql = apply_quant(linear)
250252
ql2 = apply_quant(linear2)
251253

252-
example_input = torch.randn(1, 128, dtype=torch.bfloat16, device="cuda")
254+
example_input = torch.randn(1, 128, dtype=torch.bfloat16, device=_DEVICE)
253255
output = ql(example_input)
254256
ql2.weight.copy_(ql.weight)
255257
ql2.bias = ql.bias
256258
output2 = ql2(example_input)
257259
self.assertEqual(output, output2)
258260

259-
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
261+
@unittest.skipIf(not torch.accelerator.is_available(), "Need GPU available")
260262
@common_utils.parametrize(
261-
"apply_quant", get_quantization_functions(False, True, "cuda", False)
263+
"apply_quant", get_quantization_functions(False, True, _DEVICE, False)
262264
)
263265
def test_copy__mismatch_metadata(self, apply_quant):
264-
linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16, device="cuda")
265-
linear2 = torch.nn.Linear(128, 512, dtype=torch.bfloat16, device="cuda")
266+
linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16, device=_DEVICE)
267+
linear2 = torch.nn.Linear(128, 512, dtype=torch.bfloat16, device=_DEVICE)
266268

267269
if isinstance(apply_quant, AOBaseConfig):
268270
quantize_(linear, apply_quant)
@@ -336,7 +338,7 @@ def test_alias(self, device, dtype):
336338
quantize_(dummy, Int8DynamicActivationInt8WeightConfig())
337339
_ = dummy.weight[...]
338340

339-
@common_utils.parametrize("device", ["cuda"])
341+
@common_utils.parametrize("device", [_DEVICE])
340342
@common_utils.parametrize("dtype", [torch.bfloat16])
341343
@skip_if_no_cuda()
342344
@skip_if_rocm("ROCm enablement in progress")
@@ -350,9 +352,9 @@ def test_slice_int4wo(self, device, dtype):
350352
_ = dummy.weight.narrow(0, 0, 64)
351353
_ = dummy.weight.narrow(1, 0, 128)
352354

353-
@common_utils.parametrize("device", ["cuda"])
355+
@common_utils.parametrize("device", [_DEVICE])
354356
@common_utils.parametrize("dtype", [torch.float16, torch.bfloat16])
355-
@skip_if_no_cuda()
357+
@unittest.skipIf(not torch.accelerator.is_available(), "Need GPU available")
356358
@skip_if_no_gemlite()
357359
def test_slice_gemlite(self, device, dtype):
358360
# in_feature not divisible by 1024
@@ -433,7 +435,7 @@ def dequant(input_layer, in_features, orig_shape):
433435
)
434436
self.assertEqual((W_slice_ref - W_slice).abs().mean().item(), 0)
435437

436-
@common_utils.parametrize("device", ["cuda"])
438+
@common_utils.parametrize("device", [_DEVICE])
437439
@common_utils.parametrize("dtype", [torch.bfloat16])
438440
def test_matmul(self, device, dtype):
439441
x = torch.randn(53, 2048)
@@ -450,14 +452,14 @@ def test_matmul(self, device, dtype):
450452
# make sure it runs
451453
torch.matmul(x, w.t())
452454

453-
@common_utils.parametrize("device", ["cuda"])
455+
@common_utils.parametrize("device", [_DEVICE])
454456
@common_utils.parametrize("dtype", [torch.bfloat16])
455457
@skip_if_no_cuda()
456458
@skip_if_rocm("ROCm enablement in progress")
457459
def test_slice_and_copy_int4wo(self, device, dtype):
458-
l = torch.nn.Linear(1024, 1024).to("cuda").to(torch.bfloat16)
460+
l = torch.nn.Linear(1024, 1024).to(_DEVICE).to(torch.bfloat16)
459461
l.weight = torch.nn.Parameter(
460-
torch.zeros(1024, 1024, dtype=torch.bfloat16, device="cuda")
462+
torch.zeros(1024, 1024, dtype=torch.bfloat16, device=_DEVICE)
461463
)
462464
quantize_(l, Int4WeightOnlyConfig(version=1))
463465
param = l.weight
@@ -474,7 +476,7 @@ def test_slice_and_copy_int4wo(self, device, dtype):
474476
assert param.data.dequantize()[0][0] == 0
475477

476478
# dummy_l has random input (shouldn't be 0)
477-
dummy_l = torch.nn.Linear(1024, 1024).to("cuda").to(torch.bfloat16)
479+
dummy_l = torch.nn.Linear(1024, 1024).to(_DEVICE).to(torch.bfloat16)
478480
quantize_(dummy_l, Int4WeightOnlyConfig(version=1))
479481
quantized = dummy_l.weight
480482
quantized = quantized.narrow(0, 0, 512)
@@ -484,9 +486,9 @@ def test_slice_and_copy_int4wo(self, device, dtype):
484486
# making sure param.data is updated
485487
assert param.data.dequantize()[0][0] != 0
486488

487-
@common_utils.parametrize("device", ["cuda"])
489+
@common_utils.parametrize("device", [_DEVICE])
488490
@common_utils.parametrize("dtype", [torch.bfloat16])
489-
@skip_if_no_cuda()
491+
@unittest.skipIf(not torch.accelerator.is_available(), "Need GPU available")
490492
@skip_if_rocm("ROCm enablement in progress")
491493
def test_mm_int4wo(self, device, dtype):
492494
weight = torch.randn(512, 1024).to(device).to(dtype)

0 commit comments

Comments
 (0)