Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
133 changes: 133 additions & 0 deletions modelopt/torch/quantization/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,6 +387,69 @@
"algorithm": "max",
}

NVFP4_WEIGHT_ACT_MSE_CFG = {
"quant_cfg": {
"*weight_quantizer": {
"num_bits": (2, 1),
"block_sizes": {-1: 16, "type": "static", "scale_bits": (4, 3)},
"axis": None,
"enable": True,
},
"*input_quantizer": {
"num_bits": (2, 1),
"block_sizes": {-1: 16, "type": "dynamic", "scale_bits": (4, 3)},
"axis": None,
"enable": True,
},
**_default_disabled_quantizer_cfg,
},
"algorithm": {
"method": "mse",
"step_size": 0.25,
"start_multiplier": 0.25,
"stop_multiplier": 2.0,
},
}

NVFP4_WEIGHT_MSE_FP8_SWEEP_CFG = {
"quant_cfg": {
"*weight_quantizer": {
"num_bits": (2, 1),
"block_sizes": {-1: 16, "type": "static", "scale_bits": (4, 3)},
"axis": None,
"enable": True,
},
"*input_quantizer": {
"enable": False,
},
**_default_disabled_quantizer_cfg,
},
"algorithm": {
"method": "mse",
"fp8_scale_sweep": True,
},
}


NVFP4_LOCAL_HESSIAN_CFG = {
"quant_cfg": {
"*weight_quantizer": {
"num_bits": (2, 1),
"block_sizes": {-1: 16, "type": "static", "scale_bits": (4, 3)},
"axis": None,
"enable": True,
},
"*input_quantizer": {
"enable": False,
},
**_default_disabled_quantizer_cfg,
},
"algorithm": {
"method": "local_hessian",
"fp8_scale_sweep": True,
},
}

NVFP4_AWQ_LITE_CFG = {
"quant_cfg": {
"*weight_quantizer": {
Expand Down Expand Up @@ -1059,6 +1122,76 @@ class MseCalibConfig(QuantizeAlgorithmConfig):
)


class LocalHessianCalibConfig(QuantizeAlgorithmConfig):
"""Configuration for local Hessian-weighted MSE calibration.

This algorithm uses activation information to optimize per-block scales for weight
quantization. It minimizes the output reconstruction error by weighting the loss
with the local Hessian matrix computed from input activations.

The local Hessian loss for each block is: ``(dw @ H @ dw.T)`` where:
- ``dw = weight - quantized_weight`` (weight reconstruction error per block)
- ``H = X @ X.T`` is the local Hessian computed from input activations X

This method is particularly effective for NVFP4 weight-only quantization where
activation information helps select better per-block scales.

"""

method: Literal["local_hessian"] = ModeloptField("local_hessian")

step_size: float | None = ModeloptField(
default=0.1,
gt=0.0,
title="Step size for amax search.",
description="Step size between amax candidates. The number of candidates is computed as "
"ceil((stop_multiplier - start_multiplier) / step_size) + 1.",
)

start_multiplier: float | None = ModeloptField(
default=0.25,
gt=0.0,
title="Starting multiplier for amax search.",
description="Starting multiplier for amax search range (multiplies initial amax).",
)

stop_multiplier: float | None = ModeloptField(
default=4.0,
gt=0.0,
title="Ending multiplier for amax search.",
description="Ending multiplier for amax search range (multiplies initial amax).",
)

fp8_scale_sweep: bool | None = ModeloptField(
default=True,
title="Enable FP8 scale sweep for NVFP4 per-block quantization.",
description="If True, sweep over all 128 possible FP8 E4M3 scale values "
"for NVFP4 per-block quantization instead of using multipliers. "
"This is the recommended setting for NVFP4 quantization.",
)

block_size: int | None = ModeloptField(
default=16,
gt=0,
title="Block size for local Hessian computation.",
description="The block size used for computing the local Hessian matrix. "
"This should match the block size used in the quantization config. "
"Default is 16 for NVFP4.",
)

distributed_sync: bool | None = ModeloptField(
default=True,
title="Whether to sync the amax across the distributed processes.",
description="If True, the amax will be synced across the distributed processes.",
)

debug: bool | None = ModeloptField(
default=False,
title="Debug mode.",
description="If True, module's local Hessian metadata will be kept as a module attribute.",
)


class SmoothQuantCalibConfig(QuantizeAlgorithmConfig):
"""The config for ``smoothquant`` algorithm (SmoothQuant).

Expand Down
26 changes: 25 additions & 1 deletion modelopt/torch/quantization/mode.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
AWQFullCalibConfig,
AWQLiteCalibConfig,
CompressConfig,
LocalHessianCalibConfig,
MaxCalibConfig,
MseCalibConfig,
QuantizeAlgoCfgType,
Expand All @@ -55,7 +56,14 @@
restore_svdquant_model,
update_quantize_metadata,
)
from .model_calib import awq, max_calibrate, mse_calibrate, smoothquant, svdquant
from .model_calib import (
awq,
local_hessian_calibrate,
max_calibrate,
mse_calibrate,
smoothquant,
svdquant,
)

__all__ = ["BaseCalibrateModeDescriptor"]

Expand Down Expand Up @@ -376,6 +384,22 @@ def config_class(self) -> type[QuantizeAlgorithmConfig]:
_calib_func = mse_calibrate


@CalibrateModeRegistry.register_mode
class LocalHessianModeDescriptor(BaseCalibrateModeDescriptor):
"""Mode for local Hessian-weighted MSE calibration algorithm.

This algorithm uses activation information to optimize per-block scales for weight
quantization by minimizing output reconstruction error instead of weight reconstruction error.
"""

@property
def config_class(self) -> type[QuantizeAlgorithmConfig]:
"""Specifies the config class for the mode."""
return LocalHessianCalibConfig

_calib_func = local_hessian_calibrate


@CalibrateModeRegistry.register_mode
class SmoothQuantModeDescriptor(BaseCalibrateModeDescriptor):
"""Mode for smoothquant calibration algorithm."""
Expand Down
Loading