From e77586dd85375b47594c3f9beff6d9d48478d37d Mon Sep 17 00:00:00 2001 From: "Zeng, Xiangdong" Date: Sun, 30 Nov 2025 22:39:47 +0800 Subject: [PATCH 1/7] add test/quantization/pt2e/test_quantize_pt2e.py --- test/quantization/pt2e/test_quantize_pt2e.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/test/quantization/pt2e/test_quantize_pt2e.py b/test/quantization/pt2e/test_quantize_pt2e.py index 0b5fd64120..21c9365df3 100644 --- a/test/quantization/pt2e/test_quantize_pt2e.py +++ b/test/quantization/pt2e/test_quantize_pt2e.py @@ -30,6 +30,7 @@ ) from torch.testing._internal.common_utils import ( TEST_CUDA, + TEST_XPU, TemporaryFileName, instantiate_parametrized_tests, parametrize, @@ -68,9 +69,10 @@ QuantizationConfig, ) from torchao.testing.pt2e.utils import PT2EQuantizationTestCase -from torchao.utils import torch_version_at_least +from torchao.utils import torch_version_at_least, get_current_accelerator_device -DEVICE_LIST = ["cpu"] + (["cuda"] if TEST_CUDA else []) +DEVICE_LIST = ["cpu"] + (["cuda"] if TEST_CUDA else []) + (["xpu"] if TEST_XPU else []) +_DEVICE = get_current_accelerator_device() if torch_version_at_least("2.7.0"): from torch.testing._internal.common_utils import ( @@ -2057,7 +2059,7 @@ def __init__(self) -> None: def forward(self, x): return self.bn(x) - if TEST_CUDA or TEST_HPU: + if TEST_CUDA or TEST_HPU or TEST_XPU: m = M().train().to(device) example_inputs = (torch.randn((1, 3, 3, 3), device=device),) @@ -2132,9 +2134,9 @@ def forward(self, x): x = self.dropout(x) return x - if TEST_CUDA: - m = M().train().cuda() - example_inputs = (torch.randn(1, 3, 3, 3).cuda(),) + if TEST_CUDA or TEST_XPU: + m = M().train().to(_assert_ops_are_correct) + example_inputs = (torch.randn(1, 3, 3, 3).to(_DEVICE),) else: m = M().train() example_inputs = (torch.randn(1, 3, 3, 3),) @@ -2146,7 +2148,7 @@ def _assert_ops_are_correct(m: torch.fx.GraphModule, train: bool): bn_op = bn_train_op if train else bn_eval_op bn_node = self._get_node(m, bn_op) self.assertTrue(bn_node is not None) - if TEST_CUDA: + if TEST_CUDA or TEST_XPU: self.assertEqual(bn_node.args[5], train) dropout_node = self._get_node(m, torch.ops.aten.dropout.default) self.assertEqual(dropout_node.args[2], train) From f331d2e31cb25781f0c9350454186fa5a49bc7ff Mon Sep 17 00:00:00 2001 From: "Zeng, Xiangdong" Date: Sun, 30 Nov 2025 22:42:45 +0800 Subject: [PATCH 2/7] add test/quantization/pt2e/test_quantize_pt2e.py --- test/quantization/pt2e/test_quantize_pt2e.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/quantization/pt2e/test_quantize_pt2e.py b/test/quantization/pt2e/test_quantize_pt2e.py index 21c9365df3..94d8fc1b0b 100644 --- a/test/quantization/pt2e/test_quantize_pt2e.py +++ b/test/quantization/pt2e/test_quantize_pt2e.py @@ -2134,7 +2134,7 @@ def forward(self, x): x = self.dropout(x) return x - if TEST_CUDA or TEST_XPU: + if TEST_CUDA: m = M().train().to(_assert_ops_are_correct) example_inputs = (torch.randn(1, 3, 3, 3).to(_DEVICE),) else: From 72467111fdbe2ac904f3225b3cc6449d326fa501 Mon Sep 17 00:00:00 2001 From: "Zeng, Xiangdong" Date: Sun, 30 Nov 2025 22:48:01 +0800 Subject: [PATCH 3/7] test/quantization/pt2e/test_quantize_pt2e_qat.py --- .../quantization/pt2e/test_quantize_pt2e_qat.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/test/quantization/pt2e/test_quantize_pt2e_qat.py b/test/quantization/pt2e/test_quantize_pt2e_qat.py index 293d243f6a..4b9197cf8e 100644 --- a/test/quantization/pt2e/test_quantize_pt2e_qat.py +++ b/test/quantization/pt2e/test_quantize_pt2e_qat.py @@ -28,7 +28,7 @@ skipIfNoQNNPACK, ) from torch.testing._internal.common_quantized import override_quantized_engine -from torch.testing._internal.common_utils import run_tests +from torch.testing._internal.common_utils import run_tests, TEST_XPU from torchao.quantization.pt2e import ( FusedMovingAvgObsFakeQuantize, @@ -52,8 +52,9 @@ XNNPACKQuantizer, get_symmetric_quantization_config, ) -from torchao.utils import torch_version_at_least +from torchao.utils import torch_version_at_least, get_current_accelerator_device +_DEVICE = get_current_accelerator_device() class PT2EQATTestCase(QuantizationTestCase): """ @@ -453,10 +454,10 @@ def test_qat_conv_bn_fusion(self): self._verify_symmetric_xnnpack_qat_graph(m, self.example_inputs, has_relu=False) self._verify_symmetric_xnnpack_qat_numerics(m, self.example_inputs) - @unittest.skipIf(not TEST_CUDA, "CUDA unavailable") + @unittest.skipIf(not TEST_CUDA or not TEST_XPU, "GPU unavailable") def test_qat_conv_bn_fusion_cuda(self): - m = self._get_conv_bn_model().cuda() - example_inputs = (self.example_inputs[0].cuda(),) + m = self._get_conv_bn_model().to(_DEVICE) + example_inputs = (self.example_inputs[0].to(_DEVICE),) self._verify_symmetric_xnnpack_qat_graph( m, example_inputs, @@ -540,10 +541,10 @@ def test_qat_conv_bn_relu_fusion(self): self._verify_symmetric_xnnpack_qat_graph(m, self.example_inputs, has_relu=True) self._verify_symmetric_xnnpack_qat_numerics(m, self.example_inputs) - @unittest.skipIf(not TEST_CUDA, "CUDA unavailable") + @unittest.skipIf(not TEST_CUDA or not TEST_XPU, "GPU unavailable") def test_qat_conv_bn_relu_fusion_cuda(self): - m = self._get_conv_bn_model(has_relu=True).cuda() - example_inputs = (self.example_inputs[0].cuda(),) + m = self._get_conv_bn_model(has_relu=True).to(_DEVICE) + example_inputs = (self.example_inputs[0].to(_DEVICE),) self._verify_symmetric_xnnpack_qat_graph( m, example_inputs, From 62f656589c65c38eb30606ce6d7f2a1f3c630475 Mon Sep 17 00:00:00 2001 From: "Zeng, Xiangdong" Date: Sun, 30 Nov 2025 22:50:51 +0800 Subject: [PATCH 4/7] test/quantization/pt2e/test_quantize_pt2e_qat.py --- test/quantization/pt2e/test_quantize_pt2e_qat.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/quantization/pt2e/test_quantize_pt2e_qat.py b/test/quantization/pt2e/test_quantize_pt2e_qat.py index 4b9197cf8e..3edb159b63 100644 --- a/test/quantization/pt2e/test_quantize_pt2e_qat.py +++ b/test/quantization/pt2e/test_quantize_pt2e_qat.py @@ -454,7 +454,7 @@ def test_qat_conv_bn_fusion(self): self._verify_symmetric_xnnpack_qat_graph(m, self.example_inputs, has_relu=False) self._verify_symmetric_xnnpack_qat_numerics(m, self.example_inputs) - @unittest.skipIf(not TEST_CUDA or not TEST_XPU, "GPU unavailable") + @unittest.skipIf(not TEST_CUDA and not TEST_XPU, "GPU unavailable") def test_qat_conv_bn_fusion_cuda(self): m = self._get_conv_bn_model().to(_DEVICE) example_inputs = (self.example_inputs[0].to(_DEVICE),) @@ -541,7 +541,7 @@ def test_qat_conv_bn_relu_fusion(self): self._verify_symmetric_xnnpack_qat_graph(m, self.example_inputs, has_relu=True) self._verify_symmetric_xnnpack_qat_numerics(m, self.example_inputs) - @unittest.skipIf(not TEST_CUDA or not TEST_XPU, "GPU unavailable") + @unittest.skipIf(not TEST_CUDA and not TEST_XPU, "GPU unavailable") def test_qat_conv_bn_relu_fusion_cuda(self): m = self._get_conv_bn_model(has_relu=True).to(_DEVICE) example_inputs = (self.example_inputs[0].to(_DEVICE),) From 6100c8be192ad4e58e588ef00b05d253972b278f Mon Sep 17 00:00:00 2001 From: "Zeng, Xiangdong" Date: Sun, 30 Nov 2025 23:12:23 +0800 Subject: [PATCH 5/7] fix format issue --- test/quantization/pt2e/test_quantize_pt2e.py | 2 +- test/quantization/pt2e/test_quantize_pt2e_qat.py | 5 +++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/test/quantization/pt2e/test_quantize_pt2e.py b/test/quantization/pt2e/test_quantize_pt2e.py index 94d8fc1b0b..6e0e7a9601 100644 --- a/test/quantization/pt2e/test_quantize_pt2e.py +++ b/test/quantization/pt2e/test_quantize_pt2e.py @@ -69,7 +69,7 @@ QuantizationConfig, ) from torchao.testing.pt2e.utils import PT2EQuantizationTestCase -from torchao.utils import torch_version_at_least, get_current_accelerator_device +from torchao.utils import get_current_accelerator_device, torch_version_at_least DEVICE_LIST = ["cpu"] + (["cuda"] if TEST_CUDA else []) + (["xpu"] if TEST_XPU else []) _DEVICE = get_current_accelerator_device() diff --git a/test/quantization/pt2e/test_quantize_pt2e_qat.py b/test/quantization/pt2e/test_quantize_pt2e_qat.py index 3edb159b63..2004c9e04a 100644 --- a/test/quantization/pt2e/test_quantize_pt2e_qat.py +++ b/test/quantization/pt2e/test_quantize_pt2e_qat.py @@ -28,7 +28,7 @@ skipIfNoQNNPACK, ) from torch.testing._internal.common_quantized import override_quantized_engine -from torch.testing._internal.common_utils import run_tests, TEST_XPU +from torch.testing._internal.common_utils import TEST_XPU, run_tests from torchao.quantization.pt2e import ( FusedMovingAvgObsFakeQuantize, @@ -52,10 +52,11 @@ XNNPACKQuantizer, get_symmetric_quantization_config, ) -from torchao.utils import torch_version_at_least, get_current_accelerator_device +from torchao.utils import get_current_accelerator_device, torch_version_at_least _DEVICE = get_current_accelerator_device() + class PT2EQATTestCase(QuantizationTestCase): """ Base QuantizationTestCase for PT2E QAT with some helper methods. From 4d72f8a139130a0b96293f237d5ab0f9d4442ea0 Mon Sep 17 00:00:00 2001 From: "Zeng, Xiangdong" Date: Sun, 30 Nov 2025 23:17:31 +0800 Subject: [PATCH 6/7] update format --- test/quantization/pt2e/test_quantize_pt2e.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/quantization/pt2e/test_quantize_pt2e.py b/test/quantization/pt2e/test_quantize_pt2e.py index 6e0e7a9601..4b188ecf35 100644 --- a/test/quantization/pt2e/test_quantize_pt2e.py +++ b/test/quantization/pt2e/test_quantize_pt2e.py @@ -2134,8 +2134,8 @@ def forward(self, x): x = self.dropout(x) return x - if TEST_CUDA: - m = M().train().to(_assert_ops_are_correct) + if TEST_CUDA or TEST_XPU: + m = M().train().to(_DEVICE) example_inputs = (torch.randn(1, 3, 3, 3).to(_DEVICE),) else: m = M().train() From 736d32011c08282bf316df73759090bf264b3e57 Mon Sep 17 00:00:00 2001 From: "Zeng, Xiangdong" Date: Thu, 4 Dec 2025 11:51:05 +0800 Subject: [PATCH 7/7] increase timeout for xpu --- .github/workflows/xpu_test.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/xpu_test.yml b/.github/workflows/xpu_test.yml index 3f7d1c7171..32420951e4 100644 --- a/.github/workflows/xpu_test.yml +++ b/.github/workflows/xpu_test.yml @@ -21,7 +21,7 @@ jobs: test: # Don't run on forked repos or empty test matrix # if: github.repository_owner == 'pytorch' && toJSON(fromJSON(inputs.test-matrix).include) != '[]' - timeout-minutes: 60 + timeout-minutes: 120 runs-on: linux.idc.xpu env: DOCKER_IMAGE: ci-image:pytorch-linux-noble-xpu-n-py3 @@ -166,7 +166,7 @@ jobs: GITHUB_RUN_NUMBER: ${{ github.run_number }} GITHUB_RUN_ATTEMPT: ${{ github.run_attempt }} SHA1: ${{ github.event.pull_request.head.sha || github.sha }} - timeout-minutes: 60 + timeout-minutes: 120 run: | set -x