Skip to content

Conversation

@SunMarc
Copy link
Member

@SunMarc SunMarc commented Nov 27, 2025

What does this PR do?

This PR fixes a bunch of code related to fp8 + some enhancement to make the code simpler to maintain.
Related issue #42442
Thanks to @YangKai0616 for spotting those.

@SunMarc SunMarc requested a review from MekkCyber November 27, 2025 14:19
SunMarc and others added 2 commits November 27, 2025 14:20
Co-authored-by: Yang Kai <kai.yang@intel.com>
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@YangKai0616
Copy link
Contributor

There's one more question:
when running the tests/quantization/finegrained_fp8/test_fp8.py::FP8QuantizerTest::test_quantized_model_multi_accelerator test, we get result AssertionError: False is not true. The root cause is that when setUpClass loads the FP8 quantized model, PyTorch's caching allocator reserves significantly more reserved memory than the final allocated memory. As a result, Accelerate calculates a larger unused_memory, causing the entire model to be placed on GPU0. But this behavior actually seems reasonable.

For this test, should we explicitly add:

self.__class__.quantized_model = None
backend_empty_cache(torch_device)

Or, considering the case test_save_pretrained_multi_accelerators, should we add , max_memory={0: "3GB", 1: "3GB"} to both of them?

Could I have your thoughts on this? Thanks!

I wrote a simple reproduction script to observe this situation:

import torch
from transformers import FineGrainedFP8Config, AutoModelForCausalLM
from transformers.testing_utils import (
    backend_empty_cache,
    torch_device,
)

if __name__ == "__main__":
    print(f"torch.cuda.memory_reserved(0) before loading model: {torch.cuda.memory_reserved(0)}, torch.cuda.memory_allocated(0): {torch.cuda.memory_allocated(0)}")
    print(f"torch.cuda.memory_reserved(1) before loading model: {torch.cuda.memory_reserved(1)}, torch.cuda.memory_allocated(1): {torch.cuda.memory_allocated(1)}")
    quantization_config = FineGrainedFP8Config()
    quantized_model = AutoModelForCausalLM.from_pretrained(
            "meta-llama/Llama-3.2-1B", device_map=torch_device, quantization_config=quantization_config
        )
    print(f"torch.cuda.memory_reserved(0) after loading model: {torch.cuda.memory_reserved(0)}, torch.cuda.memory_allocated(0): {torch.cuda.memory_allocated(0)}")
    print(f"torch.cuda.memory_reserved(1) after loading model: {torch.cuda.memory_reserved(1)}, torch.cuda.memory_allocated(1): {torch.cuda.memory_allocated(1)}")

    print("##################################################################################################################################################################")
    quantized_model = None
    backend_empty_cache(torch_device)


    print(f"torch.cuda.memory_reserved(0) before loading model: {torch.cuda.memory_reserved(0)}, torch.cuda.memory_allocated(0): {torch.cuda.memory_allocated(0)}")
    print(f"torch.cuda.memory_reserved(1) before loading model: {torch.cuda.memory_reserved(1)}, torch.cuda.memory_allocated(1): {torch.cuda.memory_allocated(1)}")
    model = AutoModelForCausalLM.from_pretrained(
            "meta-llama/Llama-3.2-1B", device_map=torch_device, dtype=torch.float32
        )
    print(f"torch.cuda.memory_reserved(0) after loading model: {torch.cuda.memory_reserved(0)}, torch.cuda.memory_allocated(0): {torch.cuda.memory_allocated(0)}")
    print(f"torch.cuda.memory_reserved(1) after loading model: {torch.cuda.memory_reserved(1)}, torch.cuda.memory_allocated(1): {torch.cuda.memory_allocated(1)}")

The script output is:

