diff --git a/.github/workflows/xpu_test.yml b/.github/workflows/xpu_test.yml index 3f7d1c7171..32420951e4 100644 --- a/.github/workflows/xpu_test.yml +++ b/.github/workflows/xpu_test.yml @@ -21,7 +21,7 @@ jobs: test: # Don't run on forked repos or empty test matrix # if: github.repository_owner == 'pytorch' && toJSON(fromJSON(inputs.test-matrix).include) != '[]' - timeout-minutes: 60 + timeout-minutes: 120 runs-on: linux.idc.xpu env: DOCKER_IMAGE: ci-image:pytorch-linux-noble-xpu-n-py3 @@ -166,7 +166,7 @@ jobs: GITHUB_RUN_NUMBER: ${{ github.run_number }} GITHUB_RUN_ATTEMPT: ${{ github.run_attempt }} SHA1: ${{ github.event.pull_request.head.sha || github.sha }} - timeout-minutes: 60 + timeout-minutes: 120 run: | set -x diff --git a/test/quantization/pt2e/test_quantize_pt2e.py b/test/quantization/pt2e/test_quantize_pt2e.py index 0b5fd64120..4b188ecf35 100644 --- a/test/quantization/pt2e/test_quantize_pt2e.py +++ b/test/quantization/pt2e/test_quantize_pt2e.py @@ -30,6 +30,7 @@ ) from torch.testing._internal.common_utils import ( TEST_CUDA, + TEST_XPU, TemporaryFileName, instantiate_parametrized_tests, parametrize, @@ -68,9 +69,10 @@ QuantizationConfig, ) from torchao.testing.pt2e.utils import PT2EQuantizationTestCase -from torchao.utils import torch_version_at_least +from torchao.utils import get_current_accelerator_device, torch_version_at_least -DEVICE_LIST = ["cpu"] + (["cuda"] if TEST_CUDA else []) +DEVICE_LIST = ["cpu"] + (["cuda"] if TEST_CUDA else []) + (["xpu"] if TEST_XPU else []) +_DEVICE = get_current_accelerator_device() if torch_version_at_least("2.7.0"): from torch.testing._internal.common_utils import ( @@ -2057,7 +2059,7 @@ def __init__(self) -> None: def forward(self, x): return self.bn(x) - if TEST_CUDA or TEST_HPU: + if TEST_CUDA or TEST_HPU or TEST_XPU: m = M().train().to(device) example_inputs = (torch.randn((1, 3, 3, 3), device=device),) @@ -2132,9 +2134,9 @@ def forward(self, x): x = self.dropout(x) return x - if TEST_CUDA: - m = M().train().cuda() - example_inputs = (torch.randn(1, 3, 3, 3).cuda(),) + if TEST_CUDA or TEST_XPU: + m = M().train().to(_DEVICE) + example_inputs = (torch.randn(1, 3, 3, 3).to(_DEVICE),) else: m = M().train() example_inputs = (torch.randn(1, 3, 3, 3),) @@ -2146,7 +2148,7 @@ def _assert_ops_are_correct(m: torch.fx.GraphModule, train: bool): bn_op = bn_train_op if train else bn_eval_op bn_node = self._get_node(m, bn_op) self.assertTrue(bn_node is not None) - if TEST_CUDA: + if TEST_CUDA or TEST_XPU: self.assertEqual(bn_node.args[5], train) dropout_node = self._get_node(m, torch.ops.aten.dropout.default) self.assertEqual(dropout_node.args[2], train) diff --git a/test/quantization/pt2e/test_quantize_pt2e_qat.py b/test/quantization/pt2e/test_quantize_pt2e_qat.py index 293d243f6a..2004c9e04a 100644 --- a/test/quantization/pt2e/test_quantize_pt2e_qat.py +++ b/test/quantization/pt2e/test_quantize_pt2e_qat.py @@ -28,7 +28,7 @@ skipIfNoQNNPACK, ) from torch.testing._internal.common_quantized import override_quantized_engine -from torch.testing._internal.common_utils import run_tests +from torch.testing._internal.common_utils import TEST_XPU, run_tests from torchao.quantization.pt2e import ( FusedMovingAvgObsFakeQuantize, @@ -52,7 +52,9 @@ XNNPACKQuantizer, get_symmetric_quantization_config, ) -from torchao.utils import torch_version_at_least +from torchao.utils import get_current_accelerator_device, torch_version_at_least + +_DEVICE = get_current_accelerator_device() class PT2EQATTestCase(QuantizationTestCase): @@ -453,10 +455,10 @@ def test_qat_conv_bn_fusion(self): self._verify_symmetric_xnnpack_qat_graph(m, self.example_inputs, has_relu=False) self._verify_symmetric_xnnpack_qat_numerics(m, self.example_inputs) - @unittest.skipIf(not TEST_CUDA, "CUDA unavailable") + @unittest.skipIf(not TEST_CUDA and not TEST_XPU, "GPU unavailable") def test_qat_conv_bn_fusion_cuda(self): - m = self._get_conv_bn_model().cuda() - example_inputs = (self.example_inputs[0].cuda(),) + m = self._get_conv_bn_model().to(_DEVICE) + example_inputs = (self.example_inputs[0].to(_DEVICE),) self._verify_symmetric_xnnpack_qat_graph( m, example_inputs, @@ -540,10 +542,10 @@ def test_qat_conv_bn_relu_fusion(self): self._verify_symmetric_xnnpack_qat_graph(m, self.example_inputs, has_relu=True) self._verify_symmetric_xnnpack_qat_numerics(m, self.example_inputs) - @unittest.skipIf(not TEST_CUDA, "CUDA unavailable") + @unittest.skipIf(not TEST_CUDA and not TEST_XPU, "GPU unavailable") def test_qat_conv_bn_relu_fusion_cuda(self): - m = self._get_conv_bn_model(has_relu=True).cuda() - example_inputs = (self.example_inputs[0].cuda(),) + m = self._get_conv_bn_model(has_relu=True).to(_DEVICE) + example_inputs = (self.example_inputs[0].to(_DEVICE),) self._verify_symmetric_xnnpack_qat_graph( m, example_inputs,