Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions ci/pytorch.sh
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ run_test_config_mgpu(){
#run in parallel on CI and it affects timing
run_default_fa 1 test_gemm_sm_count.py
run_default_fa 3 test_sanity_import.py
run_default_fa 3 distributed/test_cast_master_weights_to_fp8.py
run_default_fa 2 distributed/test_fusible_ops.py
run_default_fa 2 distributed/test_numerics.py
run_default_fa 1 distributed/test_torch_fsdp2.py
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# This file was modified for portability to AMDGPU.
# Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved.
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.

Expand Down Expand Up @@ -717,7 +719,9 @@ def run_parallel_tests() -> None:
@pytest.mark.parametrize("world_size", [2])
def test_cast_master_weights_to_fp8(world_size: int) -> None:
"""Launch parallel job that runs parallel tests"""
python_exe = pathlib.Path(sys.executable).resolve()
# ROCm: Use executable as-is; do not resolve() or a venv symlink may point to system
# Python which does not have torch/site-packages.
python_exe = pathlib.Path(sys.executable)
current_file = pathlib.Path(__file__).resolve()
command = [
python_exe,
Expand Down
10 changes: 6 additions & 4 deletions transformer_engine/pytorch/tensor/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# This file was modified for portability to AMDGPU
# Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved.
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.

Expand All @@ -15,7 +17,7 @@
from .mxfp8_tensor import MXFP8Tensor, MXFP8Quantizer
from .float8_blockwise_tensor import Float8BlockwiseQTensor, Float8BlockQuantizer
from ..optimizers.multi_tensor_apply import multi_tensor_applier
from ..utils import is_non_tn_fp8_gemm_supported
from ..utils import is_non_tn_fp8_gemm_supported, is_fp8_fnuz


def replace_raw_data(tensor: QuantizedTensor, new_raw_data: torch.Tensor):
Expand Down Expand Up @@ -293,7 +295,7 @@ def _cast_master_weights_to_fp8_current_scaling(
# Step 3: Update scales and scale_invs.
# ---------------------------------------------------------------------------------------------
if fp8_dtype == tex.DType.kFloat8E4M3:
max_fp8 = 448.0
max_fp8 = 240.0 if is_fp8_fnuz() else 448.0
elif fp8_dtype == tex.DType.kFloat8E5M2:
max_fp8 = 57344.0
else:
Expand Down Expand Up @@ -424,7 +426,7 @@ def _cast_master_weights_to_fp8_blockwise_scaling(
# Step 3: Update scales and scale_invs.
# ---------------------------------------------------------------------------------------------
if fp8_dtype == tex.DType.kFloat8E4M3:
max_fp8 = 448.0
max_fp8 = 240.0 if is_fp8_fnuz() else 448.0
elif fp8_dtype == tex.DType.kFloat8E5M2:
max_fp8 = 57344.0
else:
Expand Down