6262 from torchao .dtypes import Int4CPULayout
6363 if version .parse (importlib .metadata .version ("torchao" )) >= version .parse ("0.11.0" ):
6464 from torchao .dtypes import Int4XPULayout
65+ if version .parse (importlib .metadata .version ("torchao" )) >= version .parse ("0.15.0" ):
66+ from torchao .quantization import FqnToConfig
6567
6668
6769def check_torchao_int4_wo_quantized (test_module , qlayer ):
@@ -378,6 +380,154 @@ def test_module_fqn_to_config_regex_precedence(self):
378380 ]
379381 self .assertTrue (tokenizer .decode (output [0 ], skip_special_tokens = True ) in EXPECTED_OUTPUT )
380382
383+ @require_torchao_version_greater_or_equal ("0.15.0" )
384+ def test_fqn_to_config_regex_precedence (self ):
385+ linear1_config = Int8WeightOnlyConfig ()
386+ linear2_config = Float8WeightOnlyConfig ()
387+ config = FqnToConfig (
388+ {
389+ r"re:model\.layers\..+\.self_attn\.q_proj.weight" : None ,
390+ "model.layers.3.self_attn.q_proj.weight" : linear2_config ,
391+ "_default" : linear1_config ,
392+ }
393+ )
394+ quant_config = TorchAoConfig (quant_type = config )
395+ quantized_model = AutoModelForCausalLM .from_pretrained (
396+ self .model_name ,
397+ device_map = self .device ,
398+ quantization_config = quant_config ,
399+ )
400+ self .assertTrue (isinstance (quantized_model .model .layers [3 ].self_attn .q_proj .weight , Float8Tensor ))
401+ self .assertTrue (not isinstance (quantized_model .model .layers [1 ].self_attn .q_proj .weight , AffineQuantizedTensor ))
402+ self .assertTrue (isinstance (quantized_model .model .layers [1 ].self_attn .k_proj .weight , AffineQuantizedTensor ))
403+ tokenizer = AutoTokenizer .from_pretrained (self .model_name )
404+
405+ input_ids = tokenizer (self .input_text , return_tensors = "pt" ).to (self .device )
406+
407+ output = quantized_model .generate (** input_ids , max_new_tokens = self .max_new_tokens )
408+ EXPECTED_OUTPUT = [
409+ "What are we having for dinner?\n \n Jessica: (smiling)" ,
410+ "What are we having for dinner?\n \n Jess: (smiling) I" ,
411+ ]
412+ self .assertTrue (tokenizer .decode (output [0 ], skip_special_tokens = True ) in EXPECTED_OUTPUT )
413+
414+ @require_torchao_version_greater_or_equal ("0.15.0" )
415+ def test_fqn_to_config_param_over_module_regex_precedence (self ):
416+ linear1_config = Int8WeightOnlyConfig ()
417+ linear2_config = Float8WeightOnlyConfig ()
418+ config = FqnToConfig (
419+ {
420+ r"re:model\.layers\..+\.self_attn\.q_proj.weight" : None ,
421+ r"re:model\.layers\..+\.self_attn\.q_proj" : linear2_config ,
422+ "_default" : linear1_config ,
423+ }
424+ )
425+ quant_config = TorchAoConfig (quant_type = config )
426+ quantized_model = AutoModelForCausalLM .from_pretrained (
427+ self .model_name ,
428+ device_map = self .device ,
429+ quantization_config = quant_config ,
430+ )
431+ self .assertTrue (not isinstance (quantized_model .model .layers [1 ].self_attn .q_proj .weight , AffineQuantizedTensor ))
432+ self .assertTrue (isinstance (quantized_model .model .layers [1 ].self_attn .k_proj .weight , AffineQuantizedTensor ))
433+ tokenizer = AutoTokenizer .from_pretrained (self .model_name )
434+
435+ input_ids = tokenizer (self .input_text , return_tensors = "pt" ).to (self .device )
436+
437+ output = quantized_model .generate (** input_ids , max_new_tokens = self .max_new_tokens )
438+ EXPECTED_OUTPUT = [
439+ "What are we having for dinner?\n \n Jessica: (smiling)" ,
440+ "What are we having for dinner?\n \n Jess: (smiling) I" ,
441+ ]
442+ self .assertTrue (tokenizer .decode (output [0 ], skip_special_tokens = True ) in EXPECTED_OUTPUT )
443+
444+ @require_torchao_version_greater_or_equal ("0.15.0" )
445+ def test_fqn_to_config_param_over_module_precedence (self ):
446+ linear1_config = Int8WeightOnlyConfig ()
447+ linear2_config = Float8WeightOnlyConfig ()
448+ config = FqnToConfig (
449+ {
450+ "model.layers.3.self_attn.q_proj.weight" : None ,
451+ "model.layers.3.self_attn.q_proj" : linear2_config ,
452+ "_default" : linear1_config ,
453+ }
454+ )
455+ quant_config = TorchAoConfig (quant_type = config )
456+ quantized_model = AutoModelForCausalLM .from_pretrained (
457+ self .model_name ,
458+ device_map = self .device ,
459+ quantization_config = quant_config ,
460+ )
461+ self .assertTrue (not isinstance (quantized_model .model .layers [3 ].self_attn .q_proj .weight , AffineQuantizedTensor ))
462+ self .assertTrue (isinstance (quantized_model .model .layers [3 ].self_attn .k_proj .weight , AffineQuantizedTensor ))
463+ tokenizer = AutoTokenizer .from_pretrained (self .model_name )
464+
465+ input_ids = tokenizer (self .input_text , return_tensors = "pt" ).to (self .device )
466+
467+ output = quantized_model .generate (** input_ids , max_new_tokens = self .max_new_tokens )
468+ EXPECTED_OUTPUT = [
469+ "What are we having for dinner?\n \n Jessica: (smiling)" ,
470+ "What are we having for dinner?\n \n Jess: (smiling) I" ,
471+ ]
472+ self .assertTrue (tokenizer .decode (output [0 ], skip_special_tokens = True ) in EXPECTED_OUTPUT )
473+
474+ @require_torchao_version_greater_or_equal ("0.15.0" )
475+ def test_fqn_to_config_exact_over_regex_precedence (self ):
476+ linear1_config = Int8WeightOnlyConfig ()
477+ linear2_config = Float8WeightOnlyConfig ()
478+ config = FqnToConfig (
479+ {
480+ "model.layers.3.self_attn.q_proj.weight" : None ,
481+ "model.layers.1.self_attn.q_proj" : linear1_config ,
482+ r"re:model\.layers\..+\.self_attn\.q_proj.weight" : linear2_config ,
483+ }
484+ )
485+ quant_config = TorchAoConfig (quant_type = config )
486+ quantized_model = AutoModelForCausalLM .from_pretrained (
487+ self .model_name ,
488+ device_map = self .device ,
489+ quantization_config = quant_config ,
490+ )
491+ self .assertTrue (not isinstance (quantized_model .model .layers [3 ].self_attn .q_proj .weight , AffineQuantizedTensor ))
492+ self .assertTrue (isinstance (quantized_model .model .layers [1 ].self_attn .q_proj .weight , AffineQuantizedTensor ))
493+ self .assertTrue (isinstance (quantized_model .model .layers [2 ].self_attn .q_proj .weight , Float8Tensor ))
494+
495+ tokenizer = AutoTokenizer .from_pretrained (self .model_name )
496+
497+ input_ids = tokenizer (self .input_text , return_tensors = "pt" ).to (self .device )
498+
499+ output = quantized_model .generate (** input_ids , max_new_tokens = self .max_new_tokens )
500+ EXPECTED_OUTPUT = [
501+ "What are we having for dinner?\n \n Jessica: (smiling)" ,
502+ "What are we having for dinner?\n \n Jess: (smiling) I" ,
503+ ]
504+ self .assertTrue (tokenizer .decode (output [0 ], skip_special_tokens = True ) in EXPECTED_OUTPUT )
505+
506+ @require_torchao_version_greater_or_equal ("0.15.0" )
507+ def test_fqn_to_config_non_weight_param (self ):
508+ linear1_config = Int8WeightOnlyConfig ()
509+ linear2_config = Float8WeightOnlyConfig ()
510+ config = FqnToConfig (
511+ {
512+ r"re:.*gate_up_proj" : linear2_config ,
513+ "model.layers.0.feed_forward.experts.gate_up_proj" : None ,
514+ "_default" : linear1_config ,
515+ }
516+ )
517+ quant_config = TorchAoConfig (quant_type = config )
518+ quantized_model = AutoModelForCausalLM .from_pretrained (
519+ "jcaip/Llama-4-Scout-17B-two-layers-only-testing" ,
520+ device_map = "auto" ,
521+ dtype = torch .bfloat16 ,
522+ quantization_config = quant_config ,
523+ )
524+
525+ self .assertTrue (isinstance (quantized_model .model .layers [1 ].feed_forward .experts .gate_up_proj , Float8Tensor ))
526+ self .assertTrue (
527+ not isinstance (quantized_model .model .layers [0 ].feed_forward .experts .gate_up_proj , Float8Tensor )
528+ )
529+ self .assertTrue (isinstance (quantized_model .model .layers [1 ].self_attn .q_proj .weight , AffineQuantizedTensor ))
530+
381531
382532@require_torch_accelerator
383533class TorchAoAcceleratorTest (TorchAoTest ):
@@ -580,6 +730,8 @@ class TorchAoSafeSerializationTest(TorchAoSerializationTest):
580730 def setUpClass (cls ):
581731 cls .tokenizer = AutoTokenizer .from_pretrained (cls .model_name )
582732 cls .EXPECTED_OUTPUT = "What are we having for dinner?\n - 1. What is the temperature outside"
733+ # placeholder
734+ cls .quant_scheme = torchao .quantization .Float8WeightOnlyConfig ()
583735
584736 def tearDown (self ):
585737 gc .collect ()
@@ -710,11 +862,10 @@ def setUpClass(cls):
710862
711863 from torchao .quantization import Float8WeightOnlyConfig
712864
865+ super ().setUpClass ()
713866 cls .quant_scheme = Float8WeightOnlyConfig ()
714867 cls .quant_scheme_kwargs = {}
715868
716- super ().setUpClass ()
717-
718869 cls .EXPECTED_OUTPUT = "What are we having for dinner?\n \n Jessica: (smiling)"
719870
720871
@@ -732,11 +883,10 @@ def setUpClass(cls):
732883
733884 from torchao .quantization import Int8DynamicActivationInt4WeightConfig
734885
886+ super ().setUpClass ()
735887 cls .quant_scheme = Int8DynamicActivationInt4WeightConfig ()
736888 cls .quant_scheme_kwargs = {}
737889
738- super ().setUpClass ()
739-
740890 cls .EXPECTED_OUTPUT = "What are we having for dinner?\n \n Jessica: (smiling)"
741891
742892
0 commit comments