Skip to content

Conversation

ishan-modi
Copy link
Contributor

@ishan-modi ishan-modi commented Mar 30, 2025

What does this PR do?

WIP, aimed at adding new backend for quantization #11032. For now, this PR just works for on-the-fly quantization. Loading pre-quantized models errors out and it is to be fixed by NVIDIA team in next release early may

Depends on

  • this to support latest diffusers
  • this to enable INT8 quantization
  • this to enable NF4 quantization
Code
# !pip install "git+https://github.com/ishan-modi/diffusers.git@add-trtquant-backend#egg=diffusers[nvidia_modelopt]"

import torch
from tqdm import tqdm
from diffusers import SanaPipeline, SanaTransformer2DModel, StableDiffusion3Pipeline, , SD3Transformer2DModel
from diffusers.quantizers.quantization_config import NVIDIAModelOptConfig

checkpoint = "Efficient-Large-Model/Sana_600M_1024px_diffusers"
model_cls = SanaTransformer2DModel
pipe_cls = SanaPipeline
# checkpoint = "stabilityai/stable-diffusion-3-medium-diffusers"
# model_cls = SD3Transformer2DModel
# pipe_cls = StableDiffusion3Pipeline

input = {"prompt":"A capybara holding a sign that reads Hello World", "num_inference_steps":28, "guidance_scale":3.5}

quant_config_fp8 = {"quant_type": "FP8", "quant_method": "modelopt"}
quant_config_int8 = {"quant_type": "INT8", "quant_method": "modelopt"}
quant_config_int4 = {"quant_type": "INT4", "quant_method": "modelopt", "block_quantize": 128, "channel_quantize": -1, "modules_to_not_convert": ["conv", "patch_embed"]}
quant_config_nf4 = {"quant_type": "NF4", "quant_method": "modelopt", "block_quantize": 128, "channel_quantize": -1, "scale_block_quantize": 8, "scale_channel_quantize": -1, "modules_to_not_convert": ["conv"]}
quant_config_nvfp4 = {"quant_type": "NVFP4", "quant_method": "modelopt", "block_quantize": 128, "channel_quantize": -1, "modules_to_not_convert": ["conv"]}


def test_quantization(config, checkpoint, model_cls):
    quant_config = NVIDIAModelOptConfig(**config)
    print(quant_config.get_config_from_quant_type())
    quant_model = model_cls.from_pretrained(checkpoint, subfolder="transformer", quantization_config=quant_config, torch_dtype=torch.bfloat16).to('cuda')
    return quant_model

def test_quant_inference(model, input, pipe_cls, iter=1):
    inference_memory = 0
    for _ in tqdm(range(iter)):
        with torch.no_grad():
            output = pipe_cls.from_pretrained(checkpoint, transformer=model, torch_dtype=torch.bfloat16).to('cuda')(**input).images[0]
        inference_memory += torch.cuda.max_memory_allocated()
    inference_memory /= iter
    output.save("test.png")
    print("Inference Memory: ", inference_memory / 1e6)
    torch.cuda.empty_cache()
    torch.cuda.reset_max_memory_allocated()

test_quant_inference(test_quantization(quant_config_fp8, checkpoint, model_cls), input, pipe_cls)
# test_quant_inference(test_quantization(quant_config_int8, checkpoint, model_cls), input, pipe_cls)
# test_quant_inference(test_quantization(quant_config_int4, checkpoint, model_cls), input, pipe_cls)
# test_quant_inference(test_quantization(quant_config_nf4, checkpoint, model_cls), input, pipe_cls)
# test_quant_inference(test_quantization(quant_config_nvfp4, checkpoint, model_cls), input, pipe_cls)
# test_quant_inference(model_cls.from_pretrained(checkpoint, subfolder="transformer", torch_dtype=torch.bfloat16).to('cuda'), input, pipe_cls)

Following is a discussion on speedups while using real_quant with NVIDIA team here

@ishan-modi ishan-modi marked this pull request as draft March 30, 2025 07:45
@ishan-modi
Copy link
Contributor Author

@sayakpaul, would you mind giving a quick look and suggestions

@sayakpaul
Copy link
Member

Thanks for getting started on this. I guess there is a problem here: NVIDIA/TensorRT-Model-Optimizer#165? Additionally, the API should have a TRTConfig in place of just a dict being the quantization config.

@ishan-modi
Copy link
Contributor Author

I think the problem has been fixed the newest release, I just need to bump it up in diffusers requirements, also we can do the following for passing Config class

from diffusers.quantizers.quantization_config import ModelOptConfig

