Skip to content

Conversation

@khushali9
Copy link

@khushali9 khushali9 commented Nov 25, 2025

What does this PR do?

The ask is to use fp32_precision instead of allow_tf32 for
Pytorch version >= 2.9.0
as pointed out in this doc mentioned in the issue

I have also added test cases

Fixes #42371 (issue)

Can you review @Rocketknight1 @ArthurZucker

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

# Use the full module path for patch
with patch("transformers.utils.import_utils.get_torch_version", return_value=torch_version):
mock_torch = MagicMock()
with patch.dict("transformers.utils.import_utils.__dict__", {"torch": mock_torch}):
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Rocketknight1 testcase passes with import torch at module level in import_utils file where _set_tf32_mode is defined, but that is not accepted so moved it back inside my new method. but that broke this line so fixed it aswell. But tests are having issue with mock. Keeping it to skip , may be can be worked on later , or if someone can help now that will be great.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you can just remove this test! The only real way to fully test the function would be to have multiple versions of torch in the CI, which is quite hard. So in reality, we'll only be testing the installed version of torch anyway, and if the function is failing for that version then we'll see errors elsewhere anyway.

distribution_name = pkg_name if pkg_name in distributions else distributions[0]
package_version = importlib.metadata.version(distribution_name)
except (importlib.metadata.PackageNotFoundError, KeyError):
except importlib.metadata.PackageNotFoundError:
Copy link
Author

@khushali9 khushali9 Nov 26, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

changes due to running make style or main merge. How can I remove unrelated code.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmn, we probably don't want these changes in the PR! It makes it hard to review and I'm a bit worried that it'll actually revert some code. Can you try getting rid of them, maybe with one of the following:

  1. Rebase/merge onto the latest main commit
  2. pip install -e .[quality] to get the latest style tools
  3. Compare the edited files against the equivalent version in main and revert any of these unrelated changes?

def _set_tf32_mode(enable: bool) -> None:
"""
Set TF32 mode using the appropriate PyTorch API.
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

main method that will help us set correct API use

Copy link
Member

@Rocketknight1 Rocketknight1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Patch mostly looks good! I added some comments about the extra changes - those can be fiddly, and sometimes result from messy merges or rebases. If you can't get rid of them easily you may want to just start a new PR by branching from the latest main commit and tagging me there instead.

distribution_name = pkg_name if pkg_name in distributions else distributions[0]
package_version = importlib.metadata.version(distribution_name)
except (importlib.metadata.PackageNotFoundError, KeyError):
except importlib.metadata.PackageNotFoundError:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmn, we probably don't want these changes in the PR! It makes it hard to review and I'm a bit worried that it'll actually revert some code. Can you try getting rid of them, maybe with one of the following:

  1. Rebase/merge onto the latest main commit
  2. pip install -e .[quality] to get the latest style tools
  3. Compare the edited files against the equivalent version in main and revert any of these unrelated changes?



@lru_cache
def _set_tf32_mode(enable: bool) -> None:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's fine to use a non-private name for the function, like enable_tf32 or even torch_enable_tensorfloat32 for clarity.

# Use the full module path for patch
with patch("transformers.utils.import_utils.get_torch_version", return_value=torch_version):
mock_torch = MagicMock()
with patch.dict("transformers.utils.import_utils.__dict__", {"torch": mock_torch}):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you can just remove this test! The only real way to fully test the function would be to have multiple versions of torch in the CI, which is quite hard. So in reality, we'll only be testing the installed version of torch anyway, and if the function is failing for that version then we'll see errors elsewhere anyway.

@khushali9
Copy link
Author

Sure I will create new PR, and remove tests, and update method name.
Thanks for your help.
Also I have torch > 2.9 and one of the only test passed for it. That makes sense.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Please use the new API settings to control TF32 behavior, ...

2 participants