-
Notifications
You must be signed in to change notification settings - Fork 31.2k
Fix tf32 api deprecation for Pytorch version #42410
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
1a602e6
f20a9fc
bc52057
ab3a233
57a436b
69e02a1
f048179
4e59a40
9c3d5d6
8273ecf
e6eb963
01b723f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -58,7 +58,7 @@ def _is_package_available(pkg_name: str, return_version: bool = False) -> tuple[ | |
| # pick the first item of the list as best guess (it's almost always a list of length 1 anyway) | ||
| distribution_name = pkg_name if pkg_name in distributions else distributions[0] | ||
| package_version = importlib.metadata.version(distribution_name) | ||
| except (importlib.metadata.PackageNotFoundError, KeyError): | ||
| except importlib.metadata.PackageNotFoundError: | ||
| # If we cannot find the metadata (because of editable install for example), try to import directly. | ||
| # Note that this branch will almost never be run, so we do not import packages for nothing here | ||
| package = importlib.import_module(pkg_name) | ||
|
|
@@ -87,7 +87,6 @@ def _is_package_available(pkg_name: str, return_version: bool = False) -> tuple[ | |
| TORCHAO_MIN_VERSION = "0.4.0" | ||
| AUTOROUND_MIN_VERSION = "0.5.0" | ||
| TRITON_MIN_VERSION = "1.0.0" | ||
| KERNELS_MIN_VERSION = "0.9.0" | ||
|
|
||
|
|
||
| @lru_cache | ||
|
|
@@ -503,6 +502,27 @@ def is_torch_tf32_available() -> bool: | |
| return True | ||
|
|
||
|
|
||
| @lru_cache | ||
| def _set_tf32_mode(enable: bool) -> None: | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think it's fine to use a non-private name for the function, like |
||
| """ | ||
| Set TF32 mode using the appropriate PyTorch API. | ||
| For PyTorch 2.9+, uses the new fp32_precision API. | ||
| For older versions, uses the legacy allow_tf32 flags. | ||
| Args: | ||
| enable: Whether to enable TF32 mode | ||
| """ | ||
| import torch | ||
|
|
||
| pytorch_version = version.parse(get_torch_version()) | ||
| if pytorch_version >= version.parse("2.9.0"): | ||
| precision_mode = "tf32" if enable else "ieee" | ||
| torch.backends.cuda.matmul.fp32_precision = precision_mode | ||
| torch.backends.cudnn.fp32_precision = precision_mode | ||
| else: | ||
| torch.backends.cuda.matmul.allow_tf32 = enable | ||
| torch.backends.cudnn.allow_tf32 = enable | ||
|
|
||
|
|
||
| @lru_cache | ||
| def is_torch_flex_attn_available() -> bool: | ||
| return is_torch_available() and version.parse(get_torch_version()) >= version.parse("2.5.0") | ||
|
|
@@ -514,9 +534,8 @@ def is_kenlm_available() -> bool: | |
|
|
||
|
|
||
| @lru_cache | ||
| def is_kernels_available(MIN_VERSION: str = KERNELS_MIN_VERSION) -> bool: | ||
| is_available, kernels_version = _is_package_available("kernels", return_version=True) | ||
| return is_available and version.parse(kernels_version) >= version.parse(MIN_VERSION) | ||
| def is_kernels_available() -> bool: | ||
| return _is_package_available("kernels") | ||
|
|
||
|
|
||
| @lru_cache | ||
|
|
@@ -973,13 +992,13 @@ def is_quark_available() -> bool: | |
| @lru_cache | ||
| def is_fp_quant_available(): | ||
| is_available, fp_quant_version = _is_package_available("fp_quant", return_version=True) | ||
| return is_available and version.parse(fp_quant_version) >= version.parse("0.3.2") | ||
| return is_available and version.parse(fp_quant_version) >= version.parse("0.2.0") | ||
|
|
||
|
|
||
| @lru_cache | ||
| def is_qutlass_available(): | ||
| is_available, qutlass_version = _is_package_available("qutlass", return_version=True) | ||
| return is_available and version.parse(qutlass_version) >= version.parse("0.2.0") | ||
| return is_available and version.parse(qutlass_version) >= version.parse("0.1.0") | ||
|
|
||
|
|
||
| @lru_cache | ||
|
|
@@ -1178,12 +1197,9 @@ def is_mistral_common_available() -> bool: | |
|
|
||
| @lru_cache | ||
| def is_opentelemetry_available() -> bool: | ||
| try: | ||
| return _is_package_available("opentelemetry") and version.parse( | ||
| importlib.metadata.version("opentelemetry-api") | ||
| ) >= version.parse("1.30.0") | ||
| except Exception as _: | ||
| return False | ||
| return _is_package_available("opentelemetry") and version.parse( | ||
| importlib.metadata.version("opentelemetry-api") | ||
| ) >= version.parse("1.30.0") | ||
|
|
||
|
|
||
| def check_torch_load_is_safe() -> None: | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,7 +1,12 @@ | ||
| import logging | ||
| import sys | ||
| from unittest.mock import MagicMock, patch | ||
|
|
||
| import pytest | ||
| from packaging import version | ||
|
|
||
| from transformers.testing_utils import run_test_using_subprocess | ||
| from transformers.utils.import_utils import clear_import_cache | ||
| from transformers.utils.import_utils import _set_tf32_mode, clear_import_cache | ||
|
|
||
|
|
||
| @run_test_using_subprocess | ||
|
|
@@ -24,3 +29,32 @@ def test_clear_import_cache(): | |
|
|
||
| assert "transformers.models.auto.modeling_auto" in sys.modules | ||
| assert modeling_auto.__name__ == "transformers.models.auto.modeling_auto" | ||
|
|
||
|
|
||
| @pytest.mark.parametrize( | ||
| "torch_version,enable,expected", | ||
| [ | ||
| ("2.9.0", False, "ieee"), | ||
| ("2.10.0", True, "tf32"), | ||
| ("2.10.0", False, "ieee"), | ||
| ("2.8.1", True, True), | ||
| ("2.8.1", False, False), | ||
| ("2.9.0", True, "tf32"), | ||
| ], | ||
| ) | ||
| def test_set_tf32_mode(torch_version, enable, expected, caplog): | ||
| caplog.set_level(logging.INFO) | ||
| # Use the full module path for patch | ||
| with patch("transformers.utils.import_utils.get_torch_version", return_value=torch_version): | ||
| mock_torch = MagicMock() | ||
| with patch.dict("transformers.utils.import_utils.__dict__", {"torch": mock_torch}): | ||
|
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @Rocketknight1 testcase passes with import torch at module level in import_utils file where _set_tf32_mode is defined, but that is not accepted so moved it back inside my new method. but that broke this line so fixed it aswell. But tests are having issue with mock. Keeping it to skip , may be can be worked on later , or if someone can help now that will be great.
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think you can just remove this test! The only real way to fully test the function would be to have multiple versions of |
||
| _set_tf32_mode(enable) | ||
| pytorch_ver = version.parse(torch_version) | ||
| if pytorch_ver >= version.parse("2.9.0"): | ||
| pytest.skip("Skipping test for PyTorch >= 2.9.0") | ||
| # assert mock_torch.backends.cuda.matmul.fp32_precision == expected | ||
| # assert mock_torch.backends.cudnn.fp32_precision == expected | ||
| else: | ||
| pytest.skip("Skipping test for PyTorch < 2.9.0") | ||
| # assert mock_torch.backends.cuda.matmul.allow_tf32 == expected | ||
| # assert mock_torch.backends.cudnn.allow_tf32 == expected | ||
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
changes due to running make style or main merge. How can I remove unrelated code.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmn, we probably don't want these changes in the PR! It makes it hard to review and I'm a bit worried that it'll actually revert some code. Can you try getting rid of them, maybe with one of the following:
maincommitpip install -e .[quality]to get the latest style toolsmainand revert any of these unrelated changes?