quant_config = ModelOptConfig(quant_type="FP8_WO", modules_to_not_convert=["conv"])
model = SanaTransformer2DModel.from_pretrained(checkpoint, subfolder="transformer", quantization_config=quant_config...

by TRTConfig did you mean including the config classes from ModelOptimizer here ?

@sayakpaul
Copy link
Member

We use namings like BitsAndBytesConfig depending on the backend. See here:
https://github.com/huggingface/diffusers/blob/fb54499614f9603bfaa4c026202c5783841b3a80/src/diffusers/quantizers/quantization_config.py#L177C7-L177C25

So, in this case, we should be using TRTConfig or something similar.

@sayakpaul
Copy link
Member

I think the problem has been fixed the newest release, I just need to bump it up in diffusers requirements

Alright, let's try with the latest fixes then.

@ishan-modi
Copy link
Contributor Author

The newer version wasn't backward compatible hence the issues, I have fixed it.

Related to naming, package name is nvidia_modelopt, hence ModelOpt, but I can make it TRTModelOpt if you'd like ?

@sayakpaul
Copy link
Member

Doesn't it have any reliance on tensorrt?

@ishan-modi
Copy link
Contributor Author

No it doesn't, we can use TRT to compile the quantized model

@sayakpaul
Copy link
Member

No it doesn't, we can use TRT to compile the quantized model

Could you elaborate what you mean by this?

Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks nice. Could you demonstrate some memory savings and any speedups when using modelopt, please? We can then add tests, docs, etc.

@sayakpaul sayakpaul requested a review from DN6 April 9, 2025 05:31
@ishan-modi
Copy link
Contributor Author

Could you elaborate what you mean by this?

Yeah, so for quantizing the model we dont use tensorRT, but once the model is quantized we can compile the model using tensorrt.

@ishan-modi
Copy link
Contributor Author

ishan-modi commented Apr 23, 2025

💾 Model & Inference Memory (in MB)

Quantization Type SanaTransformer2DModel SD3Transformer2DModel FluxTransformer2DModel
Model Size (MB) Inference (MB) Model Size (MB) Inference (MB) Model Size (MB) Inference (MB)
FP8 592.48 2060.80 2142.78 4405.61 13697.76 14438.87
INT4 306.20 2072.30 1160.89 3443.98 8803.12 9559.62
NVFP4 721.99 Err 1148.20 Err 8724.27 Err
Original (BF16) 1183.50 2642.37 4169.90 6457.41 23802.82 -

Following is the code

import torch
from tqdm import tqdm
from diffusers import SanaTransformer2DModel, SD3Transformer2DModel, FluxTransformer2DModel
from diffusers.quantizers.quantization_config import NVIDIAModelOptConfig

checkpoint = "Efficient-Large-Model/Sana_600M_1024px_diffusers"
model_cls = SanaTransformer2DModel
# checkpoint = "stabilityai/stable-diffusion-3-medium-diffusers"
# model_cls = SD3Transformer2DModel
# checkpoint = "black-forest-labs/FLUX.1-dev"
# model_cls = FluxTransformer2DModel

input = lambda: (torch.randn((2, 32, 32, 32), dtype=torch.bfloat16).to('cuda'), torch.randn((2,10,300,2304), dtype=torch.bfloat16).to('cuda'), torch.Tensor([0,0]).to('cuda'))
# input = lambda: (torch.randn((1,16,96,96), dtype=torch.bfloat16).to('cuda'), torch.randn((1,300,4096), dtype=torch.bfloat16).to('cuda'), torch.randn((1, 2048), dtype=torch.bfloat16).to('cuda'), torch.Tensor([0]).to('cuda'))
# input = lambda: (torch.randn((1,1024, 64), dtype=torch.bfloat16).to('cuda'), torch.randn((1,300,4096), dtype=torch.bfloat16).to('cuda'), torch.randn((1, 768), dtype=torch.bfloat16).to('cuda'), torch.Tensor([0]).to('cuda'), torch.randn((300, 3)).to('cuda'), torch.randn((1024, 3)).to('cuda'), torch.Tensor([0]).to('cuda'))

quant_config_fp8 = {"quant_type": "FP8", "quant_method": "modelopt"}
quant_config_int4 = {"quant_type": "INT4", "quant_method": "modelopt", "block_quantize": 128, "channel_quantize": -1}
quant_config_nvfp4 = {"quant_type": "NVFP4", "quant_method": "modelopt", "block_quantize": 128, "channel_quantize": -1, 'modules_to_not_convert' : ['conv']}

def test_quantization(config, checkpoint, model_cls):
    quant_config = NVIDIAModelOptConfig(**config)
    print(quant_config.get_config_from_quant_type())
    quant_model = model_cls.from_pretrained(checkpoint, subfolder="transformer", quantization_config=quant_config, torch_dtype=torch.bfloat16, device_map="balanced").to('cuda')
    print(f"Quant {config['quant_type']} Model Memory Footprint: ", quant_model.get_memory_footprint() / 1e6)
    return quant_model

def test_quant_inference(model, input, iter=10):
    torch.cuda.empty_cache()
    torch.cuda.reset_max_memory_allocated()
    inference_memory = 0
    for _ in tqdm(range(iter)):
        with torch.no_grad():
            output = model(*input())
        inference_memory += torch.cuda.max_memory_allocated()
    inference_memory /= iter
    print("Inference Memory: ", inference_memory / 1e6)

test_quant_inference(test_quantization(quant_config_fp8, checkpoint, model_cls), input)
# test_quant_inference(test_quantization(quant_config_int4, checkpoint, model_cls), input)
# test_quant_inference(test_quantization(quant_config_nvfp4, checkpoint, model_cls), input)
# test_quant_inference(model_cls.from_pretrained(checkpoint, subfolder="transformer", torch_dtype=torch.bfloat16).to('cuda'), input)

Speed Ups

There is no significant speedup between the different quantizations because internally modelopt still uses high precision arithmetic (float32).

Sorry for being a bit late on this, @sayakpaul let me know next steps !

@ishan-modi ishan-modi marked this pull request as ready for review April 23, 2025 23:28
@ishan-modi ishan-modi changed the title [WIP] Add TRT as a Backend [Quantization] Add TRT as a Backend Apr 24, 2025
@sayakpaul
Copy link
Member

@ishan-modi let us know if this is ready to be reviewed.

@ishan-modi
Copy link
Contributor Author

ishan-modi commented Apr 25, 2025

@sayakpaul, I think it is ready for preliminary review, on-the-fly quantization works fine. But loading pre-quantized models errors out and will be fixed in next release here (early may) by NVIDIA team.

@jingyu-ml, just so that you are in the loop

Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking good so far!

Could you also demonstrate some memory and timing numbers with the modelopt toolkit and some visual results?

No need, just saw #11173 (comment). But it doesn't measure the inference memory which is usually done via torch.cuda.max_memory_allocated(). Could we also see those numbers? Would it be possible to make it clear in the PR description that

on-the-fly quantization works fine. But loading pre-quantized models errors out and will be fixed in next release NVIDIA/TensorRT-Model-Optimizer#185 (early may) by NVIDIA team.

@jingyu-ml is it expected to not see any speedups in latency?

Copy link
Member

@SunMarc SunMarc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for this Just some nits, it could be nice to add this quantization scheme to transformers after this gets merged !

@sayakpaul sayakpaul requested review from DN6 and realAsma August 27, 2025 08:17
@sayakpaul
Copy link
Member

@ishan-modi just a quick question. Do we know if the nunchaku SVDQuant method is supported through modelopt? From https://github.com/NVIDIA/TensorRT-Model-Optimizer?tab=readme-ov-file#quantization-examples-docs, it seems like it is supported. But could you confirm?

@ishan-modi
Copy link
Contributor Author

@sayakpaul, yes modelopt does support SVDQuant, but in this integration we support only min-max based calibration see here. I think we should iteratively add advanced quantizations like svd_quant and awq once we have the base going, let me know if you think otherwise.

@sayakpaul
Copy link
Member

That's fine. I wanted to because I think if we can support svd_quant through our modelopt backend, I am happy to drop #12207. Hence wanted to check.

@sayakpaul
Copy link
Member

Will merge after @DN6 has had a chance to review. @ishan-modi can we also include a note in the docs that just performing the conversion step with modelopt won't lead to speed improvements (as pointed out here)?

@realAsma @jingyu-ml after this PR is merged, we could plan writing a post/guide on how to take a modelopt converted diffusers pipeline and use in deployment settings for realizing the actual speed gains.

Copy link
Collaborator

@DN6 DN6 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Excellent work @ishan-modi 👍🏽 Thank you 🙏🏽

@sayakpaul
Copy link
Member

@ishan-modi can we fix the remaining CI problems and then we should be good to go.

@ishan-modi
Copy link
Contributor Author

@sayakpaul, should be fixed now.

@sayakpaul sayakpaul merged commit 4acbfbf into huggingface:main Sep 3, 2025
12 of 13 checks passed
@sayakpaul
Copy link
Member

Congratulations on shipping this thing, @ishan-modi! Thank you!

Let's maybe now focus on the following things to maximize the potential impact:

  • SVDQuant Support
  • Guide to actually benefit from speedups

Happy to help.

@ishan-modi ishan-modi deleted the add-trtquant-backend branch September 3, 2025 05:33
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

10 participants