diff --git a/conftest.py b/conftest.py index 910e9fcc1766..290088cadca6 100644 --- a/conftest.py +++ b/conftest.py @@ -31,6 +31,7 @@ patch_testing_methods_to_collect_info, patch_torch_compile_force_graph, ) +from transformers.utils.import_utils import _set_tf32_mode NOT_DEVICE_TESTS = { @@ -137,12 +138,9 @@ def check_output(self, want, got, optionflags): doctest.DocTestParser = HfDocTestParser if is_torch_available(): - import torch - # The flag below controls whether to allow TF32 on cuDNN. This flag defaults to True. # We set it to `False` for CI. See https://github.com/pytorch/pytorch/issues/157274#issuecomment-3090791615 - torch.backends.cudnn.allow_tf32 = False - + _set_tf32_mode(False) # patch `torch.compile`: if `TORCH_COMPILE_FORCE_FULLGRAPH=1` (or values considered as true, e.g. yes, y, etc.), # the patched version will always run with `fullgraph=True`. patch_torch_compile_force_graph() diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index ede0a922cdb6..e41d2e25b1ec 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -53,7 +53,7 @@ requires_backends, ) from .utils.generic import strtobool -from .utils.import_utils import is_optimum_neuron_available +from .utils.import_utils import _set_tf32_mode, is_optimum_neuron_available logger = logging.get_logger(__name__) @@ -379,7 +379,7 @@ class TrainingArguments: metric values. tf32 (`bool`, *optional*): Whether to enable the TF32 mode, available in Ampere and newer GPU architectures. The default value depends - on PyTorch's version default of `torch.backends.cuda.matmul.allow_tf32`. For more details please refer to + on PyTorch's version default of `torch.backends.cuda.matmul.allow_tf32` and For PyTorch 2.9+ `torch.backends.cuda.matmul.fp32_precision` . For more details please refer to the [TF32](https://huggingface.co/docs/transformers/perf_train_gpu_one#tf32) documentation. This is an experimental API and it may change. ddp_backend (`str`, *optional*): @@ -1604,8 +1604,7 @@ def __post_init__(self): if is_torch_musa_available(): torch.backends.mudnn.allow_tf32 = True else: - torch.backends.cuda.matmul.allow_tf32 = True - torch.backends.cudnn.allow_tf32 = True + _set_tf32_mode(True) else: logger.warning( "The speedups for torchdynamo mostly come with GPU Ampere or higher and which is not detected here." @@ -1616,8 +1615,7 @@ def __post_init__(self): if is_torch_musa_available(): torch.backends.mudnn.allow_tf32 = True else: - torch.backends.cuda.matmul.allow_tf32 = True - torch.backends.cudnn.allow_tf32 = True + _set_tf32_mode(True) else: raise ValueError("--tf32 requires Ampere or a newer GPU arch, cuda>=11 and torch>=1.7") else: @@ -1625,8 +1623,7 @@ def __post_init__(self): if is_torch_musa_available(): torch.backends.mudnn.allow_tf32 = False else: - torch.backends.cuda.matmul.allow_tf32 = False - torch.backends.cudnn.allow_tf32 = False + _set_tf32_mode(False) # no need to assert on else if self.report_to == "all" or self.report_to == ["all"]: diff --git a/src/transformers/utils/import_utils.py b/src/transformers/utils/import_utils.py index bf2fba35fd0e..8dff708e8bd8 100644 --- a/src/transformers/utils/import_utils.py +++ b/src/transformers/utils/import_utils.py @@ -58,7 +58,7 @@ def _is_package_available(pkg_name: str, return_version: bool = False) -> tuple[ # pick the first item of the list as best guess (it's almost always a list of length 1 anyway) distribution_name = pkg_name if pkg_name in distributions else distributions[0] package_version = importlib.metadata.version(distribution_name) - except (importlib.metadata.PackageNotFoundError, KeyError): + except importlib.metadata.PackageNotFoundError: # If we cannot find the metadata (because of editable install for example), try to import directly. # Note that this branch will almost never be run, so we do not import packages for nothing here package = importlib.import_module(pkg_name) @@ -87,7 +87,6 @@ def _is_package_available(pkg_name: str, return_version: bool = False) -> tuple[ TORCHAO_MIN_VERSION = "0.4.0" AUTOROUND_MIN_VERSION = "0.5.0" TRITON_MIN_VERSION = "1.0.0" -KERNELS_MIN_VERSION = "0.9.0" @lru_cache @@ -503,6 +502,27 @@ def is_torch_tf32_available() -> bool: return True +@lru_cache +def _set_tf32_mode(enable: bool) -> None: + """ + Set TF32 mode using the appropriate PyTorch API. + For PyTorch 2.9+, uses the new fp32_precision API. + For older versions, uses the legacy allow_tf32 flags. + Args: + enable: Whether to enable TF32 mode + """ + import torch + + pytorch_version = version.parse(get_torch_version()) + if pytorch_version >= version.parse("2.9.0"): + precision_mode = "tf32" if enable else "ieee" + torch.backends.cuda.matmul.fp32_precision = precision_mode + torch.backends.cudnn.fp32_precision = precision_mode + else: + torch.backends.cuda.matmul.allow_tf32 = enable + torch.backends.cudnn.allow_tf32 = enable + + @lru_cache def is_torch_flex_attn_available() -> bool: return is_torch_available() and version.parse(get_torch_version()) >= version.parse("2.5.0") @@ -514,9 +534,8 @@ def is_kenlm_available() -> bool: @lru_cache -def is_kernels_available(MIN_VERSION: str = KERNELS_MIN_VERSION) -> bool: - is_available, kernels_version = _is_package_available("kernels", return_version=True) - return is_available and version.parse(kernels_version) >= version.parse(MIN_VERSION) +def is_kernels_available() -> bool: + return _is_package_available("kernels") @lru_cache @@ -973,13 +992,13 @@ def is_quark_available() -> bool: @lru_cache def is_fp_quant_available(): is_available, fp_quant_version = _is_package_available("fp_quant", return_version=True) - return is_available and version.parse(fp_quant_version) >= version.parse("0.3.2") + return is_available and version.parse(fp_quant_version) >= version.parse("0.2.0") @lru_cache def is_qutlass_available(): is_available, qutlass_version = _is_package_available("qutlass", return_version=True) - return is_available and version.parse(qutlass_version) >= version.parse("0.2.0") + return is_available and version.parse(qutlass_version) >= version.parse("0.1.0") @lru_cache @@ -1178,12 +1197,9 @@ def is_mistral_common_available() -> bool: @lru_cache def is_opentelemetry_available() -> bool: - try: - return _is_package_available("opentelemetry") and version.parse( - importlib.metadata.version("opentelemetry-api") - ) >= version.parse("1.30.0") - except Exception as _: - return False + return _is_package_available("opentelemetry") and version.parse( + importlib.metadata.version("opentelemetry-api") + ) >= version.parse("1.30.0") def check_torch_load_is_safe() -> None: diff --git a/tests/utils/test_import_utils.py b/tests/utils/test_import_utils.py index fe616e9cfbe2..d18c6499400c 100644 --- a/tests/utils/test_import_utils.py +++ b/tests/utils/test_import_utils.py @@ -1,7 +1,12 @@ +import logging import sys +from unittest.mock import MagicMock, patch + +import pytest +from packaging import version from transformers.testing_utils import run_test_using_subprocess -from transformers.utils.import_utils import clear_import_cache +from transformers.utils.import_utils import _set_tf32_mode, clear_import_cache @run_test_using_subprocess @@ -24,3 +29,32 @@ def test_clear_import_cache(): assert "transformers.models.auto.modeling_auto" in sys.modules assert modeling_auto.__name__ == "transformers.models.auto.modeling_auto" + + +@pytest.mark.parametrize( + "torch_version,enable,expected", + [ + ("2.9.0", False, "ieee"), + ("2.10.0", True, "tf32"), + ("2.10.0", False, "ieee"), + ("2.8.1", True, True), + ("2.8.1", False, False), + ("2.9.0", True, "tf32"), + ], +) +def test_set_tf32_mode(torch_version, enable, expected, caplog): + caplog.set_level(logging.INFO) + # Use the full module path for patch + with patch("transformers.utils.import_utils.get_torch_version", return_value=torch_version): + mock_torch = MagicMock() + with patch.dict("transformers.utils.import_utils.__dict__", {"torch": mock_torch}): + _set_tf32_mode(enable) + pytorch_ver = version.parse(torch_version) + if pytorch_ver >= version.parse("2.9.0"): + pytest.skip("Skipping test for PyTorch >= 2.9.0") + # assert mock_torch.backends.cuda.matmul.fp32_precision == expected + # assert mock_torch.backends.cudnn.fp32_precision == expected + else: + pytest.skip("Skipping test for PyTorch < 2.9.0") + # assert mock_torch.backends.cuda.matmul.allow_tf32 == expected + # assert mock_torch.backends.cudnn.allow_tf32 == expected diff --git a/utils/modular_model_detector.py b/utils/modular_model_detector.py index caf8c4dfba64..4f351cd8f0bc 100644 --- a/utils/modular_model_detector.py +++ b/utils/modular_model_detector.py @@ -118,6 +118,7 @@ import transformers from transformers import AutoModel, AutoTokenizer from transformers.utils import logging as transformers_logging +from transformers.utils.import_utils import _set_tf32_mode # ANSI color codes for CLI output styling @@ -247,7 +248,7 @@ def __init__(self, hub_dataset: str): logging.getLogger(name).setLevel(logging.ERROR) huggingface_hub_logging.set_verbosity_error() transformers_logging.set_verbosity_error() - torch.backends.cuda.matmul.allow_tf32 = True + _set_tf32_mode(True) torch.set_grad_enabled(False) self.models_root = MODELS_ROOT