Skip to content

Feature suggestion: Option to enable on-the-fly model quantization for faster generation on low vram #349

@klosax

Description

@klosax

Updates:

  • 4-bit quantization is working. The 8-bit integer quantization described here does not work. See post below.
  • The bitsandbytes package need to be installed or added to requirements.txt
  • Works with loras in the 0.5.1 release version.
  • Works with both F1 and Original generation models.
  • Ghosting appear during fast movements, seems worse at lower resolutions. Could possibly be mitigated by adjusting or disabling magcache.

Original t2v tests in 0.5.1 release version using RTX3060 12GB (ubuntu)
2 seconds (= 2 sections), default settings, sage1.06, 4GB preserved vram

Magcache enabled

First section, includes loading (and quantization) of the model:

Resolution 144x112 288x224 448x320 576x448 704x544
Normal BF16 s/it 2.25 3.13 5.23 7.41 10.99
Quantized NF4 s/it 1.06 1.94 3.89 6.32 9.90
Speed gain 112% 61% 34% 17% 11%

Second section, reuses the loaded model:

Resolution 144x112 288x224 448x320 576x448 704x544
Normal BF16 s/it 1.77 2.67 4.72 6.95 10.62
Quantized NF4 s/it 1.08 1.93 3.89 6.29 9.82
Speed gain 64% 38% 21% 10% 8%

Magcache disabled

First section, includes loading (and quantization) of the model:

Resolution 144x112 288x224 448x320 576x448 704x544
Normal BF16 s/it 3.21 5.05 8.27 13.80 20.79
Quantized NF4 s/it 1.43 3.40 6.57 12.04 18.94
Speed gain 124% 49% 26% 15% 10%

Second section, reuses the loaded model:

Resolution 144x112 288x224 448x320 576x448 704x544
Normal BF16 s/it 2.74 4.48 7.77 13.20 20.19
Quantized NF4 s/it 1.41 3.31 6.57 11.94 19.00
Speed gain 94% 35% 18% 11% 6%

The 4-bit quantized model seems to stay fully loaded in vram, since there is minimal speed difference between the sections.

..

I found a simple way to use normalized 4-bit and 8-bit quantization by setting the model loader to do quantization on-the-fly when loading.

I ran some speed tests using 4 bit quantization of the F1 model. The whole quantized model is loaded to vram when quantized, about 7GB free vram is needed.

Test results:

F1 model i2v default settings sage1.06 RTX3060 12GB (ubuntu)
Input image, prompt and seed are identical in all tests.

Resolution 144x112 288x224 448x320 576x448 704x544 832x640
Normal BF16 s/it 2.33 3.32 5.32 7.16 10.78 15.03
Quantized NF4 s/it 1.03 2.14 4.17 6.44 10.10 15.16
Speed gain 126 % 55 % 28 % 11 % 7 % 0 %

Suggestion: Add an option to enable on-the-fly quantization of the models.

Here is how to enable 4 bit or 8 bit quantization:

Change these lines:

# Create the transformer model
self.transformer = HunyuanVideoTransformer3DModelPacked.from_pretrained(
path_to_load,
torch_dtype=torch.bfloat16
).cpu()
# Configure the model
self.transformer.eval()
self.transformer.to(dtype=torch.bfloat16)
self.transformer.requires_grad_(False)

to:


        nf4quant_config = BitsAndBytesConfig(load_in_4bit=True,bnb_4bit_compute_dtype="bfloat16",bnb_4bit_quant_type="nf4",bnb_4bit_use_double_quant=True)
        int8quant_config = BitsAndBytesConfig(load_in_8bit=True)

        # Create the transformer model
        self.transformer = HunyuanVideoTransformer3DModelPacked.from_pretrained(
            path_to_load, 
            torch_dtype=torch.bfloat16,
# uncomment for NF4 quantization
			quantization_config=nf4quant_config
# uncomment for INT8 quantization
#			quantization_config=int8quant_config
        ).cpu()
        
        # Configure the model
        self.transformer.eval()
# comment for quantization
#        self.transformer.to(dtype=torch.bfloat16)
        self.transformer.requires_grad_(False)

And add the following line at the top of the file:
from transformers import BitsAndBytesConfig

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions