Skip to content

Commit 80134e6

Browse files
jcaipMekkCyber
andauthored
Update transformers to support FqnToConfig (#41894)
* Update transformers to support `FqnToConfig` Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: * add case for modulefqn * remove comment * update tests * cleanup * update * wip * wip * update quantizer_torchao for module default * fix underscore * update tests * update * fix import error * fix import * import change not included in previous commit * Apply suggestion from @MekkCyber Co-authored-by: Mohamed Mekkouri <93391238+MekkCyber@users.noreply.github.com> * Update src/transformers/quantizers/quantizer_torchao.py Co-authored-by: Mohamed Mekkouri <93391238+MekkCyber@users.noreply.github.com> * update tests and add comment * fix test --------- Co-authored-by: Mohamed Mekkouri <93391238+MekkCyber@users.noreply.github.com>
1 parent ce40ca0 commit 80134e6

File tree

3 files changed

+225
-13
lines changed

3 files changed

+225
-13
lines changed

docs/source/en/quantization/torchao.md

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -422,19 +422,19 @@ print(tokenizer.decode(output[0], skip_special_tokens=True))
422422

423423
#### 1. Skip quantization for certain layers
424424

425-
With `ModuleFqnToConfig` we can specify a default configuration for all layers while skipping quantization for certain layers.
425+
With `FqnToConfig` we can specify a default configuration for all layers while skipping quantization for certain layers.
426426

427427
```py
428428
import torch
429429
from transformers import AutoModelForCausalLM, AutoTokenizer, TorchAoConfig
430430

431431
model_id = "meta-llama/Llama-3.1-8B-Instruct"
432432

433-
from torchao.quantization import Int4WeightOnlyConfig, ModuleFqnToConfig
433+
from torchao.quantization import Int4WeightOnlyConfig, FqnToConfig
434434
config = Int4WeightOnlyConfig(group_size=128)
435435

436436
# set default to int4 (for linears), and skip quantizing `model.layers.0.self_attn.q_proj`
437-
quant_config = ModuleFqnToConfig({"_default": config, "model.layers.0.self_attn.q_proj": None})
437+
quant_config = FqnToConfig({"_default": config, "model.layers.0.self_attn.q_proj": None})
438438
quantization_config = TorchAoConfig(quant_type=quant_config)
439439
quantized_model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto", dtype=torch.bfloat16, quantization_config=quantization_config)
440440
# lm_head is not quantized and model.layers.0.self_attn.q_proj is not quantized
@@ -459,7 +459,7 @@ from transformers import AutoModelForCausalLM, AutoTokenizer, TorchAoConfig
459459

460460
model_id = "facebook/opt-125m"
461461

462-
from torchao.quantization import Int4WeightOnlyConfig, ModuleFqnToConfig, Int8DynamicActivationInt4WeightConfig, IntxWeightOnlyConfig, PerAxis, MappingType
462+
from torchao.quantization import Int4WeightOnlyConfig, FqnToConfig, Int8DynamicActivationInt4WeightConfig, IntxWeightOnlyConfig, PerAxis, MappingType
463463

464464
weight_dtype = torch.int8
465465
granularity = PerAxis(0)
@@ -470,7 +470,7 @@ embedding_config = IntxWeightOnlyConfig(
470470
mapping_type=mapping_type,
471471
)
472472
linear_config = Int8DynamicActivationInt4WeightConfig(group_size=128)
473-
quant_config = ModuleFqnToConfig({"_default": linear_config, "model.decoder.embed_tokens": embedding_config, "model.decoder.embed_positions": None})
473+
quant_config = FqnToConfig({"_default": linear_config, "model.decoder.embed_tokens": embedding_config, "model.decoder.embed_positions": None})
474474
# set `include_embedding` to True in order to include embedding in quantization
475475
# when `include_embedding` is True, we'll remove input embedding from `modules_not_to_convert` as well
476476
quantization_config = TorchAoConfig(quant_type=quant_config, include_embedding=True)
@@ -521,7 +521,7 @@ from torchao.quantization import (
521521
IntxWeightOnlyConfig,
522522
PerRow,
523523
PerAxis,
524-
ModuleFqnToConfig,
524+
FqnToConfig,
525525
Float8Tensor,
526526
Int4TilePackedTo4dTensor,
527527
IntxUnpackedToInt8Tensor,
@@ -550,7 +550,7 @@ qconfig_dict = {
550550

551551
"_default": intxwo,
552552
}
553-
quant_config = ModuleFqnToConfig(qconfig_dict)
553+
quant_config = FqnToConfig(qconfig_dict)
554554
quantization_config = TorchAoConfig(quant_type=quant_config)
555555
quantized_model = AutoModelForCausalLM.from_pretrained(
556556
model_id,

src/transformers/quantizers/quantizer_torchao.py

Lines changed: 64 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -251,6 +251,23 @@ def param_needs_quantization(self, model: "PreTrainedModel", param_name: str, **
251251
_QUANTIZABLE = [torch.nn.Linear]
252252
if self.quantization_config.include_input_output_embeddings:
253253
_QUANTIZABLE.append(torch.nn.Embedding)
254+
255+
# Handle FqnToConfig, introduced in torchao 0.15.0+
256+
if self.quantization_config._get_ao_version() >= version.parse("0.15.0"):
257+
from torchao.quantization import FqnToConfig, fqn_matches_fqn_config
258+
259+
if isinstance(self.quantization_config.quant_type, FqnToConfig):
260+
module_fqn, param_name_fqn = param_name.rsplit(".", 1)
261+
if (
262+
fqn_matches_fqn_config(module_fqn, self.quantization_config.quant_type)
263+
or fqn_matches_fqn_config(param_name, self.quantization_config.quant_type)
264+
or (
265+
"_default" in self.quantization_config.quant_type.fqn_to_config
266+
and isinstance(module, tuple(_QUANTIZABLE))
267+
)
268+
):
269+
return True
270+
254271
return isinstance(module, tuple(_QUANTIZABLE)) and tensor_name == "weight"
255272

256273
def create_quantized_param(
@@ -319,8 +336,54 @@ def create_quantized_param(
319336
model.tie_weights()
320337
setattr(model.config.get_text_config(decoder=True), "tie_word_embeddings", False)
321338

339+
# handle FqnToConfig, introduced in torchao 0.15.0+
340+
if self.quantization_config._get_ao_version() >= version.Version("0.15.0"):
341+
from torchao.quantization import FqnToConfig
342+
343+
config = self.quantization_config.get_apply_tensor_subclass()
344+
if isinstance(config, FqnToConfig):
345+
module_fqn, top_level_param_name = param_name.rsplit(".", 1)
346+
c = None
347+
if param_name in config.fqn_to_config:
348+
assert not module_fqn.startswith("re:"), (
349+
"param fqn should not start with`re:`, which is used for specifying regex"
350+
)
351+
c = config.module_fqn_to_config[param_name]
352+
elif module_fqn in config.fqn_to_config:
353+
assert not module_fqn.startswith("re:"), (
354+
"module fqn should not start with`re:`, which is used for specifying regex"
355+
)
356+
c = config.module_fqn_to_config[module_fqn]
357+
# regex match module and param
358+
else:
359+
for maybe_module_fqn_pattern in config.fqn_to_config:
360+
# if key doesn't start with re, it is an exact fqn key, so we don't regex match
361+
if not maybe_module_fqn_pattern.startswith("re:"):
362+
continue
363+
# see if param matches first
364+
elif re.fullmatch(maybe_module_fqn_pattern[3:], param_name):
365+
c = config.module_fqn_to_config[maybe_module_fqn_pattern]
366+
break
367+
elif re.fullmatch(maybe_module_fqn_pattern[3:], module_fqn):
368+
# we'll apply the config for first fully matched pattern
369+
c = config.module_fqn_to_config[maybe_module_fqn_pattern]
370+
break
371+
else:
372+
c = config.module_fqn_to_config.get("_default", None)
373+
374+
if c is not None:
375+
if top_level_param_name == "weight":
376+
# we can apply the module config directly
377+
quantize_(module, c, (lambda x, fqn: True))
378+
else:
379+
# need to apply to custom param name
380+
custom_param_fqn_config = FqnToConfig({top_level_param_name: c})
381+
quantize_(module, custom_param_fqn_config, filter_fn=None)
382+
return
383+
322384
# handle ModuleFqnToConfig, introduced in torchao 0.12.0+
323-
if self.quantization_config._get_ao_version() >= version.Version("0.12.0"):
385+
# TODO deprecate this when we deprecate ModuleFqnToConfig
386+
elif self.quantization_config._get_ao_version() >= version.Version("0.12.0"):
324387
from torchao.quantization import ModuleFqnToConfig
325388

326389
config = self.quantization_config.get_apply_tensor_subclass()
@@ -342,7 +405,6 @@ def create_quantized_param(
342405
break
343406
else:
344407
c = config.module_fqn_to_config.get("_default", None)
345-
346408
if c is not None:
347409
# filter_fn: not filtering out any modules
348410
quantize_(module, c, filter_fn=lambda x, fqn: True)

tests/quantization/torchao_integration/test_torchao.py

Lines changed: 154 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,8 @@
6262
from torchao.dtypes import Int4CPULayout
6363
if version.parse(importlib.metadata.version("torchao")) >= version.parse("0.11.0"):
6464
from torchao.dtypes import Int4XPULayout
65+
if version.parse(importlib.metadata.version("torchao")) >= version.parse("0.15.0"):
66+
from torchao.quantization import FqnToConfig
6567

6668

6769
def check_torchao_int4_wo_quantized(test_module, qlayer):
@@ -378,6 +380,154 @@ def test_module_fqn_to_config_regex_precedence(self):
378380
]
379381
self.assertTrue(tokenizer.decode(output[0], skip_special_tokens=True) in EXPECTED_OUTPUT)
380382

383+
@require_torchao_version_greater_or_equal("0.15.0")
384+
def test_fqn_to_config_regex_precedence(self):
385+
linear1_config = Int8WeightOnlyConfig()
386+
linear2_config = Float8WeightOnlyConfig()
387+
config = FqnToConfig(
388+
{
389+
r"re:model\.layers\..+\.self_attn\.q_proj.weight": None,
390+
"model.layers.3.self_attn.q_proj.weight": linear2_config,
391+
"_default": linear1_config,
392+
}
393+
)
394+
quant_config = TorchAoConfig(quant_type=config)
395+
quantized_model = AutoModelForCausalLM.from_pretrained(
396+
self.model_name,
397+
device_map=self.device,
398+
quantization_config=quant_config,
399+
)
400+
self.assertTrue(isinstance(quantized_model.model.layers[3].self_attn.q_proj.weight, Float8Tensor))
401+
self.assertTrue(not isinstance(quantized_model.model.layers[1].self_attn.q_proj.weight, AffineQuantizedTensor))
402+
self.assertTrue(isinstance(quantized_model.model.layers[1].self_attn.k_proj.weight, AffineQuantizedTensor))
403+
tokenizer = AutoTokenizer.from_pretrained(self.model_name)
404+
405+
input_ids = tokenizer(self.input_text, return_tensors="pt").to(self.device)
406+
407+
output = quantized_model.generate(**input_ids, max_new_tokens=self.max_new_tokens)
408+
EXPECTED_OUTPUT = [
409+
"What are we having for dinner?\n\nJessica: (smiling)",
410+
"What are we having for dinner?\n\nJess: (smiling) I",
411+
]
412+
self.assertTrue(tokenizer.decode(output[0], skip_special_tokens=True) in EXPECTED_OUTPUT)
413+
414+
@require_torchao_version_greater_or_equal("0.15.0")
415+
def test_fqn_to_config_param_over_module_regex_precedence(self):
416+
linear1_config = Int8WeightOnlyConfig()
417+
linear2_config = Float8WeightOnlyConfig()
418+
config = FqnToConfig(
419+
{
420+
r"re:model\.layers\..+\.self_attn\.q_proj.weight": None,
421+
r"re:model\.layers\..+\.self_attn\.q_proj": linear2_config,
422+
"_default": linear1_config,
423+
}
424+
)
425+
quant_config = TorchAoConfig(quant_type=config)
426+
quantized_model = AutoModelForCausalLM.from_pretrained(
427+
self.model_name,
428+
device_map=self.device,
429+
quantization_config=quant_config,
430+
)
431+
self.assertTrue(not isinstance(quantized_model.model.layers[1].self_attn.q_proj.weight, AffineQuantizedTensor))
432+
self.assertTrue(isinstance(quantized_model.model.layers[1].self_attn.k_proj.weight, AffineQuantizedTensor))
433+
tokenizer = AutoTokenizer.from_pretrained(self.model_name)
434+
435+
input_ids = tokenizer(self.input_text, return_tensors="pt").to(self.device)
436+
437+
output = quantized_model.generate(**input_ids, max_new_tokens=self.max_new_tokens)
438+
EXPECTED_OUTPUT = [
439+
"What are we having for dinner?\n\nJessica: (smiling)",
440+
"What are we having for dinner?\n\nJess: (smiling) I",
441+
]
442+
self.assertTrue(tokenizer.decode(output[0], skip_special_tokens=True) in EXPECTED_OUTPUT)
443+
444+
@require_torchao_version_greater_or_equal("0.15.0")
445+
def test_fqn_to_config_param_over_module_precedence(self):
446+
linear1_config = Int8WeightOnlyConfig()
447+
linear2_config = Float8WeightOnlyConfig()
448+
config = FqnToConfig(
449+
{
450+
"model.layers.3.self_attn.q_proj.weight": None,
451+
"model.layers.3.self_attn.q_proj": linear2_config,
452+
"_default": linear1_config,
453+
}
454+
)
455+
quant_config = TorchAoConfig(quant_type=config)
456+
quantized_model = AutoModelForCausalLM.from_pretrained(
457+
self.model_name,
458+
device_map=self.device,
459+
quantization_config=quant_config,
460+
)
461+
self.assertTrue(not isinstance(quantized_model.model.layers[3].self_attn.q_proj.weight, AffineQuantizedTensor))
462+
self.assertTrue(isinstance(quantized_model.model.layers[3].self_attn.k_proj.weight, AffineQuantizedTensor))
463+
tokenizer = AutoTokenizer.from_pretrained(self.model_name)
464+
465+
input_ids = tokenizer(self.input_text, return_tensors="pt").to(self.device)
466+
467+
output = quantized_model.generate(**input_ids, max_new_tokens=self.max_new_tokens)
468+
EXPECTED_OUTPUT = [
469+
"What are we having for dinner?\n\nJessica: (smiling)",
470+
"What are we having for dinner?\n\nJess: (smiling) I",
471+
]
472+
self.assertTrue(tokenizer.decode(output[0], skip_special_tokens=True) in EXPECTED_OUTPUT)
473+
474+
@require_torchao_version_greater_or_equal("0.15.0")
475+
def test_fqn_to_config_exact_over_regex_precedence(self):
476+
linear1_config = Int8WeightOnlyConfig()
477+
linear2_config = Float8WeightOnlyConfig()
478+
config = FqnToConfig(
479+
{
480+
"model.layers.3.self_attn.q_proj.weight": None,
481+
"model.layers.1.self_attn.q_proj": linear1_config,
482+
r"re:model\.layers\..+\.self_attn\.q_proj.weight": linear2_config,
483+
}
484+
)
485+
quant_config = TorchAoConfig(quant_type=config)
486+
quantized_model = AutoModelForCausalLM.from_pretrained(
487+
self.model_name,
488+
device_map=self.device,
489+
quantization_config=quant_config,
490+
)
491+
self.assertTrue(not isinstance(quantized_model.model.layers[3].self_attn.q_proj.weight, AffineQuantizedTensor))
492+
self.assertTrue(isinstance(quantized_model.model.layers[1].self_attn.q_proj.weight, AffineQuantizedTensor))
493+
self.assertTrue(isinstance(quantized_model.model.layers[2].self_attn.q_proj.weight, Float8Tensor))
494+
495+
tokenizer = AutoTokenizer.from_pretrained(self.model_name)
496+
497+
input_ids = tokenizer(self.input_text, return_tensors="pt").to(self.device)
498+
499+
output = quantized_model.generate(**input_ids, max_new_tokens=self.max_new_tokens)
500+
EXPECTED_OUTPUT = [
501+
"What are we having for dinner?\n\nJessica: (smiling)",
502+
"What are we having for dinner?\n\nJess: (smiling) I",
503+
]
504+
self.assertTrue(tokenizer.decode(output[0], skip_special_tokens=True) in EXPECTED_OUTPUT)
505+
506+
@require_torchao_version_greater_or_equal("0.15.0")
507+
def test_fqn_to_config_non_weight_param(self):
508+
linear1_config = Int8WeightOnlyConfig()
509+
linear2_config = Float8WeightOnlyConfig()
510+
config = FqnToConfig(
511+
{
512+
r"re:.*gate_up_proj": linear2_config,
513+
"model.layers.0.feed_forward.experts.gate_up_proj": None,
514+
"_default": linear1_config,
515+
}
516+
)
517+
quant_config = TorchAoConfig(quant_type=config)
518+
quantized_model = AutoModelForCausalLM.from_pretrained(
519+
"jcaip/Llama-4-Scout-17B-two-layers-only-testing",
520+
device_map="auto",
521+
dtype=torch.bfloat16,
522+
quantization_config=quant_config,
523+
)
524+
525+
self.assertTrue(isinstance(quantized_model.model.layers[1].feed_forward.experts.gate_up_proj, Float8Tensor))
526+
self.assertTrue(
527+
not isinstance(quantized_model.model.layers[0].feed_forward.experts.gate_up_proj, Float8Tensor)
528+
)
529+
self.assertTrue(isinstance(quantized_model.model.layers[1].self_attn.q_proj.weight, AffineQuantizedTensor))
530+
381531

382532
@require_torch_accelerator
383533
class TorchAoAcceleratorTest(TorchAoTest):
@@ -580,6 +730,8 @@ class TorchAoSafeSerializationTest(TorchAoSerializationTest):
580730
def setUpClass(cls):
581731
cls.tokenizer = AutoTokenizer.from_pretrained(cls.model_name)
582732
cls.EXPECTED_OUTPUT = "What are we having for dinner?\n- 1. What is the temperature outside"
733+
# placeholder
734+
cls.quant_scheme = torchao.quantization.Float8WeightOnlyConfig()
583735

584736
def tearDown(self):
585737
gc.collect()
@@ -710,11 +862,10 @@ def setUpClass(cls):
710862

711863
from torchao.quantization import Float8WeightOnlyConfig
712864

865+
super().setUpClass()
713866
cls.quant_scheme = Float8WeightOnlyConfig()
714867
cls.quant_scheme_kwargs = {}
715868

716-
super().setUpClass()
717-
718869
cls.EXPECTED_OUTPUT = "What are we having for dinner?\n\nJessica: (smiling)"
719870

720871

@@ -732,11 +883,10 @@ def setUpClass(cls):
732883

733884
from torchao.quantization import Int8DynamicActivationInt4WeightConfig
734885

886+
super().setUpClass()
735887
cls.quant_scheme = Int8DynamicActivationInt4WeightConfig()
736888
cls.quant_scheme_kwargs = {}
737889

738-
super().setUpClass()
739-
740890
cls.EXPECTED_OUTPUT = "What are we having for dinner?\n\nJessica: (smiling)"
741891

742892

0 commit comments

Comments
 (0)