From 4d1380a2ed06b1d1b3f7110d138d115684b76d6a Mon Sep 17 00:00:00 2001 From: Fridah-nv <201670829+Fridah-nv@users.noreply.github.com> Date: Fri, 16 Jan 2026 00:16:33 +0000 Subject: [PATCH] add local hessian calibration Signed-off-by: Fridah-nv <201670829+Fridah-nv@users.noreply.github.com> --- modelopt/torch/quantization/config.py | 133 +++++++++++ modelopt/torch/quantization/mode.py | 26 ++- modelopt/torch/quantization/model_calib.py | 259 ++++++++++++++++++++- 3 files changed, 416 insertions(+), 2 deletions(-) diff --git a/modelopt/torch/quantization/config.py b/modelopt/torch/quantization/config.py index 9836648d0..4ca0ffc61 100644 --- a/modelopt/torch/quantization/config.py +++ b/modelopt/torch/quantization/config.py @@ -387,6 +387,69 @@ "algorithm": "max", } +NVFP4_WEIGHT_ACT_MSE_CFG = { + "quant_cfg": { + "*weight_quantizer": { + "num_bits": (2, 1), + "block_sizes": {-1: 16, "type": "static", "scale_bits": (4, 3)}, + "axis": None, + "enable": True, + }, + "*input_quantizer": { + "num_bits": (2, 1), + "block_sizes": {-1: 16, "type": "dynamic", "scale_bits": (4, 3)}, + "axis": None, + "enable": True, + }, + **_default_disabled_quantizer_cfg, + }, + "algorithm": { + "method": "mse", + "step_size": 0.25, + "start_multiplier": 0.25, + "stop_multiplier": 2.0, + }, +} + +NVFP4_WEIGHT_MSE_FP8_SWEEP_CFG = { + "quant_cfg": { + "*weight_quantizer": { + "num_bits": (2, 1), + "block_sizes": {-1: 16, "type": "static", "scale_bits": (4, 3)}, + "axis": None, + "enable": True, + }, + "*input_quantizer": { + "enable": False, + }, + **_default_disabled_quantizer_cfg, + }, + "algorithm": { + "method": "mse", + "fp8_scale_sweep": True, + }, +} + + +NVFP4_LOCAL_HESSIAN_CFG = { + "quant_cfg": { + "*weight_quantizer": { + "num_bits": (2, 1), + "block_sizes": {-1: 16, "type": "static", "scale_bits": (4, 3)}, + "axis": None, + "enable": True, + }, + "*input_quantizer": { + "enable": False, + }, + **_default_disabled_quantizer_cfg, + }, + "algorithm": { + "method": "local_hessian", + "fp8_scale_sweep": True, + }, +} + NVFP4_AWQ_LITE_CFG = { "quant_cfg": { "*weight_quantizer": { @@ -1059,6 +1122,76 @@ class MseCalibConfig(QuantizeAlgorithmConfig): ) +class LocalHessianCalibConfig(QuantizeAlgorithmConfig): + """Configuration for local Hessian-weighted MSE calibration. + + This algorithm uses activation information to optimize per-block scales for weight + quantization. It minimizes the output reconstruction error by weighting the loss + with the local Hessian matrix computed from input activations. + + The local Hessian loss for each block is: ``(dw @ H @ dw.T)`` where: + - ``dw = weight - quantized_weight`` (weight reconstruction error per block) + - ``H = X @ X.T`` is the local Hessian computed from input activations X + + This method is particularly effective for NVFP4 weight-only quantization where + activation information helps select better per-block scales. + + """ + + method: Literal["local_hessian"] = ModeloptField("local_hessian") + + step_size: float | None = ModeloptField( + default=0.1, + gt=0.0, + title="Step size for amax search.", + description="Step size between amax candidates. The number of candidates is computed as " + "ceil((stop_multiplier - start_multiplier) / step_size) + 1.", + ) + + start_multiplier: float | None = ModeloptField( + default=0.25, + gt=0.0, + title="Starting multiplier for amax search.", + description="Starting multiplier for amax search range (multiplies initial amax).", + ) + + stop_multiplier: float | None = ModeloptField( + default=4.0, + gt=0.0, + title="Ending multiplier for amax search.", + description="Ending multiplier for amax search range (multiplies initial amax).", + ) + + fp8_scale_sweep: bool | None = ModeloptField( + default=True, + title="Enable FP8 scale sweep for NVFP4 per-block quantization.", + description="If True, sweep over all 128 possible FP8 E4M3 scale values " + "for NVFP4 per-block quantization instead of using multipliers. " + "This is the recommended setting for NVFP4 quantization.", + ) + + block_size: int | None = ModeloptField( + default=16, + gt=0, + title="Block size for local Hessian computation.", + description="The block size used for computing the local Hessian matrix. " + "This should match the block size used in the quantization config. " + "Default is 16 for NVFP4.", + ) + + distributed_sync: bool | None = ModeloptField( + default=True, + title="Whether to sync the amax across the distributed processes.", + description="If True, the amax will be synced across the distributed processes.", + ) + + debug: bool | None = ModeloptField( + default=False, + title="Debug mode.", + description="If True, module's local Hessian metadata will be kept as a module attribute.", + ) + + class SmoothQuantCalibConfig(QuantizeAlgorithmConfig): """The config for ``smoothquant`` algorithm (SmoothQuant). diff --git a/modelopt/torch/quantization/mode.py b/modelopt/torch/quantization/mode.py index 53651bbcc..1cc0e3ab5 100644 --- a/modelopt/torch/quantization/mode.py +++ b/modelopt/torch/quantization/mode.py @@ -37,6 +37,7 @@ AWQFullCalibConfig, AWQLiteCalibConfig, CompressConfig, + LocalHessianCalibConfig, MaxCalibConfig, MseCalibConfig, QuantizeAlgoCfgType, @@ -55,7 +56,14 @@ restore_svdquant_model, update_quantize_metadata, ) -from .model_calib import awq, max_calibrate, mse_calibrate, smoothquant, svdquant +from .model_calib import ( + awq, + local_hessian_calibrate, + max_calibrate, + mse_calibrate, + smoothquant, + svdquant, +) __all__ = ["BaseCalibrateModeDescriptor"] @@ -376,6 +384,22 @@ def config_class(self) -> type[QuantizeAlgorithmConfig]: _calib_func = mse_calibrate +@CalibrateModeRegistry.register_mode +class LocalHessianModeDescriptor(BaseCalibrateModeDescriptor): + """Mode for local Hessian-weighted MSE calibration algorithm. + + This algorithm uses activation information to optimize per-block scales for weight + quantization by minimizing output reconstruction error instead of weight reconstruction error. + """ + + @property + def config_class(self) -> type[QuantizeAlgorithmConfig]: + """Specifies the config class for the mode.""" + return LocalHessianCalibConfig + + _calib_func = local_hessian_calibrate + + @CalibrateModeRegistry.register_mode class SmoothQuantModeDescriptor(BaseCalibrateModeDescriptor): """Mode for smoothquant calibration algorithm.""" diff --git a/modelopt/torch/quantization/model_calib.py b/modelopt/torch/quantization/model_calib.py index 0077d8666..f46703e8b 100644 --- a/modelopt/torch/quantization/model_calib.py +++ b/modelopt/torch/quantization/model_calib.py @@ -17,6 +17,7 @@ import math import warnings +from collections.abc import Callable from functools import partial import torch @@ -44,7 +45,7 @@ weight_attr_names, ) -__all__ = ["awq", "max_calibrate", "smoothquant", "svdquant"] +__all__ = ["awq", "local_hessian_calibrate", "max_calibrate", "smoothquant", "svdquant"] def weight_only_quantize(model: nn.Module): @@ -344,6 +345,262 @@ def mse_calibrate( # TODO: Sync amax across distributed processes +@torch.no_grad() +def local_hessian_calibrate( + model: nn.Module, + forward_loop: ForwardLoop | None = None, + distributed_sync: bool = True, + step_size: float = 0.1, + start_multiplier: float = 0.25, + stop_multiplier: float = 4.0, + fp8_scale_sweep: bool = True, + block_size: int = 16, + debug: bool = False, +): + """Calibrate the model using local Hessian-weighted MSE search. + + This calibration method collects input activations during forward pass, computes + per-block local Hessian matrices (H = X @ X.T), and uses them to weight the + MSE loss for scale selection. This minimizes output reconstruction error rather + than weight reconstruction error. + + Args: + model: Model to be calibrated. + forward_loop: A callable which takes the model as argument and + forwards calibration data through the model. Required for this algorithm. + distributed_sync: Whether to sync amax across distributed processes. + step_size: Step size for amax search (default: 0.1). + start_multiplier: Starting multiplier for amax search (default: 0.25). + stop_multiplier: Ending multiplier for amax search (default: 4.0). + fp8_scale_sweep: If True, sweep over all 128 possible FP8 E4M3 scale values + for NVFP4 per-block quantization (default: True). + block_size: Block size for local Hessian computation (default: 16). + debug: If True, keep the local Hessian metadata on modules. + + See :class:`LocalHessianCalibConfig ` + for details on the configuration options. + """ + if forward_loop is None: + warnings.warn("forward_loop must be provided for local_hessian; skipping local_hessian") + return + + class LocalHessianHelper: + """Helper class to collect activations and compute local Hessian per module.""" + + cache_mode: bool = False + + def __init__(self, module, name): + self.name = name + self.module = module + self.weight_shape = module.weight.shape # (cout, cin) + self.cout, self.cin = self.weight_shape + self.block_size = block_size + self.num_blocks_per_cin = self.cin // block_size + self.is_enabled = True + + # Accumulated Hessian per block: (cin // block_size, block_size, block_size) + self.hessian_per_block = torch.zeros( + self.num_blocks_per_cin, + block_size, + block_size, + dtype=torch.float32, + device=module.weight.device, + ) + self.num_samples = 0 + + def setup(self): + """Set up the forward hook to collect activations.""" + module = self.module + bind_forward_method(module, forward, "_forward_no_local_hessian") + + # Check if cin is divisible by block_size + if self.cin % self.block_size != 0: + warnings.warn( + f"Module {self.name}: input features ({self.cin}) not divisible by " + f"block_size ({self.block_size}). Skipping local Hessian for this module." + ) + self.is_enabled = False + + def cleanup(self): + """Clean up the forward hook.""" + unpatch_forward_method(self.module, "_forward_no_local_hessian") + if not debug: + if hasattr(self.module, "local_hessian"): + delattr(self.module, "local_hessian") + + def accumulate_hessian(self, input_tensor: torch.Tensor): + """Accumulate local Hessian from input activations. + + Args: + input_tensor: Input tensor of shape (..., cin) + """ + if not self.is_enabled: + return + + # Flatten to (num_tokens, cin) + x = input_tensor.reshape(-1, self.cin).T # (cin, num_tokens) + x = x.reshape(self.num_blocks_per_cin, self.block_size, -1) # (num_blocks, bs, n) + + # Compute H = X @ X.T for each block and accumulate + hessian_batch = (x @ x.transpose(-1, -2)).to(torch.float32) + self.hessian_per_block += hessian_batch + self.num_samples += input_tensor.numel() // self.cin + + def get_error_func(self) -> Callable[[torch.Tensor, torch.Tensor], torch.Tensor]: + """Get the local Hessian error function for MSE calibration.""" + cout = self.cout + bs = self.block_size + # Normalize hessian by number of samples + hessian = self.hessian_per_block / max(self.num_samples, 1) + + def local_hessian_error(x: torch.Tensor, xq: torch.Tensor) -> torch.Tensor: + """Compute local Hessian-weighted error.""" + original_shape = x.shape + dw = (x - xq).view(-1, 1, bs) # (num_blocks, 1, block_size) + # Repeat hessian for each output channel + hessian_expanded = hessian.repeat( + cout, 1, 1 + ) # (num_blocks, block_size, block_size) + # Per-block loss: (num_blocks,) + block_loss = (dw @ hessian_expanded @ dw.transpose(-1, -2)).squeeze(-1).squeeze(-1) + error = block_loss.unsqueeze(-1).expand(-1, bs).reshape(original_shape) + return error + + return local_hessian_error + + def forward(self, input, *args, **kwargs): + """Custom forward that collects activations in cache mode.""" + if LocalHessianHelper.cache_mode and self.local_hessian.is_enabled: + # Get local tensor from DTensor if applicable + input_local = input.to_local() if hasattr(input, "to_local") else input + self.local_hessian.accumulate_hessian(input_local) + + # Forward without quantization during caching + if LocalHessianHelper.cache_mode: + self.weight_quantizer.disable() + out = self._forward_no_local_hessian(input, *args, **kwargs) + self.weight_quantizer.enable() + return out + + return self._forward_no_local_hessian(input, *args, **kwargs) + + # Setup helpers for all quantized linear modules + name_to_module = dict(model.named_modules()) + weight_quantizers_info = [] + + for name, module in name_to_module.items(): + if is_quantized_linear(module) and module.weight_quantizer.is_enabled: + with enable_weight_access_and_writeback(module, model, name_to_module): + module.local_hessian = LocalHessianHelper(module, name) + module.local_hessian.setup() + if module.local_hessian.is_enabled: + weight_quantizers_info.append((name, module)) + + # Cache activations by running forward loop + LocalHessianHelper.cache_mode = True + print_rank_0("local_hessian: Caching activations and computing local Hessian...") + forward_loop(model) + + # TODO(fridah-nv): Sync Hessian across distributed processes if needed + + # Get initial amax using max calibration on weights + print_rank_0("local_hessian: Computing initial amax with max calibration...") + for name, module in weight_quantizers_info: + with enable_weight_access_and_writeback(module, model, name_to_module): + max_calibrate(module, lambda m: m.weight_quantizer(m.weight), distributed_sync) + + # Replace calibrators with MseCalibrator using local Hessian error function + print_rank_0("local_hessian: Running MSE calibration with local Hessian loss...") + for name, module in weight_quantizers_info: + weight_quantizer = module.weight_quantizer + helper = module.local_hessian + + if not hasattr(weight_quantizer, "_amax") or weight_quantizer._amax is None: + continue + + initial_amax = weight_quantizer._amax.clone().detach() + + def quant_func(x, amax, quantizer=weight_quantizer): + original_amax = quantizer._amax.clone() if hasattr(quantizer, "_amax") else None + quantizer._amax = amax + + with ( + enable_quant(quantizer), + disable_calib(quantizer), + enable_fake_quant(quantizer), + ): + if hasattr(quantizer, "_original_shape"): + x = quantizer._reset_to_original_shape(x) + xq = quantizer(x) + if hasattr(quantizer, "_block_reshape_size"): + xq = xq.reshape(quantizer._block_reshape_size) + + if original_amax is not None: + quantizer._amax = original_amax + else: + delattr(quantizer, "_amax") + + return xq + + is_nvfp4_per_block = ( + fp8_scale_sweep + and weight_quantizer.is_static_block_quant + and weight_quantizer._num_bits == (2, 1) + and weight_quantizer._block_sizes is not None + and weight_quantizer._block_sizes.get("scale_bits") == (4, 3) + ) + + error_func = helper.get_error_func() + + weight_quantizer._calibrator = MseCalibrator( + amax=initial_amax, + axis=weight_quantizer._calibrator._axis if weight_quantizer._calibrator else None, + step_size=step_size, + start_multiplier=start_multiplier, + stop_multiplier=stop_multiplier, + quant_func=quant_func, + error_func=error_func, + fp8_scale_sweep=is_nvfp4_per_block, + ) + + # Calibrate weights with local Hessian MSE + for name, module in weight_quantizers_info: + weight_quantizer = module.weight_quantizer + if weight_quantizer._calibrator is None: + continue + + weight_quantizer.disable_quant() + weight_quantizer.enable_calib() + + with enable_weight_access_and_writeback(module, model, name_to_module): + weight = module.weight + weight_quantizer(weight) + + # Compute optimal amax and load it + for name, module in weight_quantizers_info: + weight_quantizer = module.weight_quantizer + if weight_quantizer._calibrator is None: + continue + + cal = weight_quantizer._calibrator + if cal.compute_amax() is not None: + weight_quantizer.load_calib_amax() + + weight_quantizer.enable_quant() + weight_quantizer.disable_calib() + + # Cleanup and free memory + LocalHessianHelper.cache_mode = False + for name, module in weight_quantizers_info: + module.local_hessian.cleanup() + if hasattr(module.weight_quantizer, "_calibrator"): + cal = module.weight_quantizer._calibrator + if hasattr(cal, "clear"): + cal.clear() + + print_rank_0("local_hessian: Calibration complete.") + + def enable_stats_collection(model: nn.Module): """Enable stats collection for all quantizers in the model.""" for name, module in model.named_modules():