diff --git a/modelopt/torch/quantization/model_calib.py b/modelopt/torch/quantization/model_calib.py index 0077d8666..7171a2bf7 100644 --- a/modelopt/torch/quantization/model_calib.py +++ b/modelopt/torch/quantization/model_calib.py @@ -245,10 +245,10 @@ def mse_calibrate( step_size: Step size for amax search (default: 0.1). start_multiplier: Starting multiplier for amax search (default: 0.25). stop_multiplier: Ending multiplier for amax search (default: 4.0). - fp8_scale_sweep: If True, sweep over all 128 possible FP8 E4M3 scale values + fp8_scale_sweep: If True, sweep over all 126 valid FP8 E4M3 scale values for NVFP4 per-block quantization instead of using multipliers. This is specifically designed for optimizing the FP8-quantized - per-block scales in NVFP4 format (default: False). + per-block scales in NVFP4 format. See :class:`MseCalibConfig ` for details on the remaining arguments. diff --git a/modelopt/torch/quantization/nn/modules/tensor_quantizer.py b/modelopt/torch/quantization/nn/modules/tensor_quantizer.py index e004cf0e7..145fc5217 100644 --- a/modelopt/torch/quantization/nn/modules/tensor_quantizer.py +++ b/modelopt/torch/quantization/nn/modules/tensor_quantizer.py @@ -768,7 +768,7 @@ def _fake_quantize(self, inputs): inputs, None, # scale None, # scale_fp8_quant_amax - False, # skip_scale_quant + True, # skip_scale_quant inputs.dtype, # out_dtype self._pass_through_bwd, # pass_through_bwd amax, # amax