Skip to content
6 changes: 2 additions & 4 deletions conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
patch_testing_methods_to_collect_info,
patch_torch_compile_force_graph,
)
from transformers.utils.import_utils import _set_tf32_mode


NOT_DEVICE_TESTS = {
Expand Down Expand Up @@ -137,12 +138,9 @@ def check_output(self, want, got, optionflags):
doctest.DocTestParser = HfDocTestParser

if is_torch_available():
import torch

# The flag below controls whether to allow TF32 on cuDNN. This flag defaults to True.
# We set it to `False` for CI. See https://github.com/pytorch/pytorch/issues/157274#issuecomment-3090791615
torch.backends.cudnn.allow_tf32 = False

_set_tf32_mode(False)
# patch `torch.compile`: if `TORCH_COMPILE_FORCE_FULLGRAPH=1` (or values considered as true, e.g. yes, y, etc.),
# the patched version will always run with `fullgraph=True`.
patch_torch_compile_force_graph()
Expand Down
13 changes: 5 additions & 8 deletions src/transformers/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@
requires_backends,
)
from .utils.generic import strtobool
from .utils.import_utils import is_optimum_neuron_available
from .utils.import_utils import _set_tf32_mode, is_optimum_neuron_available


logger = logging.get_logger(__name__)
Expand Down Expand Up @@ -379,7 +379,7 @@ class TrainingArguments:
metric values.
tf32 (`bool`, *optional*):
Whether to enable the TF32 mode, available in Ampere and newer GPU architectures. The default value depends
on PyTorch's version default of `torch.backends.cuda.matmul.allow_tf32`. For more details please refer to
on PyTorch's version default of `torch.backends.cuda.matmul.allow_tf32` and For PyTorch 2.9+ `torch.backends.cuda.matmul.fp32_precision` . For more details please refer to
the [TF32](https://huggingface.co/docs/transformers/perf_train_gpu_one#tf32) documentation. This is an
experimental API and it may change.
ddp_backend (`str`, *optional*):
Expand Down Expand Up @@ -1604,8 +1604,7 @@ def __post_init__(self):
if is_torch_musa_available():
torch.backends.mudnn.allow_tf32 = True
else:
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
_set_tf32_mode(True)
else:
logger.warning(
"The speedups for torchdynamo mostly come with GPU Ampere or higher and which is not detected here."
Expand All @@ -1616,17 +1615,15 @@ def __post_init__(self):
if is_torch_musa_available():
torch.backends.mudnn.allow_tf32 = True
else:
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
_set_tf32_mode(True)
else:
raise ValueError("--tf32 requires Ampere or a newer GPU arch, cuda>=11 and torch>=1.7")
else:
if is_torch_tf32_available():
if is_torch_musa_available():
torch.backends.mudnn.allow_tf32 = False
else:
torch.backends.cuda.matmul.allow_tf32 = False
torch.backends.cudnn.allow_tf32 = False
_set_tf32_mode(False)
# no need to assert on else

if self.report_to == "all" or self.report_to == ["all"]:
Expand Down
42 changes: 29 additions & 13 deletions src/transformers/utils/import_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def _is_package_available(pkg_name: str, return_version: bool = False) -> tuple[
# pick the first item of the list as best guess (it's almost always a list of length 1 anyway)
distribution_name = pkg_name if pkg_name in distributions else distributions[0]
package_version = importlib.metadata.version(distribution_name)
except (importlib.metadata.PackageNotFoundError, KeyError):
except importlib.metadata.PackageNotFoundError:
Copy link
Author

@khushali9 khushali9 Nov 26, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

changes due to running make style or main merge. How can I remove unrelated code.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmn, we probably don't want these changes in the PR! It makes it hard to review and I'm a bit worried that it'll actually revert some code. Can you try getting rid of them, maybe with one of the following:

  1. Rebase/merge onto the latest main commit
  2. pip install -e .[quality] to get the latest style tools
  3. Compare the edited files against the equivalent version in main and revert any of these unrelated changes?

# If we cannot find the metadata (because of editable install for example), try to import directly.
# Note that this branch will almost never be run, so we do not import packages for nothing here
package = importlib.import_module(pkg_name)
Expand Down Expand Up @@ -87,7 +87,6 @@ def _is_package_available(pkg_name: str, return_version: bool = False) -> tuple[
TORCHAO_MIN_VERSION = "0.4.0"
AUTOROUND_MIN_VERSION = "0.5.0"
TRITON_MIN_VERSION = "1.0.0"
KERNELS_MIN_VERSION = "0.9.0"


@lru_cache
Expand Down Expand Up @@ -503,6 +502,27 @@ def is_torch_tf32_available() -> bool:
return True


@lru_cache
def _set_tf32_mode(enable: bool) -> None:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's fine to use a non-private name for the function, like enable_tf32 or even torch_enable_tensorfloat32 for clarity.

"""
Set TF32 mode using the appropriate PyTorch API.
For PyTorch 2.9+, uses the new fp32_precision API.
For older versions, uses the legacy allow_tf32 flags.
Args:
enable: Whether to enable TF32 mode
"""
import torch

pytorch_version = version.parse(get_torch_version())
if pytorch_version >= version.parse("2.9.0"):
precision_mode = "tf32" if enable else "ieee"
torch.backends.cuda.matmul.fp32_precision = precision_mode
torch.backends.cudnn.fp32_precision = precision_mode
else:
torch.backends.cuda.matmul.allow_tf32 = enable
torch.backends.cudnn.allow_tf32 = enable


@lru_cache
def is_torch_flex_attn_available() -> bool:
return is_torch_available() and version.parse(get_torch_version()) >= version.parse("2.5.0")
Expand All @@ -514,9 +534,8 @@ def is_kenlm_available() -> bool:


@lru_cache
def is_kernels_available(MIN_VERSION: str = KERNELS_MIN_VERSION) -> bool:
is_available, kernels_version = _is_package_available("kernels", return_version=True)
return is_available and version.parse(kernels_version) >= version.parse(MIN_VERSION)
def is_kernels_available() -> bool:
return _is_package_available("kernels")


@lru_cache
Expand Down Expand Up @@ -973,13 +992,13 @@ def is_quark_available() -> bool:
@lru_cache
def is_fp_quant_available():
is_available, fp_quant_version = _is_package_available("fp_quant", return_version=True)
return is_available and version.parse(fp_quant_version) >= version.parse("0.3.2")
return is_available and version.parse(fp_quant_version) >= version.parse("0.2.0")


@lru_cache
def is_qutlass_available():
is_available, qutlass_version = _is_package_available("qutlass", return_version=True)
return is_available and version.parse(qutlass_version) >= version.parse("0.2.0")
return is_available and version.parse(qutlass_version) >= version.parse("0.1.0")


@lru_cache
Expand Down Expand Up @@ -1178,12 +1197,9 @@ def is_mistral_common_available() -> bool:

@lru_cache
def is_opentelemetry_available() -> bool:
try:
return _is_package_available("opentelemetry") and version.parse(
importlib.metadata.version("opentelemetry-api")
) >= version.parse("1.30.0")
except Exception as _:
return False
return _is_package_available("opentelemetry") and version.parse(
importlib.metadata.version("opentelemetry-api")
) >= version.parse("1.30.0")


def check_torch_load_is_safe() -> None:
Expand Down
36 changes: 35 additions & 1 deletion tests/utils/test_import_utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
import logging
import sys
from unittest.mock import MagicMock, patch

import pytest
from packaging import version

from transformers.testing_utils import run_test_using_subprocess
from transformers.utils.import_utils import clear_import_cache
from transformers.utils.import_utils import _set_tf32_mode, clear_import_cache


@run_test_using_subprocess
Expand All @@ -24,3 +29,32 @@ def test_clear_import_cache():

assert "transformers.models.auto.modeling_auto" in sys.modules
assert modeling_auto.__name__ == "transformers.models.auto.modeling_auto"


@pytest.mark.parametrize(
"torch_version,enable,expected",
[
("2.9.0", False, "ieee"),
("2.10.0", True, "tf32"),
("2.10.0", False, "ieee"),
("2.8.1", True, True),
("2.8.1", False, False),
("2.9.0", True, "tf32"),
],
)
def test_set_tf32_mode(torch_version, enable, expected, caplog):
caplog.set_level(logging.INFO)
# Use the full module path for patch
with patch("transformers.utils.import_utils.get_torch_version", return_value=torch_version):
mock_torch = MagicMock()
with patch.dict("transformers.utils.import_utils.__dict__", {"torch": mock_torch}):
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Rocketknight1 testcase passes with import torch at module level in import_utils file where _set_tf32_mode is defined, but that is not accepted so moved it back inside my new method. but that broke this line so fixed it aswell. But tests are having issue with mock. Keeping it to skip , may be can be worked on later , or if someone can help now that will be great.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you can just remove this test! The only real way to fully test the function would be to have multiple versions of torch in the CI, which is quite hard. So in reality, we'll only be testing the installed version of torch anyway, and if the function is failing for that version then we'll see errors elsewhere anyway.

_set_tf32_mode(enable)
pytorch_ver = version.parse(torch_version)
if pytorch_ver >= version.parse("2.9.0"):
pytest.skip("Skipping test for PyTorch >= 2.9.0")
# assert mock_torch.backends.cuda.matmul.fp32_precision == expected
# assert mock_torch.backends.cudnn.fp32_precision == expected
else:
pytest.skip("Skipping test for PyTorch < 2.9.0")
# assert mock_torch.backends.cuda.matmul.allow_tf32 == expected
# assert mock_torch.backends.cudnn.allow_tf32 == expected
3 changes: 2 additions & 1 deletion utils/modular_model_detector.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@
import transformers
from transformers import AutoModel, AutoTokenizer
from transformers.utils import logging as transformers_logging
from transformers.utils.import_utils import _set_tf32_mode


# ANSI color codes for CLI output styling
Expand Down Expand Up @@ -247,7 +248,7 @@ def __init__(self, hub_dataset: str):
logging.getLogger(name).setLevel(logging.ERROR)
huggingface_hub_logging.set_verbosity_error()
transformers_logging.set_verbosity_error()
torch.backends.cuda.matmul.allow_tf32 = True
_set_tf32_mode(True)
torch.set_grad_enabled(False)

self.models_root = MODELS_ROOT
Expand Down