|
| 1 | +# Copyright (c) Meta Platforms, Inc. and affiliates. |
| 2 | +# All rights reserved. |
| 3 | +# |
| 4 | +# This source code is licensed under the BSD 3-Clause license found in the |
| 5 | +# LICENSE file in the root directory of this source tree. |
| 6 | + |
| 7 | +""" |
| 8 | +A script demonstrating quantization of the routed experts of |
| 9 | +the `meta-llama/Llama-4-Scout-17B-16E-Instruct` model from HuggingFace |
| 10 | +to w8a8 with float8 rowwise weights and activations. |
| 11 | +""" |
| 12 | + |
| 13 | +import argparse |
| 14 | +import random |
| 15 | +from pathlib import Path |
| 16 | + |
| 17 | +import fbgemm_gpu |
| 18 | +import numpy as np |
| 19 | +import torch |
| 20 | +import transformers |
| 21 | +from transformers import AutoModelForCausalLM, AutoTokenizer, TorchAoConfig |
| 22 | + |
| 23 | +from torchao.quantization import ( |
| 24 | + Float8DynamicActivationFloat8WeightConfig, |
| 25 | + FqnToConfig, |
| 26 | + PerRow, |
| 27 | +) |
| 28 | +from torchao.quantization.quantize_.workflows.float8.float8_tensor import ( |
| 29 | + Float8Tensor, |
| 30 | +) |
| 31 | + |
| 32 | + |
| 33 | +# Set seeds for reproducibility |
| 34 | +def set_seed(seed): |
| 35 | + random.seed(seed) |
| 36 | + np.random.seed(seed) |
| 37 | + torch.manual_seed(seed) |
| 38 | + torch.cuda.manual_seed_all(seed) |
| 39 | + |
| 40 | + |
| 41 | +def get_quantization_config(): |
| 42 | + expert_3d_weight_single_config = Float8DynamicActivationFloat8WeightConfig( |
| 43 | + # the weights of this model are stored in (B, K, N) layout, and we |
| 44 | + # need to quantize rowwise across the K axis, which is `PerRow(1)`. |
| 45 | + granularity=[PerRow(), PerRow(1)], |
| 46 | + # guard against activations with groups of all-zeroes |
| 47 | + activation_value_lb=1.0e-12, |
| 48 | + ) |
| 49 | + fqn_to_config = FqnToConfig( |
| 50 | + { |
| 51 | + # only quantize the routed experts, the rest of the model is left |
| 52 | + # in high precision |
| 53 | + r"re:.*\.feed_forward\.experts\.gate_up_proj": expert_3d_weight_single_config, |
| 54 | + r"re:.*\.feed_forward\.experts\.down_proj": expert_3d_weight_single_config, |
| 55 | + } |
| 56 | + ) |
| 57 | + return TorchAoConfig(quant_type=fqn_to_config) |
| 58 | + |
| 59 | + |
| 60 | +def parse_args(): |
| 61 | + parser = argparse.ArgumentParser(description="Quantize a model with TorchAO") |
| 62 | + parser.add_argument( |
| 63 | + "output_dir", |
| 64 | + type=str, |
| 65 | + help="Directory to save the quantized model", |
| 66 | + ) |
| 67 | + parser.add_argument( |
| 68 | + "--max_new_tokens", |
| 69 | + type=int, |
| 70 | + default=64, |
| 71 | + help="Max tokens to generate for testing (default: 64)", |
| 72 | + ) |
| 73 | + parser.add_argument( |
| 74 | + "--convert_llama_4_expert_weights_to_mnk", |
| 75 | + action="store_true", |
| 76 | + help="If set, converts LLaMa 4 Scout expert weights from MKN to MNK memory layout", |
| 77 | + ) |
| 78 | + parser.add_argument( |
| 79 | + "--no_save_model_to_disk", |
| 80 | + action="store_true", |
| 81 | + help="If set, skips saving quantized model to local disk", |
| 82 | + ) |
| 83 | + parser.add_argument( |
| 84 | + "--no_load_model_from_disk", |
| 85 | + action="store_true", |
| 86 | + help="If set, skips reloading model from disk to test it again", |
| 87 | + ) |
| 88 | + return parser.parse_args() |
| 89 | + |
| 90 | + |
| 91 | +def main(args): |
| 92 | + """ |
| 93 | + Args: |
| 94 | + args: Parsed command line arguments containing: |
| 95 | + output_dir: Directory to save the quantized model |
| 96 | + max_new_tokens: Max tokens to generate for testing |
| 97 | + convert_llama_4_expert_weights_to_mnk: if True, converts LLaMa 4 Scout expert weights from MKN to MNK memory layout |
| 98 | + no_save_model_to_disk: if True, skips saving quantized model to local disk |
| 99 | + no_load_model_from_disk: if True, skips reloading model from disk to test it again |
| 100 | + """ |
| 101 | + |
| 102 | + # ensure relevant dependency versions are satisfied |
| 103 | + t_v = str(transformers.__version__) |
| 104 | + assert t_v >= "4.58", ( |
| 105 | + f"transformers version {t_v} too old, please upgrade to a transformers version with https://github.com/huggingface/transformers/pull/41894" |
| 106 | + ) |
| 107 | + f_v = str(fbgemm_gpu.__version__) |
| 108 | + if f_v.startswith("202"): |
| 109 | + # nightly version, such as '2025.11.22+cu128' |
| 110 | + assert f_v >= "2025.11.22", ( |
| 111 | + f"fbgemm_gpu nightly version {f_v} too old, please upgrade to a nightly from 2025-11-22 or later" |
| 112 | + ) |
| 113 | + else: |
| 114 | + # stable version, such as '1.4.1' |
| 115 | + assert f_v >= "1.5", ( |
| 116 | + f"fbgemm_gpu stable version {f_v} too old, please upgrade to 1.5 or later" |
| 117 | + ) |
| 118 | + |
| 119 | + model_name = "meta-llama/Llama-4-Scout-17B-16E-Instruct" |
| 120 | + device_map = "auto" |
| 121 | + |
| 122 | + # Test prompts |
| 123 | + prompts = [ |
| 124 | + "Why is Pytorch 2.0 the best machine learning compiler?", |
| 125 | + ] |
| 126 | + |
| 127 | + # Set seed before creating the model |
| 128 | + set_seed(42) |
| 129 | + |
| 130 | + # Create output directory |
| 131 | + output_dir = Path(args.output_dir) |
| 132 | + output_dir.mkdir(parents=True, exist_ok=True) |
| 133 | + |
| 134 | + # Get quantization config |
| 135 | + quantization_config = get_quantization_config() |
| 136 | + |
| 137 | + # Load tokenizer |
| 138 | + tokenizer = AutoTokenizer.from_pretrained(model_name) |
| 139 | + |
| 140 | + # Load and quantize model |
| 141 | + print("Loading and quantizing model...") |
| 142 | + quantized_model = AutoModelForCausalLM.from_pretrained( |
| 143 | + model_name, |
| 144 | + torch_dtype="bfloat16", |
| 145 | + device_map=device_map, |
| 146 | + quantization_config=quantization_config, |
| 147 | + ) |
| 148 | + print(quantized_model) |
| 149 | + |
| 150 | + # Test generation |
| 151 | + print("\nTesting quantized model generation...") |
| 152 | + input_ids = tokenizer(prompts, return_tensors="pt", padding=True).to( |
| 153 | + quantized_model.device |
| 154 | + ) |
| 155 | + outputs = quantized_model.generate(**input_ids, max_new_tokens=args.max_new_tokens) |
| 156 | + |
| 157 | + for i, (prompt, output) in enumerate(zip(prompts, outputs, strict=False)): |
| 158 | + generated_text = tokenizer.decode(output, skip_special_tokens=True) |
| 159 | + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") |
| 160 | + |
| 161 | + save_model_to_disk = not args.no_save_model_to_disk |
| 162 | + load_model_from_disk = not args.no_load_model_from_disk |
| 163 | + |
| 164 | + if save_model_to_disk: |
| 165 | + # Save quantized model |
| 166 | + print(f"\nSaving quantized model to: {output_dir}") |
| 167 | + |
| 168 | + if args.convert_llama_4_expert_weights_to_mnk: |
| 169 | + print("\nConverting LLaMa 4 expert weights from MKN to MNK layout") |
| 170 | + |
| 171 | + # source: https://github.com/huggingface/transformers/blob/6f6095e0cf509f7384d3ce0c1804013ef6cafd5f/src/transformers/modeling_utils.py#L3466 |
| 172 | + def save_function(shard, filename): |
| 173 | + # `save_pretrained` default logic calls tensor.contiguous() before |
| 174 | + # saving, so if we do mkn -> mnk before saving it will be |
| 175 | + # converted back to mkn. |
| 176 | + # We undo this in the custom save_function, which runs after |
| 177 | + # the contiguous call in `save_pretrained`.:) |
| 178 | + for k, v in shard.items(): |
| 179 | + # hacky check for LLaMa 4 experts |
| 180 | + if isinstance(v, Float8Tensor) and len(v.shape) == 3: |
| 181 | + v.qdata = ( |
| 182 | + v.qdata.transpose(-2, -1).contiguous().transpose(-2, -1) |
| 183 | + ) |
| 184 | + torch.save(shard, filename) |
| 185 | + |
| 186 | + else: |
| 187 | + save_function = torch.save |
| 188 | + |
| 189 | + quantized_model.save_pretrained( |
| 190 | + output_dir, |
| 191 | + safe_serialization=False, |
| 192 | + save_function=save_function, |
| 193 | + ) |
| 194 | + tokenizer.save_pretrained(output_dir) |
| 195 | + |
| 196 | + if load_model_from_disk: |
| 197 | + assert save_model_to_disk, "unimplemented" |
| 198 | + # Load saved model to verify |
| 199 | + # TODO: do we really need `weights_only=False` here? |
| 200 | + loaded_model = AutoModelForCausalLM.from_pretrained( |
| 201 | + output_dir, |
| 202 | + device_map=device_map, |
| 203 | + torch_dtype="auto", |
| 204 | + weights_only=False, |
| 205 | + ) |
| 206 | + |
| 207 | + # Test loaded model with first prompt |
| 208 | + test_prompt = prompts[0] |
| 209 | + input_ids = tokenizer(test_prompt, return_tensors="pt").to(loaded_model.device) |
| 210 | + output = loaded_model.generate(**input_ids, max_new_tokens=args.max_new_tokens) |
| 211 | + generated_text = tokenizer.decode(output[0], skip_special_tokens=True) |
| 212 | + print( |
| 213 | + f"Verification - Prompt: {test_prompt!r}, Generated text: {generated_text!r}" |
| 214 | + ) |
| 215 | + |
| 216 | + print("\nQuantization process completed successfully.") |
| 217 | + |
| 218 | + |
| 219 | +if __name__ == "__main__": |
| 220 | + args = parse_args() |
| 221 | + main(args) |
0 commit comments