Skip to content

Commit afd039d

Browse files
authored
[quantization] Dequant fp8 when cuda or xpu not available (#42511)
* up * style * add tests * update
1 parent fa3cf83 commit afd039d

File tree

5 files changed

+134
-32
lines changed

5 files changed

+134
-32
lines changed

src/transformers/integrations/finegrained_fp8.py

Lines changed: 34 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,7 @@
1414
# limitations under the License.
1515

1616
import re
17-
from collections.abc import Sequence
18-
from typing import Any
17+
from typing import Optional
1918

2019
from ..core_model_loading import ConversionOps
2120
from ..utils import is_accelerate_available, is_torch_accelerator_available, is_torch_available, logging
@@ -549,6 +548,9 @@ def replace_with_fp8_linear(
549548
quantization_config=None,
550549
):
551550
"""Helper function to replace model layers with FP8 versions."""
551+
if quantization_config.dequantize:
552+
return model
553+
552554
if modules_to_not_convert is None:
553555
modules_to_not_convert = []
554556
modules_to_not_convert += ["lm_head"]
@@ -652,41 +654,45 @@ def convert(self, input_dict: torch.Tensor, **kwargs) -> dict[str, torch.Tensor]
652654
class Fp8Dequantize(ConversionOps):
653655
"""Inverse operation of :class:`Fp8Quantize`. Takes a pair (weight, scale) and reconstructs the fp32 tensor."""
654656

655-
def __init__(self, block_size: tuple[int, int] | None = None):
656-
self.block_size = block_size
657+
def __init__(self, hf_quantizer):
658+
self.hf_quantizer = hf_quantizer
657659

658660
def convert(
659661
self,
660-
value: Sequence[torch.Tensor] | dict[str, torch.Tensor],
661-
*,
662-
context: dict[str, Any],
663-
) -> torch.Tensor:
664-
if isinstance(value, dict):
665-
tensors = list(value.values())
666-
else:
667-
tensors = list(value) if isinstance(value, Sequence) else [value]
668-
if len(tensors) != 2:
669-
raise ValueError("Fp8Dequantize expects exactly two tensors: quantized weights and scales.")
670-
quantized, scales = tensors
671-
if not isinstance(quantized, torch.Tensor) or not isinstance(scales, torch.Tensor):
672-
raise TypeError("Fp8Dequantize expects tensors as inputs.")
673-
674-
quantized_fp32 = quantized.to(torch.float32)
675-
rows, cols = quantized_fp32.shape[-2:]
676-
block_size = self.block_size
677-
if block_size is None:
678-
quant_config = context.get("quantization_config")
679-
block_size = getattr(quant_config, "weight_block_size", None)
680-
if block_size is None:
681-
block_size = (rows, cols)
662+
input_dict: dict[str, torch.Tensor],
663+
model: Optional[torch.nn.Module] = None,
664+
full_layer_name: str | None = None,
665+
missing_keys=None,
666+
**kwargs,
667+
) -> dict[str, torch.Tensor]:
668+
if len(input_dict) != 2:
669+
# in case of no scales, the weights are not quantized, so we return the weights as is
670+
return {
671+
full_layer_name: input_dict["weight$"][0]
672+
if isinstance(input_dict["weight$"], list)
673+
else input_dict["weight$"]
674+
}
675+
quantized = input_dict["weight$"][0] if isinstance(input_dict["weight$"], list) else input_dict["weight$"]
676+
scales = (
677+
input_dict["weight_scale_inv"][0]
678+
if isinstance(input_dict["weight_scale_inv"], list)
679+
else input_dict["weight_scale_inv"]
680+
)
681+
682+
rows, cols = quantized.shape[-2:]
683+
block_size = self.hf_quantizer.quantization_config.weight_block_size
684+
682685
block_m, block_n = block_size
683686
if rows % block_m != 0 or cols % block_n != 0:
684687
raise ValueError(
685688
f"Matrix dimensions ({rows}, {cols}) must be divisible by block sizes ({block_m}, {block_n})."
686689
)
687690

