-
Notifications
You must be signed in to change notification settings - Fork 6.3k
[Quantization] Add TRT-ModelOpt as a Backend #11173
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
57 commits
Select commit
Hold shift + click to select a range
d7ca877
initial commit
ishan-modi a016c56
update
ishan-modi eb73ab0
updates
ishan-modi 7fdb79e
update
ishan-modi a83bb98
update
ishan-modi 9d9f0b9
update
ishan-modi 7b09750
update
ishan-modi 4fe06ee
Merge branch 'main' into add-trtquant-backend
ishan-modi 71d8a7e
update
ishan-modi 6c74c69
update
ishan-modi 10fb9fe
Merge branch 'main' into add-trtquant-backend
sayakpaul 6c65138
addressed PR comments
ishan-modi 4b32567
Merge remote-tracking branch 'origin/add-trtquant-backend' into add-t…
ishan-modi 915dbf0
update
ishan-modi 3336a08
Merge branch 'main' into add-trtquant-backend
sayakpaul 1c470f2
Merge branch 'main' into add-trtquant-backend
sayakpaul f823a2c
addressed PR comments
ishan-modi e78841e
Merge branch 'add-trtquant-backend' of https://github.com/ishan-modi/…
ishan-modi 8f88f29
update
ishan-modi 212603f
update
ishan-modi 24f1bcb
update
ishan-modi 65097f1
update
ishan-modi 97f94ae
update
ishan-modi 752544f
update
ishan-modi 415901f
Merge branch 'main' into add-trtquant-backend
ishan-modi 482fe78
updates
ishan-modi 488282f
Merge branch 'main' into add-trtquant-backend
ishan-modi 88259c9
Merge branch 'huggingface:main' into add-trtquant-backend
ishan-modi e51be6a
Merge branch 'main' into add-trtquant-backend
ishan-modi d48835d
update
ishan-modi 5c4a4ea
Merge branch 'main' into add-trtquant-backend
ishan-modi 670202d
update
ishan-modi 6dd903f
Merge branch 'add-trtquant-backend' of https://github.com/ishan-modi/…
ishan-modi 3f672d3
Merge branch 'main' into add-trtquant-backend
sayakpaul 64d018c
Merge branch 'main' into add-trtquant-backend
sayakpaul 395e75b
addressed PR comments
ishan-modi 9034661
Merge branch 'main' into add-trtquant-backend
sayakpaul bbbc840
updates
ishan-modi 2076783
Merge branch 'add-trtquant-backend' of https://github.com/ishan-modi/…
ishan-modi c53d251
code formatting
ishan-modi 1ddcc9c
update
ishan-modi 5df6926
addressed PR comments
ishan-modi 8439f01
Merge branch 'main' into add-trtquant-backend
ishan-modi b96da23
Merge branch 'main' into add-trtquant-backend
ishan-modi 0bf90b0
addressed PR comments
ishan-modi b097f0f
Merge branch 'add-trtquant-backend' of https://github.com/ishan-modi/…
ishan-modi cf054d2
addressed PR comments
ishan-modi 0828f50
Merge branch 'main' into add-trtquant-backend
sayakpaul 031298d
Merge branch 'main' into add-trtquant-backend
sayakpaul f345325
Merge branch 'main' into add-trtquant-backend
sayakpaul dd39595
addressed PR comments
ishan-modi d66709b
Merge branch 'add-trtquant-backend' of https://github.com/ishan-modi/…
ishan-modi 81f4785
Merge branch 'main' into add-trtquant-backend
sayakpaul 8f60186
fix docs and dependencies
ishan-modi 8daf21d
Merge branch 'add-trtquant-backend' of https://github.com/ishan-modi/…
ishan-modi 1a8806f
fixed dependency test
ishan-modi cb4e44b
Merge branch 'main' into add-trtquant-backend
sayakpaul File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,141 @@ | ||
<!-- Copyright 2025 The HuggingFace Team. All rights reserved. | ||
|
||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with | ||
the License. You may obtain a copy of the License at | ||
|
||
http://www.apache.org/licenses/LICENSE-2.0 | ||
|
||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on | ||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the | ||
specific language governing permissions and limitations under the License. --> | ||
|
||
# NVIDIA ModelOpt | ||
|
||
[NVIDIA-ModelOpt](https://github.com/NVIDIA/TensorRT-Model-Optimizer) is a unified library of state-of-the-art model optimization techniques like quantization, pruning, distillation, speculative decoding, etc. It compresses deep learning models for downstream deployment frameworks like TensorRT-LLM or TensorRT to optimize inference speed. | ||
|
||
Before you begin, make sure you have nvidia_modelopt installed. | ||
|
||
```bash | ||
pip install -U "nvidia_modelopt[hf]" | ||
``` | ||
|
||
Quantize a model by passing [`NVIDIAModelOptConfig`] to [`~ModelMixin.from_pretrained`] (you can also load pre-quantized models). This works for any model in any modality, as long as it supports loading with [Accelerate](https://hf.co/docs/accelerate/index) and contains `torch.nn.Linear` layers. | ||
|
||
The example below only quantizes the weights to FP8. | ||
|
||
```python | ||
import torch | ||
from diffusers import AutoModel, SanaPipeline, NVIDIAModelOptConfig | ||
|
||
model_id = "Efficient-Large-Model/Sana_600M_1024px_diffusers" | ||
dtype = torch.bfloat16 | ||
|
||
quantization_config = NVIDIAModelOptConfig(quant_type="FP8", quant_method="modelopt") | ||
transformer = AutoModel.from_pretrained( | ||
model_id, | ||
subfolder="transformer", | ||
quantization_config=quantization_config, | ||
torch_dtype=dtype, | ||
) | ||
pipe = SanaPipeline.from_pretrained( | ||
model_id, | ||
transformer=transformer, | ||
torch_dtype=dtype, | ||
) | ||
pipe.to("cuda") | ||
|
||
print(f"Pipeline memory usage: {torch.cuda.max_memory_reserved() / 1024**3:.3f} GB") | ||
|
||
prompt = "A cat holding a sign that says hello world" | ||
image = pipe( | ||
prompt, num_inference_steps=50, guidance_scale=4.5, max_sequence_length=512 | ||
).images[0] | ||
image.save("output.png") | ||
``` | ||
|
||
> **Note:** | ||
> | ||
> The quantization methods in NVIDIA-ModelOpt are designed to reduce the memory footprint of model weights using various QAT (Quantization-Aware Training) and PTQ (Post-Training Quantization) techniques while maintaining model performance. However, the actual performance gain during inference depends on the deployment framework (e.g., TRT-LLM, TensorRT) and the specific hardware configuration. | ||
> | ||
> More details can be found [here](https://github.com/NVIDIA/TensorRT-Model-Optimizer/tree/main/examples). | ||
|
||
## NVIDIAModelOptConfig | ||
|
||
The `NVIDIAModelOptConfig` class accepts three parameters: | ||
- `quant_type`: A string value mentioning one of the quantization types below. | ||
- `modules_to_not_convert`: A list of module full/partial module names for which quantization should not be performed. For example, to not perform any quantization of the [`SD3Transformer2DModel`]'s pos_embed projection blocks, one would specify: `modules_to_not_convert=["pos_embed.proj.weight"]`. | ||
- `disable_conv_quantization`: A boolean value which when set to `True` disables quantization for all convolutional layers in the model. This is useful as channel and block quantization generally don't work well with convolutional layers (used with INT4, NF4, NVFP4). If you want to disable quantization for specific convolutional layers, use `modules_to_not_convert` instead. | ||
- `algorithm`: The algorithm to use for determining scale, defaults to `"max"`. You can check modelopt documentation for more algorithms and details. | ||
- `forward_loop`: The forward loop function to use for calibrating activation during quantization. If not provided, it relies on static scale values computed using the weights only. | ||
- `kwargs`: A dict of keyword arguments to pass to the underlying quantization method which will be invoked based on `quant_type`. | ||
|
||
## Supported quantization types | ||
|
||
ModelOpt supports weight-only, channel and block quantization int8, fp8, int4, nf4, and nvfp4. The quantization methods are designed to reduce the memory footprint of the model weights while maintaining the performance of the model during inference. | ||
|
||
Weight-only quantization stores the model weights in a specific low-bit data type but performs computation with a higher-precision data type, like `bfloat16`. This lowers the memory requirements from model weights but retains the memory peaks for activation computation. | ||
|
||
The quantization methods supported are as follows: | ||
|
||
| **Quantization Type** | **Supported Schemes** | **Required Kwargs** | **Additional Notes** | | ||
|-----------------------|-----------------------|---------------------|----------------------| | ||
| **INT8** | `int8 weight only`, `int8 channel quantization`, `int8 block quantization` | `quant_type`, `quant_type + channel_quantize`, `quant_type + channel_quantize + block_quantize` | | ||
| **FP8** | `fp8 weight only`, `fp8 channel quantization`, `fp8 block quantization` | `quant_type`, `quant_type + channel_quantize`, `quant_type + channel_quantize + block_quantize` | | ||
| **INT4** | `int4 weight only`, `int4 block quantization` | `quant_type`, `quant_type + channel_quantize + block_quantize` | `channel_quantize = -1 is only supported for now`| | ||
| **NF4** | `nf4 weight only`, `nf4 double block quantization` | `quant_type`, `quant_type + channel_quantize + block_quantize + scale_channel_quantize` + `scale_block_quantize` | `channel_quantize = -1 and scale_channel_quantize = -1 are only supported for now` | | ||
| **NVFP4** | `nvfp4 weight only`, `nvfp4 block quantization` | `quant_type`, `quant_type + channel_quantize + block_quantize` | `channel_quantize = -1 is only supported for now`| | ||
|
||
|
||
Refer to the [official modelopt documentation](https://nvidia.github.io/TensorRT-Model-Optimizer/) for a better understanding of the available quantization methods and the exhaustive list of configuration options available. | ||
|
||
## Serializing and Deserializing quantized models | ||
|
||
To serialize a quantized model in a given dtype, first load the model with the desired quantization dtype and then save it using the [`~ModelMixin.save_pretrained`] method. | ||
|
||
```python | ||
import torch | ||
from diffusers import AutoModel, NVIDIAModelOptConfig | ||
from modelopt.torch.opt import enable_huggingface_checkpointing | ||
|
||
enable_huggingface_checkpointing() | ||
SunMarc marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
model_id = "Efficient-Large-Model/Sana_600M_1024px_diffusers" | ||
quant_config_fp8 = {"quant_type": "FP8", "quant_method": "modelopt"} | ||
quant_config_fp8 = NVIDIAModelOptConfig(**quant_config_fp8) | ||
model = AutoModel.from_pretrained( | ||
model_id, | ||
subfolder="transformer", | ||
quantization_config=quant_config_fp8, | ||
torch_dtype=torch.bfloat16, | ||
) | ||
model.save_pretrained('path/to/sana_fp8', safe_serialization=False) | ||
``` | ||
|
||
To load a serialized quantized model, use the [`~ModelMixin.from_pretrained`] method. | ||
|
||
```python | ||
import torch | ||
from diffusers import AutoModel, NVIDIAModelOptConfig, SanaPipeline | ||
from modelopt.torch.opt import enable_huggingface_checkpointing | ||
|
||
enable_huggingface_checkpointing() | ||
SunMarc marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
quantization_config = NVIDIAModelOptConfig(quant_type="FP8", quant_method="modelopt") | ||
transformer = AutoModel.from_pretrained( | ||
"path/to/sana_fp8", | ||
subfolder="transformer", | ||
quantization_config=quantization_config, | ||
torch_dtype=torch.bfloat16, | ||
) | ||
pipe = SanaPipeline.from_pretrained( | ||
"Efficient-Large-Model/Sana_600M_1024px_diffusers", | ||
transformer=transformer, | ||
torch_dtype=torch.bfloat16, | ||
) | ||
pipe.to("cuda") | ||
prompt = "A cat holding a sign that says hello world" | ||
image = pipe( | ||
prompt, num_inference_steps=50, guidance_scale=4.5, max_sequence_length=512 | ||
).images[0] | ||
image.save("output.png") | ||
``` |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from .modelopt_quantizer import NVIDIAModelOptQuantizer |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Lets prefer using
PipelineQuantizationConfig
.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have kept it similar to all the other quantization docs (quanto, torchao etc), can we keep it similar to them for now, in those doc they use specific quant config