torch.cuda.memory_reserved(0) before loading model: 0, torch.cuda.memory_allocated(0): 0
torch.cuda.memory_reserved(1) before loading model: 0, torch.cuda.memory_allocated(1): 0
Loading weights: 100%|██████████████████████████████████████████████████████████████████████████| 146/146 [00:00<00:00, 218.13it/s, Materializing param=model.norm.weight]
torch.cuda.memory_reserved(0) after loading model: 6054477824, torch.cuda.memory_allocated(0): 2024268288
torch.cuda.memory_reserved(1) after loading model: 0, torch.cuda.memory_allocated(1): 0
##################################################################################################################################################################
torch.cuda.memory_reserved(0) before loading model: 0, torch.cuda.memory_allocated(0): 0
torch.cuda.memory_reserved(1) before loading model: 0, torch.cuda.memory_allocated(1): 0
Loading weights: 100%|██████████████████████████████████████████████████████████████████████████| 146/146 [00:00<00:00, 226.73it/s, Materializing param=model.norm.weight]
torch.cuda.memory_reserved(0) after loading model: 4947181568, torch.cuda.memory_allocated(0): 4943258112
torch.cuda.memory_reserved(1) after loading model: 0, torch.cuda.memory_allocated(1): 0

@SunMarc
Copy link
Member Author

SunMarc commented Nov 28, 2025

There's one more question:
when running the tests/quantization/finegrained_fp8/test_fp8.py::FP8QuantizerTest::test_quantized_model_multi_accelerator test, we get result AssertionError: False is not true. The root cause is that when setUpClass loads the FP8 quantized model, PyTorch's caching allocator reserves significantly more reserved memory than the final allocated memory. As a result, Accelerate calculates a larger unused_memory, causing the entire model to be placed on GPU0. But this behavior actually seems reasonable.

Any idea where the Pytorch catching allocator happens ? We have our own caching allocator but it happens after _get_device_map. Btw, our caching allocator needs some fix as we changed a bit the modeling of the fp8 method

@YangKai0616
Copy link
Contributor

Any idea where the Pytorch catching allocator happens ? We have our own caching allocator but it happens after _get_device_map. Btw, our caching allocator needs some fix as we changed a bit the modeling of the fp8 method

Sorry for the confusion. Regarding torch, I was referring to here. Understood, I'll wait for the fix. Thanks!

@SunMarc
Copy link
Member Author

SunMarc commented Nov 28, 2025

btw @YangKai0616, even when setting _dtype = torch.float32, I don't get the expected output. Can you try this PR to see what results you get ? Even with older version of transformers, I get
Once upon a time, there was a little girl who loved to play

@YangKai0616
Copy link
Contributor

btw @YangKai0616, even when setting _dtype = torch.float32, I don't get the expected output. Can you try this PR to see what results you get ? Even with older version of transformers, I get Once upon a time, there was a little girl who loved to play

Using this PR, I can get the expected output as follows:

================================================================================ FAILURES ================================================================================
_________________________________________________________________ FP8QuantizerTest.test_quantized_model __________________________________________________________________

self = <finegrained_fp8.test_fp8.FP8QuantizerTest testMethod=test_quantized_model>

    def test_quantized_model(self):
        """
        Simple test that checks if the quantized model is working properly
        """
        input_ids = self.tokenizer(self.input_text, return_tensors="pt").to(self.device_map)
    
        output = self.quantized_model.generate(**input_ids, max_new_tokens=self.max_new_tokens, do_sample=False)
        output_tokens = self.tokenizer.decode(output[0], skip_special_tokens=True)
>       self.assertEqual(output_tokens, self.EXPECTED_OUTPUT)

tests/quantization/finegrained_fp8/test_fp8.py:159: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
src/transformers/testing_utils.py:651: in wrapper
    return test_case(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^
src/transformers/testing_utils.py:651: in wrapper
    return test_case(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^
src/transformers/testing_utils.py:651: in wrapper
    return test_case(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^
E   AssertionError: 'Once upon a time, there was a man who was very rich.' != 'Once upon a time, there was a little girl who loved to play'
E   - Once upon a time, there was a man who was very rich.
E   + Once upon a time, there was a little girl who loved to play

My testing environment is:

transformers 5.0.0.dev0 # branch fix-fp8
torch 2.9.1+xpu
2 cards Intel(R) Data Center GPU Max 1550

But I don't have a 4090 or H100, so I can't test the CUDA performance...

@github-actions
Copy link
Contributor

[For maintainers] Suggested jobs to run (before merge)

run-slow: finegrained_fp8, mxfp4

@SunMarc
Copy link
Member Author

SunMarc commented Nov 28, 2025

Thanks for confirming that it works on your hardware ! I will update it so that it doesn't fail on your side too

@SunMarc
Copy link
Member Author

SunMarc commented Nov 28, 2025

for the multi-gpu tests, I will probably fix this in a follow-up PR as I will need to update a lot of methods

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants