Skip to content

Commit 1f1cf7f

Browse files
committed
fp8/nvfp4 quantization support
1 parent 5dce761 commit 1f1cf7f

File tree

3 files changed

+329
-3
lines changed

3 files changed

+329
-3
lines changed

tools/llm/README.md

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ This directory provides utilities and scripts for compiling, optimizing, and ben
66

77
- **Model Support:** Works with popular LLMs such as Llama-3, Qwen2.5, etc.
88
- **Precision Modes:** Supports FP16, BF16, and FP32.
9+
- **Quantization:** Supports FP8 and NVFP4 quantization formats for reduced memory usage and improved inference speed.
910
- **KV Cache:** Supports static and dynamic KV cache for efficient autoregressive decoding.
1011
- **Benchmarking:** Measures and compares throughput and latency for PyTorch and TensorRT backends.
1112
- **Custom Attention:** Registers and converts custom scaled dot-product attention (SDPA) for compatibility with TensorRT.
@@ -39,11 +40,39 @@ python run_llm.py --model meta-llama/Llama-3.2-1B-Instruct --prompt "What is par
3940
- `--tokenizer`: (Optional) Tokenizer name; defaults to model.
4041
- `--prompt`: Input prompt for generation.
4142
- `--precision`: Precision mode (`FP16`, `FP32`).
43+
- `--qformat`: Quantization format (`fp8`, `nvfp4`) to apply.
44+
- `--pre_quantized`: Flag to use pre-quantized models from HuggingFace.
4245
- `--num_tokens`: Number of output tokens to generate.
4346
- `--cache`: KV cache type (`static_v1`, `static_v2`, or empty for no KV caching).
4447
- `--benchmark`: Enable benchmarking mode.
4548
- `--enable_pytorch_run`: Also run and compare PyTorch baseline.
4649

50+
### Quantization
51+
52+
Torch-TensorRT supports quantization to reduce model memory footprint and improve inference performance:
53+
54+
#### Using Pre-quantized Models
55+
56+
To use pre-quantized models from HuggingFace:
57+
58+
```bash
59+
python run_llm.py --model nvidia/Llama-3.1-8B-Instruct-FP8 --pre_quantized --prompt "What is parallel programming?" --precision FP16 --num_tokens 128
60+
```
61+
62+
#### Applying quantization by ModelOpt
63+
64+
Apply fp8 quantization from HuggingFace:
65+
66+
```bash
67+
python run_llm.py --model meta-llama/Llama-3.1-8B --qformat fp8 --prompt "What is parallel programming?" --precision FP16 --num_tokens 128
68+
```
69+
70+
#### Quantization Requirements
71+
72+
- **ModelOpt Library**: Required for quantization operations
73+
- **FP8**: Supported on Hopper and Blackwell-generation GPUs.
74+
- **NVFP4**: Supported on Blackwell-generation GPUs.
75+
4776
### Caching Strategies
4877

4978
- **Static Cache v1/v2:** Adds static KV cache tensors as model inputs/outputs for efficient reuse.

tools/llm/quantize_utils.py

