diff --git a/ci/pytorch.sh b/ci/pytorch.sh index bb972dfa6..30bbb0d2a 100755 --- a/ci/pytorch.sh +++ b/ci/pytorch.sh @@ -93,6 +93,7 @@ run_test_config_mgpu(){ #run in parallel on CI and it affects timing run_default_fa 1 test_gemm_sm_count.py run_default_fa 3 test_sanity_import.py + run_default_fa 3 distributed/test_cast_master_weights_to_fp8.py run_default_fa 2 distributed/test_fusible_ops.py run_default_fa 2 distributed/test_numerics.py run_default_fa 1 distributed/test_torch_fsdp2.py diff --git a/tests/pytorch/distributed/test_cast_master_weights_to_fp8.py b/tests/pytorch/distributed/test_cast_master_weights_to_fp8.py index 0ff98e6cb..a8fca0577 100644 --- a/tests/pytorch/distributed/test_cast_master_weights_to_fp8.py +++ b/tests/pytorch/distributed/test_cast_master_weights_to_fp8.py @@ -1,4 +1,6 @@ -# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# This file was modified for portability to AMDGPU. +# Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. @@ -717,7 +719,9 @@ def run_parallel_tests() -> None: @pytest.mark.parametrize("world_size", [2]) def test_cast_master_weights_to_fp8(world_size: int) -> None: """Launch parallel job that runs parallel tests""" - python_exe = pathlib.Path(sys.executable).resolve() + # ROCm: Use executable as-is; do not resolve() or a venv symlink may point to system + # Python which does not have torch/site-packages. + python_exe = pathlib.Path(sys.executable) current_file = pathlib.Path(__file__).resolve() command = [ python_exe, diff --git a/transformer_engine/pytorch/tensor/utils.py b/transformer_engine/pytorch/tensor/utils.py index 9773e17e6..00b2e5dc1 100644 --- a/transformer_engine/pytorch/tensor/utils.py +++ b/transformer_engine/pytorch/tensor/utils.py @@ -1,4 +1,6 @@ -# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# This file was modified for portability to AMDGPU +# Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. @@ -15,7 +17,7 @@ from .mxfp8_tensor import MXFP8Tensor, MXFP8Quantizer from .float8_blockwise_tensor import Float8BlockwiseQTensor, Float8BlockQuantizer from ..optimizers.multi_tensor_apply import multi_tensor_applier -from ..utils import is_non_tn_fp8_gemm_supported +from ..utils import is_non_tn_fp8_gemm_supported, is_fp8_fnuz def replace_raw_data(tensor: QuantizedTensor, new_raw_data: torch.Tensor): @@ -293,7 +295,7 @@ def _cast_master_weights_to_fp8_current_scaling( # Step 3: Update scales and scale_invs. # --------------------------------------------------------------------------------------------- if fp8_dtype == tex.DType.kFloat8E4M3: - max_fp8 = 448.0 + max_fp8 = 240.0 if is_fp8_fnuz() else 448.0 elif fp8_dtype == tex.DType.kFloat8E5M2: max_fp8 = 57344.0 else: @@ -424,7 +426,7 @@ def _cast_master_weights_to_fp8_blockwise_scaling( # Step 3: Update scales and scale_invs. # --------------------------------------------------------------------------------------------- if fp8_dtype == tex.DType.kFloat8E4M3: - max_fp8 = 448.0 + max_fp8 = 240.0 if is_fp8_fnuz() else 448.0 elif fp8_dtype == tex.DType.kFloat8E5M2: max_fp8 = 57344.0 else: