diff --git a/.github/workflows/xpu_test.yml b/.github/workflows/xpu_test.yml index 3f7d1c7171..32420951e4 100644 --- a/.github/workflows/xpu_test.yml +++ b/.github/workflows/xpu_test.yml @@ -21,7 +21,7 @@ jobs: test: # Don't run on forked repos or empty test matrix # if: github.repository_owner == 'pytorch' && toJSON(fromJSON(inputs.test-matrix).include) != '[]' - timeout-minutes: 60 + timeout-minutes: 120 runs-on: linux.idc.xpu env: DOCKER_IMAGE: ci-image:pytorch-linux-noble-xpu-n-py3 @@ -166,7 +166,7 @@ jobs: GITHUB_RUN_NUMBER: ${{ github.run_number }} GITHUB_RUN_ATTEMPT: ${{ github.run_attempt }} SHA1: ${{ github.event.pull_request.head.sha || github.sha }} - timeout-minutes: 60 + timeout-minutes: 120 run: | set -x diff --git a/test/float8/test_base.py b/test/float8/test_base.py index 1f9ae19346..70d8eda08f 100644 --- a/test/float8/test_base.py +++ b/test/float8/test_base.py @@ -44,6 +44,7 @@ from torchao.testing.training.test_utils import get_test_float8_linear_config from torchao.testing.utils import skip_if_rocm from torchao.utils import ( + get_current_accelerator_device, is_MI300, is_ROCM, is_sm_at_least_89, @@ -52,6 +53,7 @@ random.seed(0) torch.manual_seed(0) +_DEVICE = get_current_accelerator_device() def bitwise_identical(a: Float8TrainingTensor, b: Float8TrainingTensor) -> bool: @@ -231,11 +233,14 @@ def test_axiswise_reshape(self): (ScalingGranularity.TENSORWISE, ScalingGranularity.AXISWISE), ], ) - @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") - @unittest.skipIf(not is_sm_at_least_90(), "Requires CUDA capability >= 9.0") + @unittest.skipIf(not torch.accelerator.is_available(), "GPU not available") + @unittest.skipIf( + torch.accelerator.is_available() and not is_sm_at_least_90(), + "Requires CUDA capability >= 9.0", + ) def test_axiswise_gemm(self, a_shape, a_granularity, b_granularity): - a = torch.randn(*a_shape, dtype=torch.bfloat16, device="cuda") - b = torch.randn(64, 32, dtype=torch.bfloat16, device="cuda") + a = torch.randn(*a_shape, dtype=torch.bfloat16, device=_DEVICE) + b = torch.randn(64, 32, dtype=torch.bfloat16, device=_DEVICE) linear_mm_config = LinearMMConfig() @@ -264,7 +269,7 @@ def test_axiswise_gemm(self, a_shape, a_granularity, b_granularity): sqnr = compute_error(c_ref, c_fp8_compute) assert sqnr >= 25.0 - @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") + @unittest.skipIf(not torch.accelerator.is_available(), "GPU not available") def test_fp8_dtype( self, ): @@ -329,7 +334,7 @@ def _test_linear_impl( @pytest.mark.parametrize("linear_dtype", [torch.bfloat16, torch.float32]) @pytest.mark.parametrize("linear_bias", [False, True]) @pytest.mark.parametrize("use_ac", [False, True]) - @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") + @unittest.skipIf(not torch.accelerator.is_available(), "GPU not available") def test_linear_from_config_params( self, x_shape, @@ -341,8 +346,8 @@ def test_linear_from_config_params( linear_bias: bool, use_ac: bool, ): - x = torch.randn(*x_shape, device="cuda", dtype=linear_dtype) - m_ref = nn.Linear(16, 32, bias=linear_bias, device="cuda", dtype=linear_dtype) + x = torch.randn(*x_shape, device=_DEVICE, dtype=linear_dtype) + m_ref = nn.Linear(16, 32, bias=linear_bias, device=_DEVICE, dtype=linear_dtype) config = get_test_float8_linear_config( scaling_type_input, @@ -386,8 +391,8 @@ def test_linear_from_recipe( linear_dtype: torch.dtype, linear_bias: bool, ): - x = torch.randn(*x_shape, device="cuda", dtype=linear_dtype) - m_ref = nn.Linear(16, 32, bias=linear_bias, device="cuda", dtype=linear_dtype) + x = torch.randn(*x_shape, device=_DEVICE, dtype=linear_dtype) + m_ref = nn.Linear(16, 32, bias=linear_bias, device=_DEVICE, dtype=linear_dtype) config = Float8LinearConfig.from_recipe_name(recipe_name) self._test_linear_impl( x, @@ -409,7 +414,7 @@ def test_linear_from_recipe( Float8LinearRecipeName.ROWWISE_WITH_GW_HP, ], ) - @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") + @unittest.skipIf(not torch.accelerator.is_available(), "GPU not available") def test_autocast_outputs( self, emulate: bool, @@ -417,8 +422,8 @@ def test_autocast_outputs( recipe_name: Float8LinearRecipeName, ): m_ref = nn.Sequential( - nn.Linear(32, 32, device="cuda", dtype=linear_dtype), - nn.Linear(32, 32, device="cuda", dtype=linear_dtype), + nn.Linear(32, 32, device=_DEVICE, dtype=linear_dtype), + nn.Linear(32, 32, device=_DEVICE, dtype=linear_dtype), ) config = Float8LinearConfig.from_recipe_name(recipe_name) # work around config being frozen @@ -427,16 +432,16 @@ def test_autocast_outputs( m = convert_to_float8_training(copy.deepcopy(m_ref), config=config) # autocast off - x = torch.randn(16, 32, device="cuda", dtype=linear_dtype) + x = torch.randn(16, 32, device=_DEVICE, dtype=linear_dtype) y = m(x) assert y.dtype == linear_dtype, f"y.dtype is {y.dtype}, expected {linear_dtype}" # autocast on - with torch.autocast("cuda"): + with torch.autocast(_DEVICE.type): y = m(x) assert y.dtype == torch.half, f"y.dtype is {y.dtype}, expected {torch.half}" - with torch.autocast("cuda", dtype=torch.bfloat16): + with torch.autocast(_DEVICE.type, dtype=torch.bfloat16): y = m(x) assert y.dtype == torch.bfloat16, ( f"y.dtype is {y.dtype}, expected {torch.bfloat16}" @@ -454,18 +459,25 @@ def test_repr(self): s = m.__repr__() assert "i:dyn_ten_e4m3,w:dyn_ten_e4m3,go:dyn_ten_e5m2" in s - @unittest.skipIf(not is_sm_at_least_89(), "CUDA 8.9 not available") + @unittest.skipIf(not torch.accelerator.is_available(), "GPU not available") + @unittest.skipIf( + torch.cuda.is_available() and not is_sm_at_least_89(), "CUDA 8.9 not available" + ) def test_inference_mode(self): - x = torch.randn(32, 32, device="cuda") - m = nn.Sequential(nn.Linear(32, 32)).cuda() + x = torch.randn(32, 32, device=_DEVICE) + m = nn.Sequential(nn.Linear(32, 32)).to(_DEVICE) m = convert_to_float8_training(m) with torch.inference_mode(mode=True): m(x) - @unittest.skipIf(not is_sm_at_least_89(), "CUDA arch 8.9 not available") + @unittest.skipIf(not torch.accelerator.is_available(), "GPU not available") + @unittest.skipIf( + torch.cuda.is_available() and not is_sm_at_least_89(), + "CUDA arch 8.9 not available", + ) def test_quantize(self): - x = torch.randn(32, 32, device="cuda") - m = nn.Sequential(nn.Linear(32, 32)).cuda() + x = torch.randn(32, 32, device=_DEVICE) + m = nn.Sequential(nn.Linear(32, 32)).to(_DEVICE) m = convert_to_float8_training(m) assert isinstance(m[0], Float8Linear), "Module is not a Float8Linear" from torchao.quantization import Float8WeightOnlyConfig, quantize_ @@ -479,8 +491,9 @@ def test_quantize(self): class TestScaledMM: + @unittest.skipIf(not torch.accelerator.is_available(), "GPU not available") @unittest.skipIf( - not is_sm_at_least_89(), + torch.cuda.is_available() and not is_sm_at_least_89(), "CUDA not available", ) @pytest.mark.parametrize( @@ -493,8 +506,8 @@ def test_scaled_mm_vs_emulated(self, base_dtype, use_fast_accum): output_dtype = base_dtype compare_type = torch.float32 - a = torch.randn(16, 16, device="cuda", dtype=base_dtype) - b = torch.randn(32, 16, device="cuda", dtype=base_dtype).t() + a = torch.randn(16, 16, device=_DEVICE, dtype=base_dtype) + b = torch.randn(32, 16, device=_DEVICE, dtype=base_dtype).t() a_scale = tensor_to_scale(a, input_dtype).float() b_scale = tensor_to_scale(b, input_dtype).float() @@ -525,10 +538,13 @@ def test_scaled_mm_vs_emulated(self, base_dtype, use_fast_accum): atol, rtol = 3e-3, 3e-3 torch.testing.assert_close(out_scaled_mm, out_emulated, atol=atol, rtol=rtol) - @unittest.skipIf(not is_sm_at_least_89(), "CUDA not available") + @unittest.skipIf(not torch.accelerator.is_available(), "GPU not available") + @unittest.skipIf( + torch.cuda.is_available() and not is_sm_at_least_89(), "CUDA not available" + ) def test_different_configs_error(self): - x_fp32 = torch.randn(16, 16, device="cuda") - x_scale = torch.tensor(1.0, device="cuda") + x_fp32 = torch.randn(16, 16, device=_DEVICE) + x_scale = torch.tensor(1.0, device=_DEVICE) fp8_dtype = e4m3_dtype linear_config_a = LinearMMConfig( ScaledMMConfig(False, True, False, False), @@ -573,8 +589,8 @@ def test_pad_inner_dim(self, base_dtype, use_fast_accum): input_dtype = e4m3_dtype compare_type = torch.float32 - a = torch.randn(16, 41, device="cuda", dtype=base_dtype) - b = torch.randn(41, 128, device="cuda", dtype=base_dtype) + a = torch.randn(16, 41, device=_DEVICE, dtype=base_dtype) + b = torch.randn(41, 128, device=_DEVICE, dtype=base_dtype) a_scale = tensor_to_scale(a, input_dtype).float() b_scale = tensor_to_scale(b, input_dtype).float() @@ -652,7 +668,7 @@ class TestNumerics: torch.float8_e5m2fnuz, ], ) - @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") + @unittest.skipIf(not torch.accelerator.is_available(), "GPU not available") def test_small_amax_float16(self, float8_dtype): # If we calculate scale naively with FP8_MAX_POS / amax, # the result may not be representable in fp16. Verify that @@ -671,7 +687,7 @@ def test_small_amax_float16(self, float8_dtype): FP16_MAX_POS = torch.finfo(torch.float16).max target_amax = float8_max_pos / (FP16_MAX_POS + 1e-12) - x = torch.tensor([target_amax], dtype=torch.float16, device="cuda") + x = torch.tensor([target_amax], dtype=torch.float16, device=_DEVICE) scale = tensor_to_scale(x, float8_dtype) assert not torch.any(torch.isinf(scale)) diff --git a/test/float8/test_compile.py b/test/float8/test_compile.py index 04f03bb0ee..691447ed8d 100644 --- a/test/float8/test_compile.py +++ b/test/float8/test_compile.py @@ -33,10 +33,13 @@ ) from torchao.testing.training.test_utils import get_test_float8_linear_config from torchao.utils import ( + get_current_accelerator_device, is_sm_at_least_89, is_sm_at_least_90, ) +_DEVICE = get_current_accelerator_device() + def _test_compile_base( backend: str, @@ -49,9 +52,9 @@ def _test_compile_base( x_shape = (16, 16) linear_dtype = torch.bfloat16 - x = torch.randn(*x_shape, device="cuda", dtype=linear_dtype).requires_grad_() + x = torch.randn(*x_shape, device=_DEVICE, dtype=linear_dtype).requires_grad_() x_ref = copy.deepcopy(x) - m_ref = nn.Linear(16, 32, bias=True, device="cuda", dtype=linear_dtype) + m_ref = nn.Linear(16, 32, bias=True, device=_DEVICE, dtype=linear_dtype) m_fp8 = Float8Linear.from_float( copy.deepcopy(m_ref), @@ -86,7 +89,7 @@ def _test_compile_base( ) @pytest.mark.parametrize("emulate", [False, True] if is_sm_at_least_89() else [True]) @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float32]) -@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") +@unittest.skipIf(not torch.accelerator.is_available(), "GPU not available") def test_eager_only( fullgraph, emulate: bool, @@ -122,7 +125,7 @@ def test_eager_only( [ScalingType.DYNAMIC], ) @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float32]) -@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") +@unittest.skipIf(not torch.accelerator.is_available(), "GPU not available") def test_aot_eager( fullgraph, emulate: bool, @@ -157,8 +160,9 @@ def test_aot_eager( "scaling_type_grad_output", [ScalingType.DYNAMIC], ) +@unittest.skipIf(not torch.accelerator.is_available(), "GPU not available") @unittest.skipIf( - not torch.cuda.is_available() or not is_sm_at_least_89(), + torch.accelerator.is_available() and not is_sm_at_least_89(), "CUDA with float8 support not available", ) @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float32]) @@ -196,8 +200,10 @@ def test_inductor_from_config_params( Float8LinearRecipeName.ROWWISE_WITH_GW_HP, ], ) +@unittest.skipIf(not torch.accelerator.is_available(), "GPU not available") @unittest.skipIf( - not is_sm_at_least_90(), "CUDA with capability 9.0 or greater not available" + torch.accelerator.is_available() and not is_sm_at_least_90(), + "CUDA with capability 9.0 or greater not available", ) def test_inductor_from_recipe(recipe_name): torch._dynamo.reset() @@ -231,24 +237,26 @@ def forward(self, x): return x_fp8 # TODO(future): figure out why the test below fails on CUDA capability 8.9 + @unittest.skipIf(not torch.accelerator.is_available(), "GPU not available") @unittest.skipIf( - not torch.cuda.is_available() or not is_sm_at_least_90(), + torch.cuda.is_available() and not is_sm_at_least_90(), "CUDA with capability 9.0 or greater not available", ) def test_float8_with_graph_break_in_the_middle(self): """Test that having Float8TrainingTensor object at the boundary of a subgraph""" cnts = CompileCounterWithBackend("inductor") - mod = self.MockLinear(graph_break=True).cuda() + mod = self.MockLinear(graph_break=True).to(_DEVICE) compiled_mod = copy.deepcopy(mod) compiled_mod = torch.compile(compiled_mod, backend=cnts) - x = torch.randn(16, 16, device="cuda") + x = torch.randn(16, 16, device=_DEVICE) y_eager = mod(x) y_compiled = compiled_mod(x) self.assertEqual(cnts.frame_count, 2, "Compiled graph should have 2 frames!") torch.testing.assert_close(y_eager, y_compiled) + @unittest.skipIf(not torch.accelerator.is_available(), "GPU not available") @unittest.skipIf( - not torch.cuda.is_available() or not is_sm_at_least_89(), + torch.cuda.is_available() and not is_sm_at_least_89(), "CUDA with float8 support not available", ) def test_float8_graph_input(self): @@ -258,8 +266,8 @@ def to_float(x): return x.to_original_precision() cnts = CompileCounterWithBackend("inductor") - mod = self.MockLinear(graph_break=False).cuda() - x = torch.randn(2, 2, device="cuda") + mod = self.MockLinear(graph_break=False).to(_DEVICE) + x = torch.randn(2, 2, device=_DEVICE) compiled_to_float = torch.compile(to_float, backend=cnts) y = mod(x) y2_eager = to_float(y) @@ -271,16 +279,17 @@ def to_float(x): ) torch.testing.assert_close(y2_eager, y2_compiled) + @unittest.skipIf(not torch.accelerator.is_available(), "GPU not available") @unittest.skipIf( - not torch.cuda.is_available() or not is_sm_at_least_89(), + torch.cuda.is_available() and not is_sm_at_least_89(), "CUDA with float8 support not available", ) def test_float8_graph_output(self): """Test that having Float8TrainingTensor object as a graph output works""" cnts = CompileCounterWithBackend("inductor") - mod = self.MockLinear(graph_break=False).cuda() + mod = self.MockLinear(graph_break=False).to(_DEVICE) compiled_mod = torch.compile(mod, backend=cnts) - x = torch.randn(16, 16, device="cuda") + x = torch.randn(16, 16, device=_DEVICE) y_compiled = compiled_mod(x) self.assertEqual(cnts.frame_count, 1, "Compiled graph should have 1 frame!") @@ -318,8 +327,9 @@ def __exit__(self, *args): sys.stderr = self.sys_stderr +@unittest.skipIf(not torch.accelerator.is_available(), "GPU not available") @unittest.skipIf( - not is_sm_at_least_89(), + torch.cuda.is_available() and not is_sm_at_least_89(), "CUDA not available", ) @pytest.mark.parametrize( @@ -342,7 +352,7 @@ def test_dynamic_scale_numeric_parity( ): scaling_type_weight = ScalingType.DYNAMIC torch.manual_seed(42) - hp_tensor1 = torch.randn(16, 16, device="cuda", dtype=dtype) + hp_tensor1 = torch.randn(16, 16, device=_DEVICE, dtype=dtype) hp_tensor2 = hp_tensor1.detach().clone() float8_config = Float8LinearConfig( cast_config_weight=CastConfig(scaling_type=scaling_type_weight),