diff --git a/.github/workflows/nightly_tests.yml b/.github/workflows/nightly_tests.yml index 92165640934d..479e5503eed2 100644 --- a/.github/workflows/nightly_tests.yml +++ b/.github/workflows/nightly_tests.yml @@ -340,6 +340,9 @@ jobs: - backend: "optimum_quanto" test_location: "quanto" additional_deps: [] + - backend: "nvidia_modelopt" + test_location: "modelopt" + additional_deps: [] runs-on: group: aws-g6e-xlarge-plus container: diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index a0ddf8f25654..a97c82796fca 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -188,6 +188,8 @@ title: torchao - local: quantization/quanto title: quanto + - local: quantization/modelopt + title: NVIDIA ModelOpt - title: Model accelerators and hardware isExpanded: false diff --git a/docs/source/en/quantization/modelopt.md b/docs/source/en/quantization/modelopt.md new file mode 100644 index 000000000000..06933d47c221 --- /dev/null +++ b/docs/source/en/quantization/modelopt.md @@ -0,0 +1,141 @@ + + +# NVIDIA ModelOpt + +[NVIDIA-ModelOpt](https://github.com/NVIDIA/TensorRT-Model-Optimizer) is a unified library of state-of-the-art model optimization techniques like quantization, pruning, distillation, speculative decoding, etc. It compresses deep learning models for downstream deployment frameworks like TensorRT-LLM or TensorRT to optimize inference speed. + +Before you begin, make sure you have nvidia_modelopt installed. + +```bash +pip install -U "nvidia_modelopt[hf]" +``` + +Quantize a model by passing [`NVIDIAModelOptConfig`] to [`~ModelMixin.from_pretrained`] (you can also load pre-quantized models). This works for any model in any modality, as long as it supports loading with [Accelerate](https://hf.co/docs/accelerate/index) and contains `torch.nn.Linear` layers. + +The example below only quantizes the weights to FP8. + +```python +import torch +from diffusers import AutoModel, SanaPipeline, NVIDIAModelOptConfig + +model_id = "Efficient-Large-Model/Sana_600M_1024px_diffusers" +dtype = torch.bfloat16 + +quantization_config = NVIDIAModelOptConfig(quant_type="FP8", quant_method="modelopt") +transformer = AutoModel.from_pretrained( + model_id, + subfolder="transformer", + quantization_config=quantization_config, + torch_dtype=dtype, +) +pipe = SanaPipeline.from_pretrained( + model_id, + transformer=transformer, + torch_dtype=dtype, +) +pipe.to("cuda") + +print(f"Pipeline memory usage: {torch.cuda.max_memory_reserved() / 1024**3:.3f} GB") + +prompt = "A cat holding a sign that says hello world" +image = pipe( + prompt, num_inference_steps=50, guidance_scale=4.5, max_sequence_length=512 +).images[0] +image.save("output.png") +``` + +> **Note:** +> +> The quantization methods in NVIDIA-ModelOpt are designed to reduce the memory footprint of model weights using various QAT (Quantization-Aware Training) and PTQ (Post-Training Quantization) techniques while maintaining model performance. However, the actual performance gain during inference depends on the deployment framework (e.g., TRT-LLM, TensorRT) and the specific hardware configuration. +> +> More details can be found [here](https://github.com/NVIDIA/TensorRT-Model-Optimizer/tree/main/examples). + +## NVIDIAModelOptConfig + +The `NVIDIAModelOptConfig` class accepts three parameters: +- `quant_type`: A string value mentioning one of the quantization types below. +- `modules_to_not_convert`: A list of module full/partial module names for which quantization should not be performed. For example, to not perform any quantization of the [`SD3Transformer2DModel`]'s pos_embed projection blocks, one would specify: `modules_to_not_convert=["pos_embed.proj.weight"]`. +- `disable_conv_quantization`: A boolean value which when set to `True` disables quantization for all convolutional layers in the model. This is useful as channel and block quantization generally don't work well with convolutional layers (used with INT4, NF4, NVFP4). If you want to disable quantization for specific convolutional layers, use `modules_to_not_convert` instead. +- `algorithm`: The algorithm to use for determining scale, defaults to `"max"`. You can check modelopt documentation for more algorithms and details. +- `forward_loop`: The forward loop function to use for calibrating activation during quantization. If not provided, it relies on static scale values computed using the weights only. +- `kwargs`: A dict of keyword arguments to pass to the underlying quantization method which will be invoked based on `quant_type`. + +## Supported quantization types + +ModelOpt supports weight-only, channel and block quantization int8, fp8, int4, nf4, and nvfp4. The quantization methods are designed to reduce the memory footprint of the model weights while maintaining the performance of the model during inference. + +Weight-only quantization stores the model weights in a specific low-bit data type but performs computation with a higher-precision data type, like `bfloat16`. This lowers the memory requirements from model weights but retains the memory peaks for activation computation. + +The quantization methods supported are as follows: + +| **Quantization Type** | **Supported Schemes** | **Required Kwargs** | **Additional Notes** | +|-----------------------|-----------------------|---------------------|----------------------| +| **INT8** | `int8 weight only`, `int8 channel quantization`, `int8 block quantization` | `quant_type`, `quant_type + channel_quantize`, `quant_type + channel_quantize + block_quantize` | +| **FP8** | `fp8 weight only`, `fp8 channel quantization`, `fp8 block quantization` | `quant_type`, `quant_type + channel_quantize`, `quant_type + channel_quantize + block_quantize` | +| **INT4** | `int4 weight only`, `int4 block quantization` | `quant_type`, `quant_type + channel_quantize + block_quantize` | `channel_quantize = -1 is only supported for now`| +| **NF4** | `nf4 weight only`, `nf4 double block quantization` | `quant_type`, `quant_type + channel_quantize + block_quantize + scale_channel_quantize` + `scale_block_quantize` | `channel_quantize = -1 and scale_channel_quantize = -1 are only supported for now` | +| **NVFP4** | `nvfp4 weight only`, `nvfp4 block quantization` | `quant_type`, `quant_type + channel_quantize + block_quantize` | `channel_quantize = -1 is only supported for now`| + + +Refer to the [official modelopt documentation](https://nvidia.github.io/TensorRT-Model-Optimizer/) for a better understanding of the available quantization methods and the exhaustive list of configuration options available. + +## Serializing and Deserializing quantized models + +To serialize a quantized model in a given dtype, first load the model with the desired quantization dtype and then save it using the [`~ModelMixin.save_pretrained`] method. + +```python +import torch +from diffusers import AutoModel, NVIDIAModelOptConfig +from modelopt.torch.opt import enable_huggingface_checkpointing + +enable_huggingface_checkpointing() + +model_id = "Efficient-Large-Model/Sana_600M_1024px_diffusers" +quant_config_fp8 = {"quant_type": "FP8", "quant_method": "modelopt"} +quant_config_fp8 = NVIDIAModelOptConfig(**quant_config_fp8) +model = AutoModel.from_pretrained( + model_id, + subfolder="transformer", + quantization_config=quant_config_fp8, + torch_dtype=torch.bfloat16, +) +model.save_pretrained('path/to/sana_fp8', safe_serialization=False) +``` + +To load a serialized quantized model, use the [`~ModelMixin.from_pretrained`] method. + +```python +import torch +from diffusers import AutoModel, NVIDIAModelOptConfig, SanaPipeline +from modelopt.torch.opt import enable_huggingface_checkpointing + +enable_huggingface_checkpointing() + +quantization_config = NVIDIAModelOptConfig(quant_type="FP8", quant_method="modelopt") +transformer = AutoModel.from_pretrained( + "path/to/sana_fp8", + subfolder="transformer", + quantization_config=quantization_config, + torch_dtype=torch.bfloat16, +) +pipe = SanaPipeline.from_pretrained( + "Efficient-Large-Model/Sana_600M_1024px_diffusers", + transformer=transformer, + torch_dtype=torch.bfloat16, +) +pipe.to("cuda") +prompt = "A cat holding a sign that says hello world" +image = pipe( + prompt, num_inference_steps=50, guidance_scale=4.5, max_sequence_length=512 +).images[0] +image.save("output.png") +``` diff --git a/setup.py b/setup.py index 62d984d9b6a0..ba3ad8e2b307 100644 --- a/setup.py +++ b/setup.py @@ -132,6 +132,7 @@ "gguf>=0.10.0", "torchao>=0.7.0", "bitsandbytes>=0.43.3", + "nvidia_modelopt[hf]>=0.33.1", "regex!=2019.12.17", "requests", "tensorboard", @@ -244,6 +245,7 @@ def run(self): extras["gguf"] = deps_list("gguf", "accelerate") extras["optimum_quanto"] = deps_list("optimum_quanto", "accelerate") extras["torchao"] = deps_list("torchao", "accelerate") +extras["nvidia_modelopt"] = deps_list("nvidia_modelopt[hf]") if os.name == "nt": # windows extras["flax"] = [] # jax is not supported on windows diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 762ae3846a7d..fa5dd6482c44 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -13,6 +13,7 @@ is_k_diffusion_available, is_librosa_available, is_note_seq_available, + is_nvidia_modelopt_available, is_onnx_available, is_opencv_available, is_optimum_quanto_available, @@ -111,6 +112,18 @@ else: _import_structure["quantizers.quantization_config"].append("QuantoConfig") +try: + if not is_torch_available() and not is_accelerate_available() and not is_nvidia_modelopt_available(): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from .utils import dummy_nvidia_modelopt_objects + + _import_structure["utils.dummy_nvidia_modelopt_objects"] = [ + name for name in dir(dummy_nvidia_modelopt_objects) if not name.startswith("_") + ] +else: + _import_structure["quantizers.quantization_config"].append("NVIDIAModelOptConfig") + try: if not is_onnx_available(): raise OptionalDependencyNotAvailable() @@ -795,6 +808,14 @@ else: from .quantizers.quantization_config import QuantoConfig + try: + if not is_nvidia_modelopt_available(): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + from .utils.dummy_nvidia_modelopt_objects import * + else: + from .quantizers.quantization_config import NVIDIAModelOptConfig + try: if not is_onnx_available(): raise OptionalDependencyNotAvailable() diff --git a/src/diffusers/dependency_versions_table.py b/src/diffusers/dependency_versions_table.py index a3832cf9b89b..79dc4c50a050 100644 --- a/src/diffusers/dependency_versions_table.py +++ b/src/diffusers/dependency_versions_table.py @@ -39,6 +39,7 @@ "gguf": "gguf>=0.10.0", "torchao": "torchao>=0.7.0", "bitsandbytes": "bitsandbytes>=0.43.3", + "nvidia_modelopt[hf]": "nvidia_modelopt[hf]>=0.33.1", "regex": "regex!=2019.12.17", "requests": "requests", "tensorboard": "tensorboard", diff --git a/src/diffusers/quantizers/auto.py b/src/diffusers/quantizers/auto.py index ce214ae7bc17..070bcd0b2151 100644 --- a/src/diffusers/quantizers/auto.py +++ b/src/diffusers/quantizers/auto.py @@ -21,9 +21,11 @@ from .bitsandbytes import BnB4BitDiffusersQuantizer, BnB8BitDiffusersQuantizer from .gguf import GGUFQuantizer +from .modelopt import NVIDIAModelOptQuantizer from .quantization_config import ( BitsAndBytesConfig, GGUFQuantizationConfig, + NVIDIAModelOptConfig, QuantizationConfigMixin, QuantizationMethod, QuantoConfig, @@ -39,6 +41,7 @@ "gguf": GGUFQuantizer, "quanto": QuantoQuantizer, "torchao": TorchAoHfQuantizer, + "modelopt": NVIDIAModelOptQuantizer, } AUTO_QUANTIZATION_CONFIG_MAPPING = { @@ -47,6 +50,7 @@ "gguf": GGUFQuantizationConfig, "quanto": QuantoConfig, "torchao": TorchAoConfig, + "modelopt": NVIDIAModelOptConfig, } @@ -137,6 +141,9 @@ def merge_quantization_configs( if isinstance(quantization_config, dict): quantization_config = cls.from_dict(quantization_config) + if isinstance(quantization_config, NVIDIAModelOptConfig): + quantization_config.check_model_patching() + if warning_msg != "": warnings.warn(warning_msg) diff --git a/src/diffusers/quantizers/modelopt/__init__.py b/src/diffusers/quantizers/modelopt/__init__.py new file mode 100644 index 000000000000..ae0951cb30d1 --- /dev/null +++ b/src/diffusers/quantizers/modelopt/__init__.py @@ -0,0 +1 @@ +from .modelopt_quantizer import NVIDIAModelOptQuantizer diff --git a/src/diffusers/quantizers/modelopt/modelopt_quantizer.py b/src/diffusers/quantizers/modelopt/modelopt_quantizer.py new file mode 100644 index 000000000000..534f752321b3 --- /dev/null +++ b/src/diffusers/quantizers/modelopt/modelopt_quantizer.py @@ -0,0 +1,190 @@ +from typing import TYPE_CHECKING, Any, Dict, List, Union + +from ...utils import ( + get_module_from_name, + is_accelerate_available, + is_nvidia_modelopt_available, + is_torch_available, + logging, +) +from ..base import DiffusersQuantizer + + +if TYPE_CHECKING: + from ...models.modeling_utils import ModelMixin + + +if is_torch_available(): + import torch + import torch.nn as nn + +if is_accelerate_available(): + from accelerate.utils import set_module_tensor_to_device + + +logger = logging.get_logger(__name__) + + +class NVIDIAModelOptQuantizer(DiffusersQuantizer): + r""" + Diffusers Quantizer for TensorRT Model Optimizer + """ + + use_keep_in_fp32_modules = True + requires_calibration = False + required_packages = ["nvidia_modelopt"] + + def __init__(self, quantization_config, **kwargs): + super().__init__(quantization_config, **kwargs) + + def validate_environment(self, *args, **kwargs): + if not is_nvidia_modelopt_available(): + raise ImportError( + "Loading an nvidia-modelopt quantized model requires nvidia-modelopt library (`pip install nvidia-modelopt`)" + ) + + self.offload = False + + device_map = kwargs.get("device_map", None) + if isinstance(device_map, dict): + if "cpu" in device_map.values() or "disk" in device_map.values(): + if self.pre_quantized: + raise ValueError( + "You are attempting to perform cpu/disk offload with a pre-quantized modelopt model " + "This is not supported yet. Please remove the CPU or disk device from the `device_map` argument." + ) + else: + self.offload = True + + def check_if_quantized_param( + self, + model: "ModelMixin", + param_value: "torch.Tensor", + param_name: str, + state_dict: Dict[str, Any], + **kwargs, + ): + # ModelOpt imports diffusers internally. This is here to prevent circular imports + from modelopt.torch.quantization.utils import is_quantized + + module, tensor_name = get_module_from_name(model, param_name) + if self.pre_quantized: + return True + elif is_quantized(module) and "weight" in tensor_name: + return True + return False + + def create_quantized_param( + self, + model: "ModelMixin", + param_value: "torch.Tensor", + param_name: str, + target_device: "torch.device", + *args, + **kwargs, + ): + """ + Create the quantized parameter by calling .calibrate() after setting it to the module. + """ + # ModelOpt imports diffusers internally. This is here to prevent circular imports + import modelopt.torch.quantization as mtq + + dtype = kwargs.get("dtype", torch.float32) + module, tensor_name = get_module_from_name(model, param_name) + if self.pre_quantized: + module._parameters[tensor_name] = torch.nn.Parameter(param_value.to(device=target_device)) + else: + set_module_tensor_to_device(model, param_name, target_device, param_value, dtype) + mtq.calibrate( + module, self.quantization_config.modelopt_config["algorithm"], self.quantization_config.forward_loop + ) + mtq.compress(module) + module.weight.requires_grad = False + + def adjust_max_memory(self, max_memory: Dict[str, Union[int, str]]) -> Dict[str, Union[int, str]]: + max_memory = {key: val * 0.90 for key, val in max_memory.items()} + return max_memory + + def adjust_target_dtype(self, target_dtype: "torch.dtype") -> "torch.dtype": + if self.quantization_config.quant_type == "FP8": + target_dtype = torch.float8_e4m3fn + return target_dtype + + def update_torch_dtype(self, torch_dtype: "torch.dtype" = None) -> "torch.dtype": + if torch_dtype is None: + logger.info("You did not specify `torch_dtype` in `from_pretrained`. Setting it to `torch.float32`.") + torch_dtype = torch.float32 + return torch_dtype + + def get_conv_param_names(self, model: "ModelMixin") -> List[str]: + """ + Get parameter names for all convolutional layers in a HuggingFace ModelMixin. Includes Conv1d/2d/3d and + ConvTranspose1d/2d/3d. + """ + conv_types = ( + nn.Conv1d, + nn.Conv2d, + nn.Conv3d, + nn.ConvTranspose1d, + nn.ConvTranspose2d, + nn.ConvTranspose3d, + ) + + conv_param_names = [] + for name, module in model.named_modules(): + if isinstance(module, conv_types): + for param_name, _ in module.named_parameters(recurse=False): + conv_param_names.append(f"{name}.{param_name}") + + return conv_param_names + + def _process_model_before_weight_loading( + self, + model: "ModelMixin", + device_map, + keep_in_fp32_modules: List[str] = [], + **kwargs, + ): + # ModelOpt imports diffusers internally. This is here to prevent circular imports + import modelopt.torch.opt as mto + + if self.pre_quantized: + return + + modules_to_not_convert = self.quantization_config.modules_to_not_convert + + if modules_to_not_convert is None: + modules_to_not_convert = [] + if isinstance(modules_to_not_convert, str): + modules_to_not_convert = [modules_to_not_convert] + modules_to_not_convert.extend(keep_in_fp32_modules) + if self.quantization_config.disable_conv_quantization: + modules_to_not_convert.extend(self.get_conv_param_names(model)) + + for module in modules_to_not_convert: + self.quantization_config.modelopt_config["quant_cfg"]["*" + module + "*"] = {"enable": False} + self.quantization_config.modules_to_not_convert = modules_to_not_convert + mto.apply_mode(model, mode=[("quantize", self.quantization_config.modelopt_config)]) + model.config.quantization_config = self.quantization_config + + def _process_model_after_weight_loading(self, model, **kwargs): + # ModelOpt imports diffusers internally. This is here to prevent circular imports + from modelopt.torch.opt import ModeloptStateManager + + if self.pre_quantized: + return model + + for _, m in model.named_modules(): + if hasattr(m, ModeloptStateManager._state_key) and m is not model: + ModeloptStateManager.remove_state(m) + + return model + + @property + def is_trainable(self): + return True + + @property + def is_serializable(self): + self.quantization_config.check_model_patching(operation="saving") + return True diff --git a/src/diffusers/quantizers/quantization_config.py b/src/diffusers/quantizers/quantization_config.py index 871faf076e5a..bf857956512c 100644 --- a/src/diffusers/quantizers/quantization_config.py +++ b/src/diffusers/quantizers/quantization_config.py @@ -25,10 +25,11 @@ import inspect import json import os +import warnings from dataclasses import dataclass from enum import Enum from functools import partial -from typing import Any, Dict, List, Optional, Union +from typing import Any, Callable, Dict, List, Optional, Union from packaging import version @@ -46,6 +47,7 @@ class QuantizationMethod(str, Enum): GGUF = "gguf" TORCHAO = "torchao" QUANTO = "quanto" + MODELOPT = "modelopt" if is_torchao_available(): @@ -268,7 +270,14 @@ def __init__( if bnb_4bit_quant_storage is None: self.bnb_4bit_quant_storage = torch.uint8 elif isinstance(bnb_4bit_quant_storage, str): - if bnb_4bit_quant_storage not in ["float16", "float32", "int8", "uint8", "float64", "bfloat16"]: + if bnb_4bit_quant_storage not in [ + "float16", + "float32", + "int8", + "uint8", + "float64", + "bfloat16", + ]: raise ValueError( "`bnb_4bit_quant_storage` must be a valid string (one of 'float16', 'float32', 'int8', 'uint8', 'float64', 'bfloat16') " ) @@ -479,7 +488,12 @@ class TorchAoConfig(QuantizationConfigMixin): ``` """ - def __init__(self, quant_type: str, modules_to_not_convert: Optional[List[str]] = None, **kwargs) -> None: + def __init__( + self, + quant_type: str, + modules_to_not_convert: Optional[List[str]] = None, + **kwargs, + ) -> None: self.quant_method = QuantizationMethod.TORCHAO self.quant_type = quant_type self.modules_to_not_convert = modules_to_not_convert @@ -724,3 +738,194 @@ def post_init(self): accepted_weights = ["float8", "int8", "int4", "int2"] if self.weights_dtype not in accepted_weights: raise ValueError(f"Only support weights in {accepted_weights} but found {self.weights_dtype}") + + +@dataclass +class NVIDIAModelOptConfig(QuantizationConfigMixin): + """This is a config class to use nvidia modelopt for quantization. + + Args: + quant_type (`str`): + The type of quantization we want to use, following is how to use: + **weightquant_activationquant ==> FP8_FP8** In the above example we have use FP8 for both weight and + activation quantization. Following are the all the options: + - FP8 + - INT8 + - INT4 + - NF4 + - NVFP4 + modules_to_not_convert (`List[str]`, *optional*, default to `None`): + The list of modules to not quantize, useful for quantizing models that explicitly require to have some + weight_only (`bool`, *optional*, default to `False`): + If set to `True`, the quantization will be applied only to the weights of the model. + channel_quantize (`int`, *optional*, default to `None`): + The channel quantization axis, useful for quantizing models across different axes. + block_quantize (`int`, *optional*, default to `None`): + The block size, useful to further quantize each channel/axes into blocks. + scale_channel_quantize (`int`, *optional*, default to `None`): + The scale channel quantization axis, useful for quantizing calculated scale across different axes. + scale_block_quantize (`int`, *optional*, default to `None`): + The scale block size, useful for quantizing each scale channel/axes into blocks. + algorithm (`str`, *optional*, default to `"max"`): + The algorithm to use for quantization, currently only supports `"max"`. + forward_loop (`Callable`, *optional*, default to `None`): + The forward loop function to use for calibration during quantization. + modelopt_config (`dict`, *optional*, default to `None`): + The modelopt config, useful for passing custom configs to modelopt. + disable_conv_quantization (`bool`, *optional*, default to `False`): + If set to `True`, the quantization will be disabled for convolutional layers. + kwargs (`Dict[str, Any]`, *optional*): + Additional parameters which are to be used for calibration. + """ + + quanttype_to_numbits = { + "FP8": (4, 3), + "INT8": 8, + "INT4": 4, + "NF4": 4, + "NVFP4": (2, 1), + } + quanttype_to_scalingbits = { + "NF4": 8, + "NVFP4": (4, 3), + } + + def __init__( + self, + quant_type: str, + modules_to_not_convert: Optional[List[str]] = None, + weight_only: bool = True, + channel_quantize: Optional[int] = None, + block_quantize: Optional[int] = None, + scale_channel_quantize: Optional[int] = None, + scale_block_quantize: Optional[int] = None, + algorithm: str = "max", + forward_loop: Optional[Callable] = None, + modelopt_config: Optional[dict] = None, + disable_conv_quantization: bool = False, + **kwargs, + ) -> None: + self.quant_method = QuantizationMethod.MODELOPT + self._normalize_quant_type(quant_type) + self.modules_to_not_convert = modules_to_not_convert + self.weight_only = weight_only + self.channel_quantize = channel_quantize + self.block_quantize = block_quantize + self.calib_cfg = { + "method": algorithm, + # add more options here if needed + } + self.forward_loop = forward_loop + self.scale_channel_quantize = scale_channel_quantize + self.scale_block_quantize = scale_block_quantize + self.modelopt_config = self.get_config_from_quant_type() if not modelopt_config else modelopt_config + self.disable_conv_quantization = disable_conv_quantization + + def check_model_patching(self, operation: str = "loading"): + # ModelOpt imports diffusers internally. This is here to prevent circular imports + from modelopt.torch.opt.plugins.huggingface import _PATCHED_CLASSES + + if len(_PATCHED_CLASSES) == 0: + warning_msg = ( + f"Not {operation} weights in modelopt format. This might cause unreliable behavior." + "Please make sure to run the following code before loading/saving model weights:\n\n" + " from modelopt.torch.opt import enable_huggingface_checkpointing\n" + " enable_huggingface_checkpointing()\n" + ) + warnings.warn(warning_msg) + + def _normalize_quant_type(self, quant_type: str) -> str: + """ + Validates and normalizes the quantization type string. + + Splits the quant_type into weight and activation components, verifies them against supported types, and + replaces unsupported values with safe defaults. + + Args: + quant_type (str): The input quantization type string (e.g., 'FP8_INT8'). + + Returns: + str: A valid quantization type string (e.g., 'FP8_INT8' or 'FP8'). + """ + parts = quant_type.split("_") + w_type = parts[0] + act_type = parts[1] if len(parts) > 1 else None + if len(parts) > 2: + logger.warning(f"Quantization type {quant_type} is not supported. Picking FP8_INT8 as default") + w_type = "FP8" + act_type = None + else: + if w_type not in NVIDIAModelOptConfig.quanttype_to_numbits: + logger.warning(f"Weight Quantization type {w_type} is not supported. Picking FP8 as default") + w_type = "FP8" + if act_type is not None and act_type not in NVIDIAModelOptConfig.quanttype_to_numbits: + logger.warning(f"Activation Quantization type {act_type} is not supported. Picking INT8 as default") + act_type = None + self.quant_type = w_type + ("_" + act_type if act_type is not None else "") + + def get_config_from_quant_type(self) -> Dict[str, Any]: + """ + Get the config from the quantization type. + """ + import modelopt.torch.quantization as mtq + + BASE_CONFIG = { + "quant_cfg": { + "*weight_quantizer": {"fake_quant": False}, + "*input_quantizer": {}, + "*output_quantizer": {"enable": False}, + "*q_bmm_quantizer": {}, + "*k_bmm_quantizer": {}, + "*v_bmm_quantizer": {}, + "*softmax_quantizer": {}, + **mtq.config._default_disabled_quantizer_cfg, + }, + "algorithm": self.calib_cfg, + } + + quant_cfg = BASE_CONFIG["quant_cfg"] + if self.weight_only: + for k in quant_cfg: + if "*weight_quantizer" not in k and not quant_cfg[k]: + quant_cfg[k]["enable"] = False + + parts = self.quant_type.split("_") + w_type = parts[0] + act_type = parts[1].replace("A", "") if len(parts) > 1 else None + for k in quant_cfg: + if k not in mtq.config._default_disabled_quantizer_cfg and "enable" not in quant_cfg[k]: + if k == "*input_quantizer": + if act_type is not None: + quant_cfg[k]["num_bits"] = NVIDIAModelOptConfig.quanttype_to_numbits[act_type] + continue + quant_cfg[k]["num_bits"] = NVIDIAModelOptConfig.quanttype_to_numbits[w_type] + + if self.block_quantize is not None and self.channel_quantize is not None: + quant_cfg["*weight_quantizer"]["block_sizes"] = {self.channel_quantize: self.block_quantize} + quant_cfg["*input_quantizer"]["block_sizes"] = { + self.channel_quantize: self.block_quantize, + "type": "dynamic", + } + elif self.channel_quantize is not None: + quant_cfg["*weight_quantizer"]["axis"] = self.channel_quantize + quant_cfg["*input_quantizer"]["axis"] = self.channel_quantize + quant_cfg["*input_quantizer"]["type"] = "dynamic" + + # Only fixed scaling sizes are supported for now in modelopt + if self.scale_channel_quantize is not None and self.scale_block_quantize is not None: + if w_type in NVIDIAModelOptConfig.quanttype_to_scalingbits: + quant_cfg["*weight_quantizer"]["block_sizes"].update( + { + "scale_bits": NVIDIAModelOptConfig.quanttype_to_scalingbits[w_type], + "scale_block_sizes": {self.scale_channel_quantize: self.scale_block_quantize}, + } + ) + if act_type and act_type in NVIDIAModelOptConfig.quanttype_to_scalingbits: + quant_cfg["*input_quantizer"]["block_sizes"].update( + { + "scale_bits": NVIDIAModelOptConfig.quanttype_to_scalingbits[act_type], + "scale_block_sizes": {self.scale_channel_quantize: self.scale_block_quantize}, + } + ) + + return BASE_CONFIG diff --git a/src/diffusers/utils/__init__.py b/src/diffusers/utils/__init__.py index b27cf981edeb..63932221b207 100644 --- a/src/diffusers/utils/__init__.py +++ b/src/diffusers/utils/__init__.py @@ -89,6 +89,8 @@ is_matplotlib_available, is_nltk_available, is_note_seq_available, + is_nvidia_modelopt_available, + is_nvidia_modelopt_version, is_onnx_available, is_opencv_available, is_optimum_quanto_available, diff --git a/src/diffusers/utils/dummy_nvidia_modelopt_objects.py b/src/diffusers/utils/dummy_nvidia_modelopt_objects.py new file mode 100644 index 000000000000..046b28223b3d --- /dev/null +++ b/src/diffusers/utils/dummy_nvidia_modelopt_objects.py @@ -0,0 +1,17 @@ +# This file is autogenerated by the command `make fix-copies`, do not edit. +from ..utils import DummyObject, requires_backends + + +class NVIDIAModelOptConfig(metaclass=DummyObject): + _backends = ["nvidia_modelopt"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["nvidia_modelopt"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["nvidia_modelopt"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["nvidia_modelopt"]) diff --git a/src/diffusers/utils/import_utils.py b/src/diffusers/utils/import_utils.py index 153be057381d..9399ccd2a7a3 100644 --- a/src/diffusers/utils/import_utils.py +++ b/src/diffusers/utils/import_utils.py @@ -226,6 +226,7 @@ def _is_package_available(pkg_name: str, get_dist_name: bool = False) -> Tuple[b _flash_attn_available, _flash_attn_version = _is_package_available("flash_attn") _flash_attn_3_available, _flash_attn_3_version = _is_package_available("flash_attn_3") _kornia_available, _kornia_version = _is_package_available("kornia") +_nvidia_modelopt_available, _nvidia_modelopt_version = _is_package_available("modelopt", get_dist_name=True) def is_torch_available(): @@ -364,6 +365,10 @@ def is_optimum_quanto_available(): return _optimum_quanto_available +def is_nvidia_modelopt_available(): + return _nvidia_modelopt_available + + def is_timm_available(): return _timm_available @@ -830,6 +835,21 @@ def is_optimum_quanto_version(operation: str, version: str): return compare_versions(parse(_optimum_quanto_version), operation, version) +def is_nvidia_modelopt_version(operation: str, version: str): + """ + Compares the current Nvidia ModelOpt version to a given reference with an operation. + + Args: + operation (`str`): + A string representation of an operator, such as `">"` or `"<="` + version (`str`): + A version string + """ + if not _nvidia_modelopt_available: + return False + return compare_versions(parse(_nvidia_modelopt_version), operation, version) + + def is_xformers_version(operation: str, version: str): """ Compares the current xformers version to a given reference with an operation. diff --git a/src/diffusers/utils/testing_utils.py b/src/diffusers/utils/testing_utils.py index 6d6a7d6ce4bb..3297bb5fdcd6 100644 --- a/src/diffusers/utils/testing_utils.py +++ b/src/diffusers/utils/testing_utils.py @@ -38,6 +38,7 @@ is_gguf_available, is_kernels_available, is_note_seq_available, + is_nvidia_modelopt_available, is_onnx_available, is_opencv_available, is_optimum_quanto_available, @@ -638,6 +639,18 @@ def decorator(test_case): return decorator +def require_modelopt_version_greater_or_equal(modelopt_version): + def decorator(test_case): + correct_nvidia_modelopt_version = is_nvidia_modelopt_available() and version.parse( + version.parse(importlib.metadata.version("modelopt")).base_version + ) >= version.parse(modelopt_version) + return unittest.skipUnless( + correct_nvidia_modelopt_version, f"Test requires modelopt with version greater than {modelopt_version}." + )(test_case) + + return decorator + + def require_kernels_version_greater_or_equal(kernels_version): def decorator(test_case): correct_kernels_version = is_kernels_available() and version.parse( diff --git a/tests/others/test_dependencies.py b/tests/others/test_dependencies.py index a08129a1e9c9..db22f10c4b3c 100644 --- a/tests/others/test_dependencies.py +++ b/tests/others/test_dependencies.py @@ -39,6 +39,8 @@ def test_backend_registration(self): backend = "invisible-watermark" elif backend == "opencv": backend = "opencv-python" + elif backend == "nvidia_modelopt": + backend = "nvidia_modelopt[hf]" assert backend in deps, f"{backend} is not in the deps table!" def test_pipeline_imports(self): diff --git a/tests/quantization/modelopt/__init__.py b/tests/quantization/modelopt/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/quantization/modelopt/test_modelopt.py b/tests/quantization/modelopt/test_modelopt.py new file mode 100644 index 000000000000..6b0624a28083 --- /dev/null +++ b/tests/quantization/modelopt/test_modelopt.py @@ -0,0 +1,306 @@ +import gc +import tempfile +import unittest + +from diffusers import NVIDIAModelOptConfig, SD3Transformer2DModel, StableDiffusion3Pipeline +from diffusers.utils import is_nvidia_modelopt_available, is_torch_available +from diffusers.utils.testing_utils import ( + backend_empty_cache, + backend_reset_peak_memory_stats, + enable_full_determinism, + nightly, + numpy_cosine_similarity_distance, + require_accelerate, + require_big_accelerator, + require_modelopt_version_greater_or_equal, + require_torch_cuda_compatibility, + torch_device, +) + + +if is_nvidia_modelopt_available(): + import modelopt.torch.quantization as mtq + +if is_torch_available(): + import torch + + from ..utils import LoRALayer, get_memory_consumption_stat + +enable_full_determinism() + + +@nightly +@require_big_accelerator +@require_accelerate +@require_modelopt_version_greater_or_equal("0.33.1") +class ModelOptBaseTesterMixin: + model_id = "hf-internal-testing/tiny-sd3-pipe" + model_cls = SD3Transformer2DModel + pipeline_cls = StableDiffusion3Pipeline + torch_dtype = torch.bfloat16 + expected_memory_reduction = 0.0 + keep_in_fp32_module = "" + modules_to_not_convert = "" + _test_torch_compile = False + + def setUp(self): + backend_reset_peak_memory_stats(torch_device) + backend_empty_cache(torch_device) + gc.collect() + + def tearDown(self): + backend_reset_peak_memory_stats(torch_device) + backend_empty_cache(torch_device) + gc.collect() + + def get_dummy_init_kwargs(self): + return {"quant_type": "FP8"} + + def get_dummy_model_init_kwargs(self): + return { + "pretrained_model_name_or_path": self.model_id, + "torch_dtype": self.torch_dtype, + "quantization_config": NVIDIAModelOptConfig(**self.get_dummy_init_kwargs()), + "subfolder": "transformer", + } + + def test_modelopt_layers(self): + model = self.model_cls.from_pretrained(**self.get_dummy_model_init_kwargs()) + for name, module in model.named_modules(): + if isinstance(module, torch.nn.Linear): + assert mtq.utils.is_quantized(module) + + def test_modelopt_memory_usage(self): + inputs = self.get_dummy_inputs() + inputs = { + k: v.to(device=torch_device, dtype=torch.bfloat16) for k, v in inputs.items() if not isinstance(v, bool) + } + + unquantized_model = self.model_cls.from_pretrained( + self.model_id, torch_dtype=self.torch_dtype, subfolder="transformer" + ) + unquantized_model.to(torch_device) + unquantized_model_memory = get_memory_consumption_stat(unquantized_model, inputs) + + quantized_model = self.model_cls.from_pretrained(**self.get_dummy_model_init_kwargs()) + quantized_model.to(torch_device) + quantized_model_memory = get_memory_consumption_stat(quantized_model, inputs) + + assert unquantized_model_memory / quantized_model_memory >= self.expected_memory_reduction + + def test_keep_modules_in_fp32(self): + _keep_in_fp32_modules = self.model_cls._keep_in_fp32_modules + self.model_cls._keep_in_fp32_modules = self.keep_in_fp32_module + + model = self.model_cls.from_pretrained(**self.get_dummy_model_init_kwargs()) + model.to(torch_device) + + for name, module in model.named_modules(): + if isinstance(module, torch.nn.Linear): + if name in model._keep_in_fp32_modules: + assert module.weight.dtype == torch.float32 + self.model_cls._keep_in_fp32_modules = _keep_in_fp32_modules + + def test_modules_to_not_convert(self): + init_kwargs = self.get_dummy_model_init_kwargs() + quantization_config_kwargs = self.get_dummy_init_kwargs() + quantization_config_kwargs.update({"modules_to_not_convert": self.modules_to_not_convert}) + quantization_config = NVIDIAModelOptConfig(**quantization_config_kwargs) + init_kwargs.update({"quantization_config": quantization_config}) + + model = self.model_cls.from_pretrained(**init_kwargs) + model.to(torch_device) + + for name, module in model.named_modules(): + if name in self.modules_to_not_convert: + assert not mtq.utils.is_quantized(module) + + def test_dtype_assignment(self): + model = self.model_cls.from_pretrained(**self.get_dummy_model_init_kwargs()) + + with self.assertRaises(ValueError): + model.to(torch.float16) + + with self.assertRaises(ValueError): + device_0 = f"{torch_device}:0" + model.to(device=device_0, dtype=torch.float16) + + with self.assertRaises(ValueError): + model.float() + + with self.assertRaises(ValueError): + model.half() + + model.to(torch_device) + + def test_serialization(self): + model = self.model_cls.from_pretrained(**self.get_dummy_model_init_kwargs()) + inputs = self.get_dummy_inputs() + + model.to(torch_device) + with torch.no_grad(): + model_output = model(**inputs) + + with tempfile.TemporaryDirectory() as tmp_dir: + model.save_pretrained(tmp_dir) + saved_model = self.model_cls.from_pretrained( + tmp_dir, + torch_dtype=torch.bfloat16, + ) + + saved_model.to(torch_device) + with torch.no_grad(): + saved_model_output = saved_model(**inputs) + + assert torch.allclose(model_output.sample, saved_model_output.sample, rtol=1e-5, atol=1e-5) + + def test_torch_compile(self): + if not self._test_torch_compile: + return + + model = self.model_cls.from_pretrained(**self.get_dummy_model_init_kwargs()) + compiled_model = torch.compile(model, mode="max-autotune", fullgraph=True, dynamic=False) + + model.to(torch_device) + with torch.no_grad(): + model_output = model(**self.get_dummy_inputs()).sample + + compiled_model.to(torch_device) + with torch.no_grad(): + compiled_model_output = compiled_model(**self.get_dummy_inputs()).sample + + model_output = model_output.detach().float().cpu().numpy() + compiled_model_output = compiled_model_output.detach().float().cpu().numpy() + + max_diff = numpy_cosine_similarity_distance(model_output.flatten(), compiled_model_output.flatten()) + assert max_diff < 1e-3 + + def test_device_map_error(self): + with self.assertRaises(ValueError): + _ = self.model_cls.from_pretrained( + **self.get_dummy_model_init_kwargs(), + device_map={0: "8GB", "cpu": "16GB"}, + ) + + def get_dummy_inputs(self): + batch_size = 1 + seq_len = 16 + height = width = 32 + num_latent_channels = 4 + caption_channels = 8 + + torch.manual_seed(0) + hidden_states = torch.randn((batch_size, num_latent_channels, height, width)).to( + torch_device, dtype=torch.bfloat16 + ) + encoder_hidden_states = torch.randn((batch_size, seq_len, caption_channels)).to( + torch_device, dtype=torch.bfloat16 + ) + timestep = torch.tensor([1.0]).to(torch_device, dtype=torch.bfloat16).expand(batch_size) + + return { + "hidden_states": hidden_states, + "encoder_hidden_states": encoder_hidden_states, + "timestep": timestep, + } + + def test_model_cpu_offload(self): + init_kwargs = self.get_dummy_init_kwargs() + transformer = self.model_cls.from_pretrained( + self.model_id, + quantization_config=NVIDIAModelOptConfig(**init_kwargs), + subfolder="transformer", + torch_dtype=torch.bfloat16, + ) + pipe = self.pipeline_cls.from_pretrained(self.model_id, transformer=transformer, torch_dtype=torch.bfloat16) + pipe.enable_model_cpu_offload(device=torch_device) + _ = pipe("a cat holding a sign that says hello", num_inference_steps=2) + + def test_training(self): + quantization_config = NVIDIAModelOptConfig(**self.get_dummy_init_kwargs()) + quantized_model = self.model_cls.from_pretrained( + self.model_id, + subfolder="transformer", + quantization_config=quantization_config, + torch_dtype=torch.bfloat16, + ).to(torch_device) + + for param in quantized_model.parameters(): + param.requires_grad = False + if param.ndim == 1: + param.data = param.data.to(torch.float32) + + for _, module in quantized_model.named_modules(): + if hasattr(module, "to_q"): + module.to_q = LoRALayer(module.to_q, rank=4) + if hasattr(module, "to_k"): + module.to_k = LoRALayer(module.to_k, rank=4) + if hasattr(module, "to_v"): + module.to_v = LoRALayer(module.to_v, rank=4) + + with torch.amp.autocast(str(torch_device), dtype=torch.bfloat16): + inputs = self.get_dummy_inputs() + output = quantized_model(**inputs)[0] + output.norm().backward() + + for module in quantized_model.modules(): + if isinstance(module, LoRALayer): + self.assertTrue(module.adapter[1].weight.grad is not None) + + +class SanaTransformerFP8WeightsTest(ModelOptBaseTesterMixin, unittest.TestCase): + expected_memory_reduction = 0.6 + + def get_dummy_init_kwargs(self): + return {"quant_type": "FP8"} + + +class SanaTransformerINT8WeightsTest(ModelOptBaseTesterMixin, unittest.TestCase): + expected_memory_reduction = 0.6 + _test_torch_compile = True + + def get_dummy_init_kwargs(self): + return {"quant_type": "INT8"} + + +@require_torch_cuda_compatibility(8.0) +class SanaTransformerINT4WeightsTest(ModelOptBaseTesterMixin, unittest.TestCase): + expected_memory_reduction = 0.55 + + def get_dummy_init_kwargs(self): + return { + "quant_type": "INT4", + "block_quantize": 128, + "channel_quantize": -1, + "disable_conv_quantization": True, + } + + +@require_torch_cuda_compatibility(8.0) +class SanaTransformerNF4WeightsTest(ModelOptBaseTesterMixin, unittest.TestCase): + expected_memory_reduction = 0.65 + + def get_dummy_init_kwargs(self): + return { + "quant_type": "NF4", + "block_quantize": 128, + "channel_quantize": -1, + "scale_block_quantize": 8, + "scale_channel_quantize": -1, + "modules_to_not_convert": ["conv"], + } + + +@require_torch_cuda_compatibility(8.0) +class SanaTransformerNVFP4WeightsTest(ModelOptBaseTesterMixin, unittest.TestCase): + expected_memory_reduction = 0.65 + + def get_dummy_init_kwargs(self): + return { + "quant_type": "NVFP4", + "block_quantize": 128, + "channel_quantize": -1, + "scale_block_quantize": 8, + "scale_channel_quantize": -1, + "modules_to_not_convert": ["conv"], + }