diff --git a/modelopt/torch/export/unified_export_megatron.py b/modelopt/torch/export/unified_export_megatron.py index f1bd67327..625e62f97 100644 --- a/modelopt/torch/export/unified_export_megatron.py +++ b/modelopt/torch/export/unified_export_megatron.py @@ -331,6 +331,9 @@ def save_pretrained( "exclude_modules": ["lm_head"], }, } + if "NVFP4" in quantization: + # for vllm, the group size is required + hf_quant_config["quantization"]["group_size"] = 16 with open(save_directory + "/hf_quant_config.json", "w") as f: json.dump(hf_quant_config, f, indent=4)