From ce8c562f0978f0a12064e2e06a6c3c6124cb12aa Mon Sep 17 00:00:00 2001 From: Fridah-nv <201670829+Fridah-nv@users.noreply.github.com> Date: Wed, 28 Jan 2026 22:09:59 +0000 Subject: [PATCH 1/2] fix FP8 sweep, do not requantize the scales Signed-off-by: Fridah-nv <201670829+Fridah-nv@users.noreply.github.com> --- modelopt/torch/quantization/model_calib.py | 12 ++++++++++-- .../quantization/nn/modules/tensor_quantizer.py | 7 ++++++- 2 files changed, 16 insertions(+), 3 deletions(-) diff --git a/modelopt/torch/quantization/model_calib.py b/modelopt/torch/quantization/model_calib.py index 0077d8666..89e69167b 100644 --- a/modelopt/torch/quantization/model_calib.py +++ b/modelopt/torch/quantization/model_calib.py @@ -245,10 +245,12 @@ def mse_calibrate( step_size: Step size for amax search (default: 0.1). start_multiplier: Starting multiplier for amax search (default: 0.25). stop_multiplier: Ending multiplier for amax search (default: 4.0). - fp8_scale_sweep: If True, sweep over all 128 possible FP8 E4M3 scale values + fp8_scale_sweep: If True, sweep over all 126 valid FP8 E4M3 scale values for NVFP4 per-block quantization instead of using multipliers. This is specifically designed for optimizing the FP8-quantized - per-block scales in NVFP4 format (default: False). + per-block scales in NVFP4 format. When enabled, sets + block_sizes["skip_fp8_scale_quant"] = True on NVFP4 quantizers + to skip dynamic FP8 scale quantization during inference (default: False). See :class:`MseCalibConfig ` for details on the remaining arguments. @@ -279,6 +281,12 @@ def mse_calibrate( "block quantization. fp8_scale_sweep will be ignored for this quantizer." ) + # Skip dynamic FP8 scale quantization as scales are pre-optimized via FP8 sweep + if fp8_scale_sweep and is_nvfp4_static: + if module._block_sizes is None: + module._block_sizes = {} + module._block_sizes["skip_fp8_scale_quant"] = True + # Create MSE calibrator with quant_func module._calibrator = MseCalibrator( amax=initial_amax, diff --git a/modelopt/torch/quantization/nn/modules/tensor_quantizer.py b/modelopt/torch/quantization/nn/modules/tensor_quantizer.py index e004cf0e7..a762761b9 100644 --- a/modelopt/torch/quantization/nn/modules/tensor_quantizer.py +++ b/modelopt/torch/quantization/nn/modules/tensor_quantizer.py @@ -764,11 +764,16 @@ def _fake_quantize(self, inputs): self._pass_through_bwd, ) elif self._num_bits == (2, 1) and self.is_static_block_quant: + skip_scale_quant = ( + self._block_sizes.get("skip_fp8_scale_quant", False) + if self._block_sizes is not None + else False + ) outputs = static_blockwise_fp4_fake_quant( inputs, None, # scale None, # scale_fp8_quant_amax - False, # skip_scale_quant + skip_scale_quant, # skip_scale_quant inputs.dtype, # out_dtype self._pass_through_bwd, # pass_through_bwd amax, # amax From 1caa24fc2b47bf404dc39c73de1a7191a71f5e54 Mon Sep 17 00:00:00 2001 From: Fridah-nv <201670829+Fridah-nv@users.noreply.github.com> Date: Wed, 28 Jan 2026 22:28:13 +0000 Subject: [PATCH 2/2] default skip_scale_quant to False Signed-off-by: Fridah-nv <201670829+Fridah-nv@users.noreply.github.com> --- modelopt/torch/quantization/model_calib.py | 10 +--------- .../torch/quantization/nn/modules/tensor_quantizer.py | 7 +------ 2 files changed, 2 insertions(+), 15 deletions(-) diff --git a/modelopt/torch/quantization/model_calib.py b/modelopt/torch/quantization/model_calib.py index 89e69167b..7171a2bf7 100644 --- a/modelopt/torch/quantization/model_calib.py +++ b/modelopt/torch/quantization/model_calib.py @@ -248,9 +248,7 @@ def mse_calibrate( fp8_scale_sweep: If True, sweep over all 126 valid FP8 E4M3 scale values for NVFP4 per-block quantization instead of using multipliers. This is specifically designed for optimizing the FP8-quantized - per-block scales in NVFP4 format. When enabled, sets - block_sizes["skip_fp8_scale_quant"] = True on NVFP4 quantizers - to skip dynamic FP8 scale quantization during inference (default: False). + per-block scales in NVFP4 format. See :class:`MseCalibConfig ` for details on the remaining arguments. @@ -281,12 +279,6 @@ def mse_calibrate( "block quantization. fp8_scale_sweep will be ignored for this quantizer." ) - # Skip dynamic FP8 scale quantization as scales are pre-optimized via FP8 sweep - if fp8_scale_sweep and is_nvfp4_static: - if module._block_sizes is None: - module._block_sizes = {} - module._block_sizes["skip_fp8_scale_quant"] = True - # Create MSE calibrator with quant_func module._calibrator = MseCalibrator( amax=initial_amax, diff --git a/modelopt/torch/quantization/nn/modules/tensor_quantizer.py b/modelopt/torch/quantization/nn/modules/tensor_quantizer.py index a762761b9..145fc5217 100644 --- a/modelopt/torch/quantization/nn/modules/tensor_quantizer.py +++ b/modelopt/torch/quantization/nn/modules/tensor_quantizer.py @@ -764,16 +764,11 @@ def _fake_quantize(self, inputs): self._pass_through_bwd, ) elif self._num_bits == (2, 1) and self.is_static_block_quant: - skip_scale_quant = ( - self._block_sizes.get("skip_fp8_scale_quant", False) - if self._block_sizes is not None - else False - ) outputs = static_blockwise_fp4_fake_quant( inputs, None, # scale None, # scale_fp8_quant_amax - skip_scale_quant, # skip_scale_quant + True, # skip_scale_quant inputs.dtype, # out_dtype self._pass_through_bwd, # pass_through_bwd amax, # amax