Lines changed: 267 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,267 @@
1+
import json
2+
import logging
3+
import os
4+
5+
import huggingface_hub
6+
import torch
7+
from huggingface_hub import snapshot_download
8+
9+
logger = logging.getLogger(__name__)
10+
11+
try:
12+
import modelopt.torch.quantization as mtq # noqa: F401f
13+
14+
assert torch.ops.tensorrt.quantize_op.default
15+
except Exception:
16+
logger.warning("Unable to import quantization op. Please install modelopt library")
17+
18+
from modelopt.core.torch.quantization.qtensor.nvfp4_tensor import NVFP4QTensor
19+
from modelopt.torch.quantization.config import QuantizerAttributeConfig
20+
from modelopt.torch.quantization.nn.modules.tensor_quantizer import TensorQuantizer
21+
from modelopt.torch.utils.dataset_utils import (
22+
create_forward_loop,
23+
get_dataset_dataloader,
24+
)
25+
from safetensors import safe_open
26+
27+
28+
def quantize_model(model, args, tokenizer):
29+
"""
30+
Quantize a PyTorch model using ModelOpt quantization.
31+
32+
This function performs post-training quantization (PTQ) on the model using
33+
calibration data from the provided tokenizer. It supports both FP8 and NVFP4
34+
quantization formats.
35+
36+
Args:
37+
model: PyTorch model to quantize
38+
args: Arguments containing quantization format and debug settings
39+
tokenizer: Tokenizer for creating calibration dataloader
40+
41+
Returns:
42+
Quantized model with reduced precision weights and activations
43+
44+
Raises:
45+
RuntimeError: If unsupported quantization format is specified
46+
"""
47+
# Create calibration dataloader for quantization
48+
calib_dataloader = get_dataset_dataloader(
49+
tokenizer=tokenizer,
50+
batch_size=32,
51+
num_samples=512,
52+
device="cuda:0",
53+
)
54+
if args.qformat == "fp8":
55+
quant_cfg = mtq.FP8_DEFAULT_CFG
56+
elif args.qformat == "nvfp4":
57+
quant_cfg = mtq.NVFP4_DEFAULT_CFG
58+
else:
59+
raise RuntimeError("Unsupported quantization format")
60+
calibrate_loop = create_forward_loop(dataloader=calib_dataloader)
61+
62+
model = mtq.quantize(model, quant_cfg, forward_loop=calibrate_loop)
63+
if args.debug:
64+
mtq.print_quant_summary(model)
65+
66+
return model
67+
68+
69+
class TensorRTQuantizedLinear(torch.nn.Module):
70+
"""
71+
TensorRT quantized linear layer that applies quantization to both input and weight tensors.
72+
"""
73+
74+
def __init__(
75+
self, original_linear: torch.nn.Linear, input_amax, weight_amax, quant_cfg
76+
):
77+
"""
78+
Initialize quantized linear layer.
79+
80+
Args:
81+
original_linear: Original PyTorch linear layer to quantize
82+
input_amax: Maximum absolute value for input quantization scaling
83+
weight_amax: Maximum absolute value for weight quantization scaling
84+
quant_cfg: Quantization configuration for TensorQuantizer
85+
"""
86+
super().__init__()
87+
88+
# Store reference to original linear layer for weight access
89+
self.original_linear = original_linear
90+
91+
# Copy bias from original layer if it exists
92+
if original_linear.bias is not None:
93+
self.bias = torch.nn.Parameter(original_linear.bias.clone()).cuda()
94+
else:
95+
self.bias = None
96+
97+
# Create quantizers for input and weight tensors
98+
self.input_quantizer = TensorQuantizer(
99+
quant_attribute_cfg=quant_cfg, amax=input_amax
100+
)
101+
self.weight_quantizer = TensorQuantizer(
102+
quant_attribute_cfg=quant_cfg, amax=weight_amax
103+
)
104+
105+
def forward(self, input):
106+
input = self.input_quantizer(input)
107+
weight = self.weight_quantizer(self.original_linear.weight)
108+
return torch.nn.functional.linear(input, weight, self.bias)
109+
110+
111+
def convert_linear_to_tensorrt_quantized(model, model_name):
112+
"""
113+
Convert linear layers in a model to TensorRT quantized versions from pre-quantized weights.
114+
115+
This function is specifically designed for Hugging Face quantized models and only
116+
applies quantization to linear operations. It loads pre-quantized models from
117+
Hugging Face format and replaces standard linear layers with TensorRTQuantizedLinear
118+
layers. It supports both FP8 and NVFP4 quantization formats.
119+
120+
The function:
121+
1. Loads quantization scales from Hugging Face model files (SafeTensors)
122+
2. Parses quantization configuration from hf_quant_config.json
123+
3. Replaces standard linear layers with TensorRTQuantizedLinear layers
124+
4. Applies appropriate quantization based on the model's quantization format
125+
126+
Note: This function only quantizes linear operations and is intended for use
127+
with pre-quantized Hugging Face models that have been quantized using ModelOpt.
128+
129+
Args:
130+
model: PyTorch model to quantize
131+
model_name: Path to Hugging Face model directory or model identifier
132+
133+
Returns:
134+
Model with quantized linear layers
135+
136+
Raises:
137+
RuntimeError: If quantization config is not found or unsupported format
138+
"""
139+
# Determine if model_name is a local directory or needs to be downloaded
140+
if os.path.isdir(model_name):
141+
hf_folder = model_name
142+
else:
143+
# Download model from Hugging Face Hub
144+
hf_folder = snapshot_download(
145+
model_name,
146+
local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE,
147+
ignore_patterns=["original/**/*"],
148+
revision=None,
149+
)
150+
151+
# Load all tensors from SafeTensors files
152+
tensors = {}
153+
for file in os.listdir(hf_folder):
154+
if file.endswith(".safetensors"):
155+
with safe_open(
156+
os.path.join(hf_folder, file), framework="pt", device="cpu"
157+
) as f:
158+
tensor_names = f.keys()
159+
for name in tensor_names:
160+
tensors[name] = f.get_tensor(name)
161+
162+
# Load and parse quantization configuration
163+
hf_quant_config_path = f"{hf_folder}/hf_quant_config.json"
164+
if os.path.exists(hf_quant_config_path):
165+
with open(hf_quant_config_path, "r") as f:
166+
hf_quant_config = json.load(f)
167+
hf_quant_config = hf_quant_config["quantization"]
168+
169+
hf_quant_algo = hf_quant_config.pop("quant_algo", None)
170+
if hf_quant_algo != "FP8" and hf_quant_algo != "NVFP4":
171+
raise RuntimeError("Only FP8 or NVFP4 quantization is supported")
172+
else:
173+
raise RuntimeError("No quantization config found")
174+
175+
# Iterate through all modules in the model
176+
for name, module in model.named_modules():
177+
# Check if the module is a linear layer
178+
target = torch.nn.modules.linear.Linear
179+
if isinstance(module, target):
180+
# Construct names for quantization scale tensors
181+
# These follow the naming convention: module_name.weight_scale and module_name.input_scale
182+
weight_scale_name = name + ".weight_scale"
183+
input_scale_name = name + ".input_scale"
184+
185+
if weight_scale_name not in tensors:
186+
logger.warning(f"Weight scale tensor {weight_scale_name} not found")
187+
continue
188+
if input_scale_name not in tensors:
189+
logger.warning(f"Input scale tensor {input_scale_name} not found")
190+
continue
191+
192+
if hf_quant_algo == "FP8":
193+
# FP8 E4M3 format has a maximum representable value of 448.0
194+
# Scale the quantization parameters accordingly
195+
weight_scale = tensors.pop(weight_scale_name)
196+
weight_amax = weight_scale * 448.0
197+
input_amax = tensors.pop(input_scale_name) * 448.0
198+
199+
# Dequantize the weight using the scale factor
200+
dequantized_weight_data = module.weight.to(torch.float32) * weight_scale
201+
202+
# Configure quantizer for FP8 format (4 exponent bits, 3 mantissa bits)
203+
quantizer_attribute_config = QuantizerAttributeConfig(
204+
num_bits=(4, 3), axis=None
205+
)
206+
207+
elif hf_quant_algo == "NVFP4":
208+
# NVFP4 format requires additional scale tensor and different configuration
209+
weight_name = name + ".weight"
210+
weight_scale2_name = name + ".weight_scale_2"
211+
weight_scale = tensors.pop(weight_scale_name)
212+
input_scale = tensors.pop(input_scale_name)
213+
weight_scale2 = tensors.pop(weight_scale2_name)
214+
215+
# Calculate amax values with additional scaling factor for NVFP4
216+
input_amax = input_scale * 448.0 * 6.0
217+
weight_amax = weight_scale2 * 448.0 * 6.0
218+
219+
# Handle NVFP4 tensor format
220+
weight_data = tensors.pop(weight_name)
221+
original_shape = list(weight_data.shape)
222+
original_shape[-1] *= 2 # NVFP4 packs 2 values per element
223+
nvfp4_tensor = NVFP4QTensor(
224+
torch.Size(original_shape), torch.float32, weight_data
225+
)
226+
227+
# Dequantize using both scales and block size configuration
228+
dequantized_weight_data = nvfp4_tensor.dequantize(
229+
scale=weight_scale, double_scale=weight_scale2, block_sizes={-1: 16}
230+
)
231+
232+
# Configure quantizer for NVFP4 format with dynamic block quantization
233+
quantizer_attribute_config = QuantizerAttributeConfig(
234+
num_bits=(2, 1),
235+
axis=None,
236+
block_sizes={-1: 16, "type": "dynamic", "scale_bits": (4, 3)},
237+
enable=True,
238+
)
239+
240+
# Restore the weight to its original full-precision format so that QDQ nodes
241+
# can be properly inserted and optimized during TensorRT compilation
242+
module.weight.data = dequantized_weight_data
243+
244+
# Create the quantized linear layer with calculated amax values
245+
quantized_module = TensorRTQuantizedLinear(
246+
module, input_amax, weight_amax, quantizer_attribute_config
247+
)
248+
249+
# Replace the original module with the quantized version
250+
# Extract parent module name and child module name
251+
parent_name = ".".join(name.split(".")[:-1])
252+
child_name = name.split(".")[-1]
253+
254+
if parent_name:
255+
# Get the parent module and replace the child
256+
parent_module = model.get_submodule(parent_name)
257+
setattr(parent_module, child_name, quantized_module)
258+
else:
259+
# If no parent, replace at model level
260+
setattr(model, child_name, quantized_module)
261+
262+
# Log any unused tensors for debugging
263+
if len(tensors) > 0:
264+
logger.debug(f"{len(tensors)} tensors not used")
265+
for key in tensors:
266+
logger.debug(f" {key}")
267+
return model

