-
Notifications
You must be signed in to change notification settings - Fork 31.2k
Fix tf32 api deprecation for Pytorch version #42410
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
| # Use the full module path for patch | ||
| with patch("transformers.utils.import_utils.get_torch_version", return_value=torch_version): | ||
| mock_torch = MagicMock() | ||
| with patch.dict("transformers.utils.import_utils.__dict__", {"torch": mock_torch}): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@Rocketknight1 testcase passes with import torch at module level in import_utils file where _set_tf32_mode is defined, but that is not accepted so moved it back inside my new method. but that broke this line so fixed it aswell. But tests are having issue with mock. Keeping it to skip , may be can be worked on later , or if someone can help now that will be great.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think you can just remove this test! The only real way to fully test the function would be to have multiple versions of torch in the CI, which is quite hard. So in reality, we'll only be testing the installed version of torch anyway, and if the function is failing for that version then we'll see errors elsewhere anyway.
| distribution_name = pkg_name if pkg_name in distributions else distributions[0] | ||
| package_version = importlib.metadata.version(distribution_name) | ||
| except (importlib.metadata.PackageNotFoundError, KeyError): | ||
| except importlib.metadata.PackageNotFoundError: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
changes due to running make style or main merge. How can I remove unrelated code.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmn, we probably don't want these changes in the PR! It makes it hard to review and I'm a bit worried that it'll actually revert some code. Can you try getting rid of them, maybe with one of the following:
- Rebase/merge onto the latest
maincommit pip install -e .[quality]to get the latest style tools- Compare the edited files against the equivalent version in
mainand revert any of these unrelated changes?
| def _set_tf32_mode(enable: bool) -> None: | ||
| """ | ||
| Set TF32 mode using the appropriate PyTorch API. | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
main method that will help us set correct API use
Rocketknight1
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Patch mostly looks good! I added some comments about the extra changes - those can be fiddly, and sometimes result from messy merges or rebases. If you can't get rid of them easily you may want to just start a new PR by branching from the latest main commit and tagging me there instead.
| distribution_name = pkg_name if pkg_name in distributions else distributions[0] | ||
| package_version = importlib.metadata.version(distribution_name) | ||
| except (importlib.metadata.PackageNotFoundError, KeyError): | ||
| except importlib.metadata.PackageNotFoundError: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmn, we probably don't want these changes in the PR! It makes it hard to review and I'm a bit worried that it'll actually revert some code. Can you try getting rid of them, maybe with one of the following:
- Rebase/merge onto the latest
maincommit pip install -e .[quality]to get the latest style tools- Compare the edited files against the equivalent version in
mainand revert any of these unrelated changes?
|
|
||
|
|
||
| @lru_cache | ||
| def _set_tf32_mode(enable: bool) -> None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it's fine to use a non-private name for the function, like enable_tf32 or even torch_enable_tensorfloat32 for clarity.
| # Use the full module path for patch | ||
| with patch("transformers.utils.import_utils.get_torch_version", return_value=torch_version): | ||
| mock_torch = MagicMock() | ||
| with patch.dict("transformers.utils.import_utils.__dict__", {"torch": mock_torch}): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think you can just remove this test! The only real way to fully test the function would be to have multiple versions of torch in the CI, which is quite hard. So in reality, we'll only be testing the installed version of torch anyway, and if the function is failing for that version then we'll see errors elsewhere anyway.
|
Sure I will create new PR, and remove tests, and update method name. |
What does this PR do?
The ask is to use fp32_precision instead of allow_tf32 for
Pytorch version >= 2.9.0
as pointed out in this doc mentioned in the issue
I have also added test cases
Fixes #42371 (issue)
Can you review @Rocketknight1 @ArthurZucker
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.