688-
reshaped = quantized_fp32.reshape(-1, rows // block_m, block_m, cols // block_n, block_n)
691+
reshaped = quantized.reshape(-1, rows // block_m, block_m, cols // block_n, block_n)
689692
expanded_scales = scales.to(torch.float32).reshape(-1, rows // block_m, cols // block_n)
690693
expanded_scales = expanded_scales.unsqueeze(-1).unsqueeze(2)
691694
dequantized = reshaped * expanded_scales
692-
return dequantized.reshape(quantized_fp32.shape)
695+
696+
return {
697+
full_layer_name: dequantized.reshape(quantized.shape),
698+
}

src/transformers/quantizers/auto.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -224,7 +224,15 @@ def merge_quantization_configs(
224224
if (
225225
isinstance(
226226
quantization_config,
227-
(GPTQConfig, AwqConfig, AutoRoundConfig, FbgemmFp8Config, CompressedTensorsConfig, Mxfp4Config),
227+
(
228+
GPTQConfig,
229+
AwqConfig,
230+
AutoRoundConfig,
231+
FbgemmFp8Config,
232+
CompressedTensorsConfig,
233+
Mxfp4Config,
234+
FineGrainedFP8Config,
235+
),
228236
)
229237
and quantization_config_from_args is not None
230238
):
@@ -234,7 +242,7 @@ def merge_quantization_configs(
234242

235243
warning_msg += f"However, loading attributes (e.g. {list(loading_attr_dict.keys())}) will be overwritten with the one you passed to `from_pretrained`. The rest will be ignored."
236244

237-
if warning_msg != "" and not isinstance(quantization_config, Mxfp4Config):
245+
if warning_msg != "" and not isinstance(quantization_config, (Mxfp4Config, FineGrainedFP8Config)):
238246
warnings.warn(warning_msg)
239247
else:
240248
# in the case of mxfp4, we don't want to print the warning message, bit confusing for users

src/transformers/quantizers/quantizer_finegrained_fp8.py

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,8 +38,15 @@ def validate_environment(self, *args, **kwargs):
3838
if not is_accelerate_available():
3939
raise ImportError("Loading an FP8 quantized model requires accelerate (`pip install accelerate`)")
4040

41-
if not (torch.cuda.is_available() or is_torch_xpu_available()):
42-
raise RuntimeError("No GPU or XPU found. A GPU or XPU is needed for FP8 quantization.")
41+
if (not (torch.cuda.is_available() or is_torch_xpu_available())) and not self.quantization_config.dequantize:
42+
if self.pre_quantized:
43+
logger.warning_once(
44+
"Using FP8 quantized models requires a GPU or XPU, we will default to dequantizing the model to bf16 since no GPU or XPU is available"
45+
)
46+
self.quantization_config.dequantize = True
47+
return
48+
else:
49+
raise RuntimeError("No GPU or XPU found. A GPU or XPU is needed for FP8 quantization.")
4350

4451
if torch.cuda.is_available():
4552
compute_capability = torch.cuda.get_device_capability()
@@ -231,3 +238,18 @@ def get_quantize_ops(self):
231238
from ..integrations.finegrained_fp8 import Fp8Quantize
232239

233240
return Fp8Quantize(self)
241+
242+
def get_weight_conversions(self):
243+
from ..core_model_loading import WeightConverter
244+
from ..integrations.finegrained_fp8 import Fp8Dequantize
245+
246+
if self.pre_quantized and self.quantization_config.dequantize:
247+
return [
248+
# either use the dollar sign, or permute the source patterns to start matching against the scales first
249+
WeightConverter(
250+
source_patterns=["weight$", "weight_scale_inv"],
251+
target_patterns="weight",
252+
operations=[Fp8Dequantize(self)],
253+
)
254+
]
255+
return []

src/transformers/utils/quantization_config.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1981,6 +1981,8 @@ class FineGrainedFP8Config(QuantizationConfigMixin):
19811981
The scheme used for activation, the defaults and only support scheme for now is "dynamic".
19821982
weight_block_size (`typing.tuple[int, int]`, *optional*, defaults to `(128, 128)`):
19831983
The size of the weight blocks for quantization, default is (128, 128).
1984+
dequantize (`bool`, *optional*, defaults to `False`):
1985+
Whether to dequantize the model during loading.
19841986
modules_to_not_convert (`list`, *optional*):
19851987
A list of module names that should not be converted during quantization.
19861988
"""
@@ -1989,13 +1991,15 @@ def __init__(
19891991
self,
19901992
activation_scheme: str = "dynamic",
19911993
weight_block_size: tuple[int, int] = (128, 128),
1994+
dequantize: bool = False,
19921995
modules_to_not_convert: list | None = None,
19931996
**kwargs,
19941997
):
19951998
self.quant_method = QuantizationMethod.FP8
19961999
self.modules_to_not_convert = modules_to_not_convert
19972000
self.activation_scheme = activation_scheme
19982001
self.weight_block_size = weight_block_size
2002+
self.dequantize = dequantize
19992003
self.post_init()
20002004

20012005
def post_init(self):
@@ -2010,6 +2014,9 @@ def post_init(self):
20102014
if self.weight_block_size[0] <= 0 or self.weight_block_size[1] <= 0:
20112015
raise ValueError("weight_block_size must be a tuple of two positive integers")
20122016

2017+
def get_loading_attributes(self):
2018+
return {"dequantize": self.dequantize}
2019+
20132020

20142021
class QuarkConfig(QuantizationConfigMixin):
20152022
def __init__(

tests/quantization/finegrained_fp8/test_fp8.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,11 @@
1515
import gc
1616
import tempfile
1717
import unittest
18+
from contextlib import ExitStack, contextmanager
19+
from unittest.mock import patch
1820

1921
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, FineGrainedFP8Config, OPTForCausalLM
22+
from transformers.quantizers.quantizer_finegrained_fp8 import FineGrainedFP8HfQuantizer
2023
from transformers.testing_utils import (
2124
backend_empty_cache,
2225
get_device_properties,
@@ -37,6 +40,15 @@
3740
from accelerate import init_empty_weights
3841

3942

43+
@contextmanager
44+
def _patch_no_accelerator():
45+
with ExitStack() as stack:
46+
stack.enter_context(patch("torch.cuda.is_available", return_value=False))
47+
if hasattr(torch, "xpu"):
48+
stack.enter_context(patch("torch.xpu.is_available", return_value=False))
49+
yield
50+
51+
4052
@require_torch_accelerator
4153
class FineGrainedFP8ConfigTest(unittest.TestCase):
4254
def test_to_dict(self):
@@ -71,9 +83,11 @@ def test_from_dict(self):
7183
)
7284
class FP8QuantizerTest(unittest.TestCase):
7385
model_name = "meta-llama/Llama-3.2-1B"
86+
quantized_model_name = "hf-internal-testing/Llama-3.2-1B-Instruct-fp8"
7487
input_text = "Once upon a time"
7588
max_new_tokens = 10
7689
EXPECTED_OUTPUT = "Once upon a time, there was a man who was very rich."
90+
EXPECTED_DEQUANTIZED_OUTPUT = "Once upon a time, in a small village nestled in the rolling hills"
7791
device_map = torch_device
7892
offload_device_map = {
7993
"model.embed_tokens": 0,
@@ -152,6 +166,25 @@ def test_quantized_model_conversion(self):
152166

153167
self.assertEqual(nb_linears - 25, nb_fp8_linear)
154168

169+
def test_quantizer_validation_no_accelerator(self):
170+
"""Test quantizer validation when CUDA/XPU is not available"""
171+
with _patch_no_accelerator():
172+
config = FineGrainedFP8Config()
173+
quantizer = FineGrainedFP8HfQuantizer(config)
174+
quantizer.pre_quantized = False
175+
176+
with self.assertRaises(RuntimeError):
177+
quantizer.validate_environment()
178+
179+
def test_dequantization_no_accelerator(self):
180+
"""Test dequantization when CUDA/XPU is not available"""
181+
with _patch_no_accelerator():
182+
config = FineGrainedFP8Config()
183+
quantizer = FineGrainedFP8HfQuantizer(config)
184+
quantizer.pre_quantized = True
185+
quantizer.validate_environment()
186+
self.assertTrue(quantizer.quantization_config.dequantize)
187+
155188
def test_quantized_model(self):
156189
"""
157190
Simple test that checks if the quantized model is working properly
@@ -162,6 +195,32 @@ def test_quantized_model(self):
162195
output_tokens = self.tokenizer.decode(output[0], skip_special_tokens=True)
163196
self.assertEqual(output_tokens, self.EXPECTED_OUTPUT)
164197

198+
def test_dequantized_model(self):
199+
"""
200+
Simple test that checks if the dequantized model is working properly
201+
"""
202+
quantization_config = FineGrainedFP8Config(dequantize=True)
203+
dequantized_model = AutoModelForCausalLM.from_pretrained(
204+
self.quantized_model_name, device_map=self.device_map, quantization_config=quantization_config
205+
)
206+
input_ids = self.tokenizer(self.input_text, return_tensors="pt").to(self.device_map)
207+
output = dequantized_model.generate(**input_ids, max_new_tokens=self.max_new_tokens, do_sample=False)
208+
output_tokens = self.tokenizer.decode(output[0], skip_special_tokens=True)
209+
self.assertEqual(output_tokens, self.EXPECTED_DEQUANTIZED_OUTPUT)
210+
del dequantized_model
211+
212+
def test_dequantize_when_no_accelerator(self):
213+
"""
214+
Simple test that checks if the dequantized model is working properly when no accelerator is available
215+
"""
216+
with _patch_no_accelerator():
217+
dequantized_model = AutoModelForCausalLM.from_pretrained(self.quantized_model_name, device_map="cpu")
218+
input_ids = self.tokenizer(self.input_text, return_tensors="pt").to("cpu")
219+
output = dequantized_model.generate(**input_ids, max_new_tokens=self.max_new_tokens, do_sample=False)
220+
output_tokens = self.tokenizer.decode(output[0], skip_special_tokens=True)
221+
self.assertEqual(output_tokens, self.EXPECTED_DEQUANTIZED_OUTPUT)
222+
del dequantized_model
223+
165224
def test_save_pretrained(self):
166225
"""
167226
Simple test that checks if the quantized model is working properly after being saved and loaded

0 commit comments

Comments
 (0)