tools/llm/run_llm.py

Lines changed: 33 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
import argparse
1111
import copy
12+
import json
1213
import os
1314
import timeit
1415
from contextlib import nullcontext
@@ -54,10 +55,13 @@ def get_model(args):
5455
args.model,
5556
use_cache=False,
5657
attn_implementation="sdpa",
58+
ignore_mismatched_sizes=True,
5759
)
5860
.eval()
5961
.cuda()
6062
)
63+
if args.pre_quantized:
64+
model = convert_linear_to_tensorrt_quantized(model, args.model).cuda()
6165

6266
if args.precision == "FP16":
6367
model = model.to(torch.float16)
@@ -91,7 +95,8 @@ def compile_torchtrt(model, input_ids, args):
9195
for optimized inference
9296
"""
9397
max_seq_len = input_ids.shape[1] + args.num_tokens
94-
ep = export_llm(model, input_ids, max_seq_len=max_seq_len)
98+
with export_torch_mode() if args.qformat or args.pre_quantized else nullcontext():
99+
ep = export_llm(model, input_ids, max_seq_len=max_seq_len)
95100
position_ids = torch.arange(input_ids.shape[1]).unsqueeze(0).to(DEVICE)
96101
# Set precision specific flags
97102
use_fp32_acc = False
@@ -234,13 +239,36 @@ def measure_perf(trt_model, input_signature, backend_name):
234239
arg_parser.add_argument(
235240
"--benchmark", action="store_true", help="Enable benchmark (default: False)"
236241
)
237-
242+
arg_parser.add_argument(
243+
"--qformat",
244+
help=("Apply quantization format. Options: fp8, nvfp4 (default: None)"),
245+
default=None,
246+
)
247+
arg_parser.add_argument(
248+
"--pre_quantized",
249+
action="store_true",
250+
help="Use pre-quantized hf model weights (default: False)",
251+
)
238252
args = arg_parser.parse_args()
253+
254+
if args.qformat and args.pre_quantized:
255+
print("Error: --qformat and --pre_quantized cannot be used together")
256+
exit()
257+
258+
if args.qformat or args.pre_quantized:
259+
from modelopt.torch.quantization.utils import export_torch_mode
260+
from quantize_utils import (
261+
convert_linear_to_tensorrt_quantized,
262+
quantize_model,
263+
)
264+
239265
with torch.inference_mode():
240266
model = get_model(args)
241267

242268
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer or args.model)
243-
269+
# Set pad token
270+
if tokenizer.pad_token is None:
271+
tokenizer.pad_token = tokenizer.eos_token
244272
# Prepare input for benchmarking or evaluation
245273
if args.benchmark:
246274
input_ids = torch.randint(
@@ -258,6 +286,8 @@ def measure_perf(trt_model, input_signature, backend_name):
258286
pyt_timings = None
259287
pyt_stats = None
260288

289+
if args.qformat != None:
290+
model = quantize_model(model, args, tokenizer)
261291
if args.enable_pytorch_run:
262292
pyt_gen_tokens = generate(
263293
model, input_ids.clone(), MAX_OUTPUT_SEQ_LENGTH, tokenizer.eos_token_id

0 commit comments

Comments
 (0)