From 66ec3cb0e7fbec3d7efef21519a423ec24f24626 Mon Sep 17 00:00:00 2001 From: Kinjal Patel Date: Wed, 21 Jan 2026 23:37:34 +0000 Subject: [PATCH 01/13] Added support for HF modelopt state reload for vllm fakequant Signed-off-by: Kinjal Patel --- examples/llm_ptq/hf_ptq.py | 13 +- examples/vllm_serve/README.md | 54 ++- examples/vllm_serve/fakequant_worker.py | 361 +++++++----------- examples/vllm_serve/vllm_reload_utils.py | 237 ++++++++++++ examples/vllm_serve/vllm_serve_fakequant.py | 4 +- .../torch/export/plugins/vllm_fakequant_hf.py | 23 +- .../export/plugins/vllm_fakequant_megatron.py | 62 ++- modelopt/torch/quantization/conversion.py | 2 +- .../quantization/nn/modules/quant_module.py | 18 +- 9 files changed, 523 insertions(+), 251 deletions(-) create mode 100644 examples/vllm_serve/vllm_reload_utils.py diff --git a/examples/llm_ptq/hf_ptq.py b/examples/llm_ptq/hf_ptq.py index 7c91ca97f..3023161df 100755 --- a/examples/llm_ptq/hf_ptq.py +++ b/examples/llm_ptq/hf_ptq.py @@ -50,6 +50,7 @@ import modelopt.torch.sparsity as mts from modelopt.torch.export import ( export_hf_checkpoint, + export_hf_vllm_fq_checkpoint, export_tensorrt_llm_checkpoint, get_model_type, ) @@ -622,8 +623,10 @@ def export_quantized( "Unified HF export format does not specify inference tensor parallel or pipeline parallel. " "They will be set at deployment time." ) - - export_hf_checkpoint( + export_fn = ( + export_hf_vllm_fq_checkpoint if args.export_vllm_fq else export_hf_checkpoint + ) + export_fn( full_model, export_dir=export_path, ) @@ -1080,6 +1083,12 @@ def parse_args() -> argparse.Namespace: "(sensitivity scores, costs, etc.). Only used when auto_quantize_bits is specified." ), ) + parser.add_argument( + "--export_vllm_fq", + help="Export vLLM fakequant checkpoint.", + default=False, + action="store_true", + ) return parser.parse_args() diff --git a/examples/vllm_serve/README.md b/examples/vllm_serve/README.md index ff0c4eea3..64310fef4 100644 --- a/examples/vllm_serve/README.md +++ b/examples/vllm_serve/README.md @@ -23,9 +23,11 @@ You can either edit the `quant_config` dictionary in `vllm_serve_fakequant.py`, |-----------------|--------------------------------------------------|---------------------| | QUANT_DATASET | Dataset name for calibration | cnn_dailymail | | QUANT_CALIB_SIZE| Number of samples used for calibration | 512 | -| QUANT_CFG | Quantization format | NVFP4_DEFAULT_CFG | -| KV_QUANT_CFG | Quantization format for KV Cache | None | -| AMAX_FILE_PATH | Optional path to amax file (for loading amax) | None | +| QUANT_CFG | Quantization config | None | +| KV_QUANT_CFG | KV-cache quantization config | None | +| QUANT_FILE_PATH | Optional path to exported quantizer state dict `quantizer_state.pth` | None | +| MODELOPT_STATE_PATH | Optional path to exported `modelopt_state.pth` (restores ModelOpt mode + weights) | None | +| CALIB_BATCH_SIZE | Calibration batch size | 1 | Set these variables in your shell or Docker environment as needed to customize calibration. @@ -60,17 +62,49 @@ Overwrite the calibrated amax value with prepared values from either QAT/PTQ. Step 1: export the model with bf16 weights and amax values. To export the model: -- For HF model use `modelopt.torch.export.export_hf_vllm_fq_checkpoint` function. -- For MCore model use `modelopt.torch.export.export_mcore_gpt_to_hf_vllm_fq` function. +- For **HF** models, you can use `modelopt.torch.export.export_hf_vllm_fq_checkpoint`: -Step 2: configure from exported model using AMAX_FILE_PATH environment variable in step 1. For example: + ```python + import torch + from modelopt.torch.export import export_hf_vllm_fq_checkpoint + + with torch.inference_mode(): + export_hf_vllm_fq_checkpoint( + model, # The quantized model. + export_dir, # The directory where the exported files will be stored. + ) + ``` + Or run the example script `examples/llm_ptq/hf_ptq.py` with the `--export_vllm_fq` **flag** to export a vLLM-fakequant-compatible ModelOpt state (it generates `vllm_fq_modelopt_state.pth`, which you can use via `MODELOPT_STATE_PATH`). + +- For **MCore** models, use `modelopt.torch.export.export_mcore_gpt_to_hf_vllm_fq`: + + ```python + from modelopt.torch.export import export_mcore_gpt_to_hf_vllm_fq + export_mcore_gpt_to_hf_vllm_fq( + unwrapped_model, # Quantized MCore model + args.pretrained_model_name, # HF model id/path (for config/tokenizer) + export_dir=args.export_dir, # Directory where exported files will be stored + ) + + ``` + This generates `quantizer_state.pth`, which contains quantizer tensors for vLLM reload via `QUANT_FILE_PATH`. + +Step 2: use the exported artifacts when serving: + +- **HF export**: pass the exported `vllm_fq_modelopt_state.pth` via `MODELOPT_STATE_PATH` + +```bash +# HF +MODELOPT_STATE_PATH= python vllm_serve_fakequant.py -tp 8 --host 0.0.0.0 --port 8000 +``` + +- **MCore export**: pass the exported `quantizer_state.pth` via `QUANT_FILE_PATH` and set `QUANT_CFG` to match the MCore quantization recipe ```bash -AMAX_FILE_PATH= QUANT_CFG= python vllm_serve_fakequant.py -tp 8 --host 0.0.0.0 --port 8000 +# MCore +QUANT_CFG= QUANT_FILE_PATH= python vllm_serve_fakequant.py -tp 8 --host 0.0.0.0 --port 8000 ``` ## Known Problems -1. AWQ is not yet supported in vLLM. -2. QAT checkpoint export doesn't have KV Cache quantization enabled. KV Cache fake quantization works for PTQ. -3. Mixed precision checkpoint doesn't work currently. +1. **MCore reload does not use `MODELOPT_STATE_PATH`**; use `QUANT_FILE_PATH` and make sure `QUANT_CFG` matches the quantization recipe used for the original MCore model (otherwise quantizer keys/config won’t align). diff --git a/examples/vllm_serve/fakequant_worker.py b/examples/vllm_serve/fakequant_worker.py index 772c6fe66..81ea0379b 100644 --- a/examples/vllm_serve/fakequant_worker.py +++ b/examples/vllm_serve/fakequant_worker.py @@ -15,13 +15,12 @@ import dataclasses import os -import re import warnings -from collections import defaultdict from contextlib import contextmanager from typing import Any import torch +from vllm_reload_utils import convert_dict_to_vllm, convert_modelopt_state_to_vllm from tqdm import tqdm from transformers import AutoTokenizer from vllm.sampling_params import SamplingParams @@ -29,102 +28,10 @@ from vllm.v1.worker.gpu_worker import Worker as BaseWorker import modelopt.torch.quantization as mtq +from modelopt.torch.opt.conversion import restore_from_modelopt_state from modelopt.torch.utils.dataset_utils import get_dataset_dataloader -def convert_amax_hf2vllm( - hf_state_dict: dict[str, torch.Tensor], fuse_experts: bool = False -) -> dict[str, torch.Tensor]: - """ - Convert amax values from HuggingFace format to vLLM format. - - This function merges: - - q_proj, k_proj, v_proj amax values into qkv_proj (taking max) - - gate_proj, up_proj amax values into gate_up_proj (taking max) - - Args: - hf_state_dict: HuggingFace state dict containing amax values - - Returns: - vLLM format state dict with merged amax values - """ - vllm_state_dict = {} - - # Group keys by their base pattern (without the specific projection name) - merge_groups = defaultdict(list) - - for key, value in hf_state_dict.items(): - if "_amax" not in key: - # Copy non-amax keys as-is - vllm_state_dict[key] = value - continue - - # Check if this is a q/k/v projection that needs merging - qkv_match = re.search(r"(.*\.)([qkv])_proj(\..+_amax)$", key) - if qkv_match: - base_pattern = qkv_match.group(1) + "qkv_proj" + qkv_match.group(3) - merge_groups[base_pattern].append((key, value)) - continue - - # Check if this is an expert gate/up projection - # Pattern: model.layers.0.mlp.experts.*.gate_proj.input_quantizer._amax and - # model.layers.0.mlp.experts.*.up_proj.input_quantizer._amax - # Maps to: model.layers.0.mlp.experts.w13_input_quantizer._amax - expert_gate_up_match = ( - "mixer" not in key - and fuse_experts - and re.search(r"(.*\.experts)\.\d+\.(gate|up)_proj\.([^.]+_quantizer\._amax)$", key) - ) - if expert_gate_up_match: - base_pattern = expert_gate_up_match.group(1) + ".w13_" + expert_gate_up_match.group(3) - merge_groups[base_pattern].append((key, value)) - continue - - # Check if this is a non-expert gate/up projection that needs merging - gate_up_match = ( - "mixer" not in key - and "experts" not in key - and re.search(r"(.*\.)(gate|up)_proj(\..+_amax)$", key) - ) - if gate_up_match: - base_pattern = gate_up_match.group(1) + "gate_up_proj" + gate_up_match.group(3) - merge_groups[base_pattern].append((key, value)) - continue - - # Check if this is an expert down_proj - # Pattern: model.layers.0.mlp.experts.*.down_proj.input_quantizer._amax - # Maps to: model.layers.0.mlp.experts.w2_input_quantizer._amax - expert_down_match = ( - "mixer" not in key - and fuse_experts - and re.search(r"(.*\.experts)\.\d+\.down_proj\.([^.]+_quantizer\._amax)$", key) - ) - if expert_down_match: - base_pattern = expert_down_match.group(1) + ".w2_" + expert_down_match.group(2) - merge_groups[base_pattern].append((key, value)) - continue - - # Copy other amax keys as-is (like o_proj, down_proj) - vllm_state_dict[key] = value - - # Merge grouped amax values by taking the maximum - for merged_key, key_value_pairs in merge_groups.items(): - if len(key_value_pairs) > 1: - # Take the maximum across all values for this merged key - values = [value for _, value in key_value_pairs] - merged_value = torch.stack(values).max(dim=0)[0] - vllm_state_dict[merged_key] = merged_value - print(f"Merged {len(key_value_pairs)} keys into {merged_key}") - for orig_key, _ in key_value_pairs: - print(f" - {orig_key}") - else: - # Single key, just rename it - _, value = key_value_pairs[0] - vllm_state_dict[merged_key] = value - - return vllm_state_dict - - @contextmanager def disable_compilation(model): do_not_compile = True @@ -151,7 +58,9 @@ def disable_compilation(model): "calib_size": int(os.environ.get("QUANT_CALIB_SIZE", 512)), "quant_cfg": os.environ.get("QUANT_CFG", None), "kv_quant_cfg": os.environ.get("KV_QUANT_CFG", None), - "amax_file_path": os.environ.get("AMAX_FILE_PATH", None), + "quant_file_path": os.environ.get("QUANT_FILE_PATH", None), + "modelopt_state_path": os.environ.get("MODELOPT_STATE_PATH", None), + "calib_batch_size": int(os.environ.get("CALIB_BATCH_SIZE", 1)), } @@ -194,137 +103,151 @@ def _fakequant_run_prolog_worker(self) -> None: if tokenizer.pad_token != "" or tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token - if quant_config["amax_file_path"]: - print("Will load amax, so only do a single sample calibration") - quant_config["calib_size"] = 1 - - calib_dataloader = get_dataset_dataloader( - dataset_name=quant_config["dataset"], - tokenizer=tokenizer, - batch_size=1, - num_samples=quant_config["calib_size"], - device=self.device, - ) - - def calibrate_loop(model: Any = None) -> None: - for batch_idx, batch in tqdm(enumerate(calib_dataloader)): - input_ids = batch["input_ids"][0] - - # Convert tensor to list of integers for vLLM compatibility - if torch.is_tensor(input_ids): - input_ids_list = input_ids.cpu().tolist() - else: - input_ids_list = list(input_ids) - - num_groups = len(self.model_runner.kv_cache_config.kv_cache_groups) - empty_block_ids = tuple([] for _ in range(num_groups)) - - req_id = f"req-{batch_idx}" - # Pass all possible parameters - the helper will filter based on vLLM version - new_req = _create_new_data_cls( - NewRequestData, - req_id=req_id, - prompt_token_ids=input_ids_list, - # Old API parameters - mm_kwargs=[], # TODO: remove this when vllm <= 0.11 is outdated - mm_hashes=[], # TODO: remove this when vllm <= 0.11 is outdated - mm_positions=[], # TODO: remove this when vllm <= 0.11 is outdated - # New API parameter - mm_features=[], - sampling_params=SamplingParams(max_tokens=1), - pooling_params=None, - block_ids=empty_block_ids, - num_computed_tokens=0, - lora_request=None, - ) - - scheduler_output = _create_new_data_cls( - SchedulerOutput, - scheduled_new_reqs=[new_req], - scheduled_cached_reqs=CachedRequestData.make_empty(), - num_scheduled_tokens={req_id: len(input_ids_list)}, - total_num_scheduled_tokens=len(input_ids_list), - scheduled_spec_decode_tokens={}, - scheduled_encoder_inputs={}, - num_common_prefix_blocks=[0] * num_groups, - finished_req_ids=set(), - free_encoder_mm_hashes=[], - kv_connector_metadata=None, - # Old API parameters - structured_output_request_ids={}, # TODO: remove this when vllm <= 0.11 is outdated - grammar_bitmask=None, # TODO: remove this when vllm <= 0.11 is outdated - ) - output = self.execute_model(scheduler_output) - if hasattr(self, "sample_tokens"): - if output is None: # TODO: make this default when vllm <= 0.11 is outdated - self.sample_tokens(None) - - quant_cfg = {} if quant_config["quant_cfg"] is None else getattr(mtq, quant_config["quant_cfg"]) - quant_kv_cfg = ( - {} if quant_config["kv_quant_cfg"] is None else getattr(mtq, quant_config["kv_quant_cfg"]) - ) - model = self.model_runner.model - if hasattr(model, "unwrap"): - model = model.unwrap() + if quant_config["modelopt_state_path"]: + print(f"Loading modelopt state from {quant_config['modelopt_state_path']}") + modelopt_state = torch.load(quant_config["modelopt_state_path"], weights_only=False) + modelopt_weights = modelopt_state.pop("modelopt_state_weights", None) + modelopt_state = convert_modelopt_state_to_vllm(modelopt_state) + restore_from_modelopt_state(model, modelopt_state) - # Check if model has MLA and update KV config accordingly - if quant_kv_cfg: - quant_kv_cfg["quant_cfg"] = update_kv_cfg_for_mla(model, quant_kv_cfg["quant_cfg"]) + if modelopt_weights is not None: + modelopt_weights = convert_dict_to_vllm(modelopt_weights) + mtq.utils.set_quantizer_state_dict(model, modelopt_weights) - if quant_kv_cfg: - quant_cfg = mtq.utils.update_quant_cfg_with_kv_cache_quant( - quant_cfg, quant_kv_cfg["quant_cfg"] + else: + if quant_config["quant_file_path"]: + print("Will load quant, so only do a single sample calibration") + quant_config["calib_size"] = 1 + calib_dataloader = get_dataset_dataloader( + dataset_name=quant_config["dataset"], + tokenizer=tokenizer, + batch_size=quant_config["calib_batch_size"], + num_samples=quant_config["calib_size"], + device=self.device, ) - with disable_compilation(model): - print("quantizing model...") - mtq.quantize(model, quant_cfg, forward_loop=calibrate_loop) - - amax_file_path = quant_config["amax_file_path"] - if amax_file_path: - print(f"Loading amax values from {amax_file_path}") - saved_amax_dict = torch.load(amax_file_path) - # convert amax keys to vLLM format - if hasattr(self.model_runner.model, "hf_to_vllm_mapper"): - saved_amax_dict = self.model_runner.model.hf_to_vllm_mapper.apply_dict(saved_amax_dict) - saved_amax_dict = { - key.replace("quantizer_amax", "quantizer._amax"): value - for key, value in saved_amax_dict.items() - if key.endswith("quantizer_amax") - } - saved_amax_dict = convert_amax_hf2vllm(saved_amax_dict, fuse_experts=True) - - current_state_dict = model.state_dict() - # Count amax keys in checkpoint and model - checkpoint_amax_keys = [key for key in saved_amax_dict if key.endswith("_amax")] - model_amax_keys = [key for key in current_state_dict if key.endswith("_amax")] - for key in checkpoint_amax_keys: - if key not in model_amax_keys: - print(f"Key {key} not found in model state dict, but exists in checkpoint") - for key in model_amax_keys: - if key not in checkpoint_amax_keys: - raise ValueError( - f"Key {key} not found in checkpoint state dict, but exists in model" + def calibrate_loop(model: Any = None) -> None: + for batch_idx, batch in tqdm(enumerate(calib_dataloader)): + input_ids = batch["input_ids"][0] + + # Convert tensor to list of integers for vLLM compatibility + if torch.is_tensor(input_ids): + input_ids_list = input_ids.cpu().tolist() + else: + input_ids_list = list(input_ids) + + num_groups = len(self.model_runner.kv_cache_config.kv_cache_groups) + empty_block_ids = tuple([] for _ in range(num_groups)) + + req_id = f"req-{batch_idx}" + # Pass all possible parameters - the helper will filter based on vLLM version + new_req = _create_new_data_cls( + NewRequestData, + req_id=req_id, + prompt_token_ids=input_ids_list, + # Old API parameters + mm_kwargs=[], # TODO: remove this when vllm <= 0.11 is outdated + mm_hashes=[], # TODO: remove this when vllm <= 0.11 is outdated + mm_positions=[], # TODO: remove this when vllm <= 0.11 is outdated + # New API parameter + mm_features=[], + sampling_params=SamplingParams(max_tokens=1), + pooling_params=None, + block_ids=empty_block_ids, + num_computed_tokens=0, + lora_request=None, ) - checkpoint_amax_count = len(checkpoint_amax_keys) - model_amax_count = len(model_amax_keys) + scheduler_output = _create_new_data_cls( + SchedulerOutput, + scheduled_new_reqs=[new_req], + scheduled_cached_reqs=CachedRequestData.make_empty(), + num_scheduled_tokens={req_id: len(input_ids_list)}, + total_num_scheduled_tokens=len(input_ids_list), + scheduled_spec_decode_tokens={}, + scheduled_encoder_inputs={}, + num_common_prefix_blocks=[0] * num_groups, + finished_req_ids=set(), + free_encoder_mm_hashes=[], + kv_connector_metadata=None, + # Old API parameters + structured_output_request_ids={}, # TODO: remove this when vllm <= 0.11 is outdated + grammar_bitmask=None, # TODO: remove this when vllm <= 0.11 is outdated + ) + output = self.execute_model(scheduler_output) + if hasattr(self, "sample_tokens"): + if output is None: # TODO: make this default when vllm <= 0.11 is outdated + self.sample_tokens(None) + + quant_cfg = getattr(mtq, quant_config["quant_cfg"]) if quant_config["quant_cfg"] else {} + quant_kv_cfg = ( + getattr(mtq, quant_config["kv_quant_cfg"]) if quant_config["kv_quant_cfg"] else {} + ) + + if hasattr(model, "unwrap"): + model = model.unwrap() - # Ensure counts match - if checkpoint_amax_count != model_amax_count: - warnings.warn( - f"Mismatch in amax key counts: checkpoint has {checkpoint_amax_count} " - f"amax keys but model has {model_amax_count} amax keys. This can happen if the model is using PP." + # Check if model has MLA and update KV config accordingly + if quant_kv_cfg: + quant_kv_cfg["quant_cfg"] = update_kv_cfg_for_mla(model, quant_kv_cfg["quant_cfg"]) + + if quant_kv_cfg: + quant_cfg = mtq.utils.update_quant_cfg_with_kv_cache_quant( + quant_cfg, quant_kv_cfg["quant_cfg"] ) - # Update amax values - for key, value in saved_amax_dict.items(): - if key in current_state_dict: - current_state_dict[key] = value.to(current_state_dict[key].device) + with disable_compilation(model): + print("quantizing model...") + mtq.quantize(model, quant_cfg, forward_loop=calibrate_loop) + + quantizer_file_path = quant_config["quant_file_path"] + if quantizer_file_path: + print(f"Loading quantizer values from {quantizer_file_path}") + saved_quant_dict = torch.load(quantizer_file_path) + # convert quant keys to vLLM format + if hasattr(self.model_runner.model, "hf_to_vllm_mapper"): + saved_quant_dict = self.model_runner.model.hf_to_vllm_mapper.apply_dict( + saved_quant_dict + ) + saved_quant_dict = { + key.replace("quantizer_", "quantizer._"): value + for key, value in saved_quant_dict.items() + if key.endswith("quantizer_") + } + saved_quant_dict = convert_dict_to_vllm(saved_quant_dict) + + current_state_dict = model.state_dict() + # Count quant keys in checkpoint and model + checkpoint_quant_keys = [key for key in saved_quant_dict if "quantizer" in key] + model_quant_keys = [key for key in current_state_dict if "quantizer" in key] + for key in checkpoint_quant_keys: + if key not in model_quant_keys: + print(f"Key {key} not found in model state dict, but exists in checkpoint") + for key in model_quant_keys: + if key not in checkpoint_quant_keys: + raise ValueError( + f"Key {key} not found in checkpoint state dict, but exists in model" + ) + + checkpoint_quant_count = len(checkpoint_quant_keys) + model_quant_count = len(model_quant_keys) + + # Ensure counts match + if checkpoint_quant_count != model_quant_count: + warnings.warn( + f"Mismatch in quantizer state key counts: checkpoint has {checkpoint_quant_count} " + f"quant keys but model has {model_quant_count} quantizer state keys. " + f"This can happen if the model is using PP." + ) + + # Update quant values + for key, value in saved_quant_dict.items(): + if key in current_state_dict: + current_state_dict[key] = value.to(current_state_dict[key].device) - model.load_state_dict(current_state_dict) - torch.distributed.barrier() + model.load_state_dict(current_state_dict) + torch.distributed.barrier() if not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0: mtq.print_quant_summary(model) @@ -345,6 +268,10 @@ def determine_available_memory(self) -> int: return super().determine_available_memory() def compile_or_warm_up_model(self) -> None: - if quant_config["quant_cfg"] or quant_config["kv_quant_cfg"]: + if ( + quant_config["quant_cfg"] + or quant_config["kv_quant_cfg"] + or quant_config["modelopt_state_path"] + ): _fakequant_run_prolog_worker(self) super().compile_or_warm_up_model() diff --git a/examples/vllm_serve/vllm_reload_utils.py b/examples/vllm_serve/vllm_reload_utils.py new file mode 100644 index 000000000..9bf8e6eb4 --- /dev/null +++ b/examples/vllm_serve/vllm_reload_utils.py @@ -0,0 +1,237 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import re +import torch + +from collections import defaultdict +from typing import Any, Callable + + +def _values_equal(v1: Any, v2: Any) -> bool: + """Compare values, handling dicts with tensors.""" + if isinstance(v1, dict) and isinstance(v2, dict): + if v1.keys() != v2.keys(): + return False + return all( + torch.equal(v1[k], v2[k]) if isinstance(v1[k], torch.Tensor) else v1[k] == v2[k] + for k in v1.keys() + ) + elif isinstance(v1, torch.Tensor) and isinstance(v2, torch.Tensor): + return torch.equal(v1, v2) + return v1 == v2 + + +def _convert_key_for_vllm(key: str, value: Any) -> tuple[str | None, str | None, Any]: + """ + Transform a single key from HuggingFace format to vLLM format. + + Returns: + Tuple of (action, new_key_or_group, value) where action is one of: + - "copy": Copy value to new_key directly + - "group": Add to merge group identified by new_key + - "skip": Skip this key entirely + """ + if "quantizer" not in key: + return ("copy", key, value) + + # Skip softmax_quantizer (not needed in vLLM) + if "softmax_quantizer" in key: + return ("skip", None, None) + + # Skip lm_head quantizers (not needed in vLLM) + if key.startswith("lm_head.") and "quantizer" in key: + return ("skip", None, None) + + # Check if this is a q/k/v projection that needs merging + qkv_match = re.search(r"(.*\.)([qkv])_proj\.([^.]+_quantizer)(\..+)?$", key) + if qkv_match: + suffix = qkv_match.group(4) or "" + group_key = qkv_match.group(1) + "qkv_proj." + qkv_match.group(3) + suffix + return ("group", group_key, value) + + # Check if this is an expert gate/up projection + if "mixer" not in key: + expert_gate_up_match = re.search( + r"(.*\.experts)\.\d+\.(gate|up)_proj\.([^.]+_quantizer)(\..+)?$", key + ) + if expert_gate_up_match: + suffix = expert_gate_up_match.group(4) or "" + group_key = expert_gate_up_match.group(1) + ".w13_" + expert_gate_up_match.group(3) + suffix + return ("group", group_key, value) + + # Check if this is a non-expert gate/up projection that needs merging + if "mixer" not in key and "experts" not in key: + gate_up_match = re.search(r"(.*\.)(gate|up)_proj\.([^.]+_quantizer)(\..+)?$", key) + if gate_up_match: + suffix = gate_up_match.group(4) or "" + group_key = gate_up_match.group(1) + "gate_up_proj." + gate_up_match.group(3) + suffix + return ("group", group_key, value) + + # Check if this is an expert down_proj + if "mixer" not in key: + expert_down_match = re.search( + r"(.*\.experts)\.\d+\.down_proj\.([^.]+_quantizer)(\..+)?$", key + ) + if expert_down_match: + suffix = expert_down_match.group(3) or "" + group_key = expert_down_match.group(1) + ".w2_" + expert_down_match.group(2) + suffix + return ("group", group_key, value) + + # Transform bmm_quantizer keys: self_attn.q/k/v_bmm_quantizer -> self_attn.attn.q/k/v_bmm_quantizer + bmm_match = re.search(r"(.*\.self_attn)\.([qkv]_bmm_quantizer.*)$", key) + if bmm_match: + new_key = bmm_match.group(1) + ".attn." + bmm_match.group(2) + # Debug: show device of amax values + if isinstance(value, dict): + for k, v in value.items(): + if isinstance(v, torch.Tensor): + print(f"Renamed {key} -> {new_key}, {k} device: {v.device}") + elif isinstance(value, torch.Tensor): + print(f"Renamed {key} -> {new_key}, device: {value.device}") + else: + print(f"Renamed {key} -> {new_key}") + return ("copy", new_key, value) + + # Copy other quantizer keys as-is (like o_proj, down_proj) + return ("copy", key, value) + + +def _group_keys_for_vllm( + state_dict: dict[str, Any] +) -> tuple[dict[str, Any], dict[str, list[tuple[str, Any]]]]: + """ + Process state dict and group keys that need merging. + + Returns: + Tuple of (direct_copy_dict, merge_groups) + """ + vllm_state_dict = {} + merge_groups = defaultdict(list) + + for key, value in state_dict.items(): + action, new_key, new_value = _convert_key_for_vllm(key, value) + + if action == "copy": + vllm_state_dict[new_key] = new_value + elif action == "group": + merge_groups[new_key].append((key, new_value)) + # action == "skip" does nothing + + return vllm_state_dict, merge_groups + + +def _merge_values_by_max_or_concat( + merged_key: str, key_value_pairs: list[tuple[str, Any]] +) -> Any: + """ + Merge values by taking max for amax, concatenating for others. + Used for quantizer state weights (tensor values). + """ + values = [value for _, value in key_value_pairs] + + # Check if values are dicts (OrderedDict) containing tensors + if isinstance(values[0], dict): + merged_value = {} + for dict_key in values[0].keys(): + tensors = [v[dict_key] for v in values] + if "_amax" in dict_key: + merged_value[dict_key] = torch.stack(tensors).max(dim=0)[0] + else: + merged_value[dict_key] = torch.cat(tensors, dim=0) + return merged_value + else: + # Values are tensors directly + if "_amax" in merged_key: + merged_value = torch.stack(values).max(dim=0)[0] + else: + merged_value = torch.cat(values, dim=0) + return merged_value + + +def _merge_values_require_identical( + merged_key: str, key_value_pairs: list[tuple[str, Any]] +) -> Any: + """ + Merge values by requiring all values to be identical. + Used for quantizer state (config/metadata). + """ + keys = [k for k, _ in key_value_pairs] + values = [v for _, v in key_value_pairs] + first_value = values[0] + + for i, val in enumerate(values[1:], start=1): + if not _values_equal(val, first_value): + raise ValueError( + f"Cannot merge keys into '{merged_key}': values differ.\n" + f" '{keys[0]}' has value: {first_value}\n" + f" '{keys[i]}' has value: {val}" + ) + return first_value + + +def convert_dict_to_vllm( + state_dict: dict[str, Any], + merge_mode: str = "max_or_concat" +) -> dict[str, Any]: + """ + Common implementation for converting quantizer state from HF to vLLM format. + + Args: + state_dict: Input state dict + fuse_experts: Whether to fuse expert projections + merge_mode: Mode to merge grouped values, "max_or_concat" or "require_identical" + """ + vllm_state_dict, merge_groups = _group_keys_for_vllm(state_dict) + + merge_fn = _merge_values_require_identical if merge_mode == "require_identical" else _merge_values_by_max_or_concat + + # Merge grouped values + for merged_key, key_value_pairs in merge_groups.items(): + if len(key_value_pairs) > 1: + merged_value = merge_fn(merged_key, key_value_pairs) + vllm_state_dict[merged_key] = merged_value + else: + # Single key, just rename it + _, value = key_value_pairs[0] + vllm_state_dict[merged_key] = value + + return vllm_state_dict + + +def convert_modelopt_state_to_vllm(modelopt_state: dict[str, Any]) -> dict[str, Any]: + """ + Convert modelopt state from HuggingFace format to vLLM compatible format. + + This function converts the quantizer state from HuggingFace format to vLLM compatible format. + + Args: + modelopt_state: HuggingFace modelopt state dict + + Returns: + vLLM compatible modelopt state dict + """ + modelopt_state_dict = modelopt_state.pop("modelopt_state_dict", []) + for idx, current_mode in enumerate(modelopt_state_dict): + current_mode_metadata = current_mode[1].pop("metadata", {}) + current_mode_quant_state = current_mode_metadata.pop("quantizer_state", {}) + if current_mode_quant_state: + current_mode_metadata["quantizer_state"] = convert_dict_to_vllm(current_mode_quant_state, merge_mode="require_identical") + else: + current_mode_metadata.pop("quantizer_state", None) + current_mode[1]['metadata'] = current_mode_metadata + modelopt_state_dict[idx] = (current_mode[0], current_mode[1]) + modelopt_state["modelopt_state_dict"] = modelopt_state_dict + return modelopt_state diff --git a/examples/vllm_serve/vllm_serve_fakequant.py b/examples/vllm_serve/vllm_serve_fakequant.py index 25483f2be..c32593005 100644 --- a/examples/vllm_serve/vllm_serve_fakequant.py +++ b/examples/vllm_serve/vllm_serve_fakequant.py @@ -74,8 +74,10 @@ "QUANT_DATASET", "QUANT_CALIB_SIZE", "QUANT_CFG", - "AMAX_FILE_PATH", + "QUANT_FILE_PATH", "KV_QUANT_CFG", + "MODELOPT_STATE_PATH", + "CALIB_BATCH_SIZE", } RayDistributedExecutor.ADDITIONAL_ENV_VARS.update(additional_env_vars) diff --git a/modelopt/torch/export/plugins/vllm_fakequant_hf.py b/modelopt/torch/export/plugins/vllm_fakequant_hf.py index 54987b40c..03b191346 100644 --- a/modelopt/torch/export/plugins/vllm_fakequant_hf.py +++ b/modelopt/torch/export/plugins/vllm_fakequant_hf.py @@ -19,7 +19,8 @@ import torch import torch.nn as nn -from modelopt.torch.export.layer_utils import is_quantlinear +import modelopt.torch.opt as mto +from modelopt.torch.export.layer_utils import is_attention, is_quantlinear from modelopt.torch.quantization.utils import get_quantizer_state_dict __all__ = ["export_hf_vllm_fq_checkpoint"] @@ -44,12 +45,11 @@ def export_hf_vllm_fq_checkpoint( export_dir = Path(export_dir) export_dir.mkdir(parents=True, exist_ok=True) - amax_dict = { - name + "._amax": param["_amax"].detach().clone().cpu() - for name, param in get_quantizer_state_dict(model).items() - if "_amax" in param - } + quantizer_state_dict = get_quantizer_state_dict(model) + modelopt_state = mto.modelopt_state(model) + modelopt_state["modelopt_state_weights"] = quantizer_state_dict + torch.save(modelopt_state, f"{export_dir}/vllm_fq_modelopt_state.pth") # remove quantizer from model for _, module in model.named_modules(): if is_quantlinear(module): @@ -57,6 +57,15 @@ def export_hf_vllm_fq_checkpoint( if hasattr(module, attr): delattr(module, attr) module.export() - torch.save(amax_dict, f"{export_dir}/quant_amax.pth") + if is_attention(module): + for attr in [ + "q_bmm_quantizer", + "k_bmm_quantizer", + "v_bmm_quantizer", + "softmax_quantizer", + ]: + if hasattr(module, attr): + delattr(module, attr) + # Save model model.save_pretrained(export_dir, state_dict=model.state_dict(), save_modelopt_state=False) diff --git a/modelopt/torch/export/plugins/vllm_fakequant_megatron.py b/modelopt/torch/export/plugins/vllm_fakequant_megatron.py index 95b194c3f..0549e2e9d 100644 --- a/modelopt/torch/export/plugins/vllm_fakequant_megatron.py +++ b/modelopt/torch/export/plugins/vllm_fakequant_megatron.py @@ -22,6 +22,7 @@ from modelopt.torch.export.model_config import QUANTIZATION_NONE from modelopt.torch.export.unified_export_megatron import GPTModelExporter +from modelopt.torch.quantization.utils import get_quantizer_state_dict __all__ = ["export_mcore_gpt_to_hf_vllm_fq"] @@ -38,8 +39,8 @@ def gather_mcore_vllm_fq_quantized_state_dict( Returns: The state dictionary of the module without quantized state. """ - amax_state_dict = { - k: v.detach().clone().cpu() for k, v in state_dict.items() if k.endswith("_amax") + quantizer_state_dict = { + k: v.detach().clone().cpu() for k, v in state_dict.items() if "quantizer" in k } # Gather all amax dicts to rank 0 @@ -48,20 +49,19 @@ def gather_mcore_vllm_fq_quantized_state_dict( if rank == 0: # Rank 0 will collect all amax values - all_amax_dicts = [None] * world_size - torch.distributed.gather_object(amax_state_dict, all_amax_dicts, dst=0) + all_quantizer_state_dicts = [None] * world_size + torch.distributed.gather_object(quantizer_state_dict, all_quantizer_state_dicts, dst=0) - # Merge all amax dicts into one - merged_amax_dict = {} - for amax_dict in all_amax_dicts: - if amax_dict is not None: - merged_amax_dict.update(amax_dict) + # Merge all quantizer state dicts into one + merged_quantizer_state_dict = {} + for quantizer_state_dict in all_quantizer_state_dicts: + if quantizer_state_dict is not None: + merged_quantizer_state_dict.update(quantizer_state_dict) - print(f"Total amax entries from all ranks: {len(merged_amax_dict.keys())}") - torch.save(merged_amax_dict, save_directory + "/quant_amax.pth") + torch.save(merged_quantizer_state_dict, save_directory + "/quantizer_state.pth") else: # Other ranks just send their amax values - torch.distributed.gather_object(amax_state_dict, None, dst=0) + torch.distributed.gather_object(quantizer_state_dict, None, dst=0) torch.distributed.barrier() @@ -76,6 +76,13 @@ def save_pretrained( ): os.makedirs(save_directory, exist_ok=True) gather_mcore_vllm_fq_quantized_state_dict(self.model, self.state_dict, save_directory) + + # NOTE: `self.state_dict` is an OrderedDict; mutating it while iterating + # over its keys raises "OrderedDict mutated during iteration". + keys_to_remove = [k for k in self.state_dict if "quantizer" in k] + for k in keys_to_remove: + self.state_dict.pop(k, None) + assert not (self.is_multimodal and pretrained_model_name_or_path is not None), ( "Exporting weights in bf16 and amax values is not supported for multimodal models " "when pretrained_model_name_or_path is not None" @@ -88,6 +95,37 @@ def save_pretrained( def _get_quantization_format(self, module: torch.nn.Module): return QUANTIZATION_NONE + def _get_quantized_state( + self, + module: torch.nn.Module, + dtype: torch.dtype = torch.float16, + ) -> tuple[dict[str, torch.Tensor], str, int]: + """Return a state_dict, quantization format, and block_size of the module. + + Args: + module: The target module to perform real quantization. + dtype: The default data type. + + Returns: + Tuple: state_dict, quantization format, and block_size of the module. + """ + name_to_value = {} + qformat: str = self._get_quantization_format(module) + block_size = 0 + + if hasattr(module, "weight") and module.weight is not None: + weight = module.weight.to(dtype).cpu() + name_to_value["weight"] = weight + else: + return name_to_value, qformat, block_size + + if hasattr(module, "bias") and module.bias is not None: + name_to_value["bias"] = module.bias.to(dtype).cpu() + for name, param in get_quantizer_state_dict(module).items(): + for key, value in param.items(): + name_to_value[name + "." + key] = value.to(dtype).cpu() + return name_to_value, qformat, block_size + def export_mcore_gpt_to_hf_vllm_fq( model: torch.nn.Module, diff --git a/modelopt/torch/quantization/conversion.py b/modelopt/torch/quantization/conversion.py index c93ea546f..fce232eb6 100644 --- a/modelopt/torch/quantization/conversion.py +++ b/modelopt/torch/quantization/conversion.py @@ -130,7 +130,7 @@ def restore_quantizer_state(model: nn.Module, config: QuantizeConfig, metadata: for name, module in model.named_modules(): if isinstance(module, QuantModule): name = get_unwrapped_name(name, model) - module.modelopt_post_restore(name) + module.modelopt_post_restore(name, model=model) return model diff --git a/modelopt/torch/quantization/nn/modules/quant_module.py b/modelopt/torch/quantization/nn/modules/quant_module.py index 12aaee3f8..d1756dc27 100644 --- a/modelopt/torch/quantization/nn/modules/quant_module.py +++ b/modelopt/torch/quantization/nn/modules/quant_module.py @@ -37,7 +37,7 @@ class QuantModule(DynamicModule): """A base class for quantized modules.""" - def modelopt_post_restore(self, prefix: str = ""): + def modelopt_post_restore(self, prefix: str = "", model: "torch.nn.Module | None" = None): """Post-restore to correctly configure the TensorQuantizer states. TensorQuantizer states are restored to their shape before saving. Now we need to further configure them. @@ -46,6 +46,10 @@ def modelopt_post_restore(self, prefix: str = ""): 2. For sharded modules the restored states of TensorQuantizer could be incorrect. This is because parallelism such as TP might have been changed between saving and resoring. So we need to re-calculate the state shapes. Hence such modules should override this and implement their own logic. + + Args: + prefix: The module name prefix for error messages. + model: Optional main model to search for device if not found in this module. """ # Get a parameter or buffer that does not belong to a TensorQuantizer non_tq_param_or_buffer = None @@ -55,6 +59,18 @@ def modelopt_post_restore(self, prefix: str = ""): non_tq_param_or_buffer = param_or_buffer break + # If not found (e.g., container modules like vLLM's attn that only have child quantizers), + # traverse up to parent's parent to find a module with parameters + if model is not None: + parts = prefix.split(".") + parent_prefix = ".".join(parts[: len(parts) - 1]) + parent_module = model.get_submodule(parent_prefix) + # Look for any parameter in parent module (not just state_dict) + for param in parent_module.parameters(): + # Skip if param belongs to a TensorQuantizer + non_tq_param_or_buffer = param + break + if non_tq_param_or_buffer is None: warnings.warn( f"Could not identify the device for TensorQuantizer states of {prefix}. " From 00444292adbccbacdd5555f910cb2b64e4e4197f Mon Sep 17 00:00:00 2001 From: Kinjal Patel Date: Wed, 21 Jan 2026 23:39:44 +0000 Subject: [PATCH 02/13] changelog update Signed-off-by: Kinjal Patel --- CHANGELOG.rst | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.rst b/CHANGELOG.rst index e615627b2..10575a18e 100755 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -14,6 +14,7 @@ NVIDIA Model Optimizer Changelog (Linux) - Add support for Kimi K2 Thinking model quantization from the original int4 checkpoint. - Add support for ``params`` constraint based automatic neural architecture search in Minitron pruning (``mcore_minitron``) as an alternative to manual pruning (using ``export_config``). See `examples/pruning/README.md `_ for more details on its usage. - Add support for calibration data with multiple samples in ``npz`` format in the ONNX Autocast workflow. +- Add support for vLLM fakequant reload using ModelOpt state for HF models. See `examples/vllm_serve/README.md `_ for more details. 0.41 (2026-01-19) ^^^^^^^^^^^^^^^^^ From 2130d373ad67072dc01141642ad69803c31d442f Mon Sep 17 00:00:00 2001 From: Kinjal Patel Date: Thu, 22 Jan 2026 01:16:37 +0000 Subject: [PATCH 03/13] minor Signed-off-by: Kinjal Patel --- examples/vllm_serve/README.md | 2 + examples/vllm_serve/fakequant_worker.py | 2 +- examples/vllm_serve/vllm_reload_utils.py | 57 +++++++++++++----------- 3 files changed, 33 insertions(+), 28 deletions(-) diff --git a/examples/vllm_serve/README.md b/examples/vllm_serve/README.md index 64310fef4..74a1f2510 100644 --- a/examples/vllm_serve/README.md +++ b/examples/vllm_serve/README.md @@ -74,6 +74,7 @@ Step 1: export the model with bf16 weights and amax values. To export the model: export_dir, # The directory where the exported files will be stored. ) ``` + Or run the example script `examples/llm_ptq/hf_ptq.py` with the `--export_vllm_fq` **flag** to export a vLLM-fakequant-compatible ModelOpt state (it generates `vllm_fq_modelopt_state.pth`, which you can use via `MODELOPT_STATE_PATH`). - For **MCore** models, use `modelopt.torch.export.export_mcore_gpt_to_hf_vllm_fq`: @@ -87,6 +88,7 @@ Step 1: export the model with bf16 weights and amax values. To export the model: ) ``` + This generates `quantizer_state.pth`, which contains quantizer tensors for vLLM reload via `QUANT_FILE_PATH`. Step 2: use the exported artifacts when serving: diff --git a/examples/vllm_serve/fakequant_worker.py b/examples/vllm_serve/fakequant_worker.py index 81ea0379b..a5f1b0332 100644 --- a/examples/vllm_serve/fakequant_worker.py +++ b/examples/vllm_serve/fakequant_worker.py @@ -20,12 +20,12 @@ from typing import Any import torch -from vllm_reload_utils import convert_dict_to_vllm, convert_modelopt_state_to_vllm from tqdm import tqdm from transformers import AutoTokenizer from vllm.sampling_params import SamplingParams from vllm.v1.core.sched.output import CachedRequestData, NewRequestData, SchedulerOutput from vllm.v1.worker.gpu_worker import Worker as BaseWorker +from vllm_reload_utils import convert_dict_to_vllm, convert_modelopt_state_to_vllm import modelopt.torch.quantization as mtq from modelopt.torch.opt.conversion import restore_from_modelopt_state diff --git a/examples/vllm_serve/vllm_reload_utils.py b/examples/vllm_serve/vllm_reload_utils.py index 9bf8e6eb4..bbd77be5b 100644 --- a/examples/vllm_serve/vllm_reload_utils.py +++ b/examples/vllm_serve/vllm_reload_utils.py @@ -14,10 +14,10 @@ # limitations under the License. import re -import torch - from collections import defaultdict -from typing import Any, Callable +from typing import Any + +import torch def _values_equal(v1: Any, v2: Any) -> bool: @@ -27,14 +27,14 @@ def _values_equal(v1: Any, v2: Any) -> bool: return False return all( torch.equal(v1[k], v2[k]) if isinstance(v1[k], torch.Tensor) else v1[k] == v2[k] - for k in v1.keys() + for k in v1 ) elif isinstance(v1, torch.Tensor) and isinstance(v2, torch.Tensor): return torch.equal(v1, v2) return v1 == v2 -def _convert_key_for_vllm(key: str, value: Any) -> tuple[str | None, str | None, Any]: +def _convert_key_for_vllm(key: str, value: Any) -> tuple[str, str | None, Any]: """ Transform a single key from HuggingFace format to vLLM format. @@ -47,12 +47,8 @@ def _convert_key_for_vllm(key: str, value: Any) -> tuple[str | None, str | None, if "quantizer" not in key: return ("copy", key, value) - # Skip softmax_quantizer (not needed in vLLM) - if "softmax_quantizer" in key: - return ("skip", None, None) - - # Skip lm_head quantizers (not needed in vLLM) - if key.startswith("lm_head.") and "quantizer" in key: + # Skip softmax_quantizer and lm_head quantizers(not needed in vLLM) + if "softmax_quantizer" in key or (key.startswith("lm_head.") and "quantizer" in key): return ("skip", None, None) # Check if this is a q/k/v projection that needs merging @@ -69,7 +65,9 @@ def _convert_key_for_vllm(key: str, value: Any) -> tuple[str | None, str | None, ) if expert_gate_up_match: suffix = expert_gate_up_match.group(4) or "" - group_key = expert_gate_up_match.group(1) + ".w13_" + expert_gate_up_match.group(3) + suffix + group_key = ( + expert_gate_up_match.group(1) + ".w13_" + expert_gate_up_match.group(3) + suffix + ) return ("group", group_key, value) # Check if this is a non-expert gate/up projection that needs merging @@ -110,8 +108,8 @@ def _convert_key_for_vllm(key: str, value: Any) -> tuple[str | None, str | None, def _group_keys_for_vllm( - state_dict: dict[str, Any] -) -> tuple[dict[str, Any], dict[str, list[tuple[str, Any]]]]: + state_dict: dict[str, Any], +) -> tuple[dict[str, Any], defaultdict[str, list[tuple[str, Any]]]]: """ Process state dict and group keys that need merging. @@ -123,7 +121,11 @@ def _group_keys_for_vllm( for key, value in state_dict.items(): action, new_key, new_value = _convert_key_for_vllm(key, value) - + if new_key is None or new_value is None: + assert action == "skip", ( + f"Expected action to be 'skip' for key {key}, value {value}, got {action}" + ) + continue if action == "copy": vllm_state_dict[new_key] = new_value elif action == "group": @@ -133,9 +135,7 @@ def _group_keys_for_vllm( return vllm_state_dict, merge_groups -def _merge_values_by_max_or_concat( - merged_key: str, key_value_pairs: list[tuple[str, Any]] -) -> Any: +def _merge_values_by_max_or_concat(merged_key: str, key_value_pairs: list[tuple[str, Any]]) -> Any: """ Merge values by taking max for amax, concatenating for others. Used for quantizer state weights (tensor values). @@ -145,7 +145,7 @@ def _merge_values_by_max_or_concat( # Check if values are dicts (OrderedDict) containing tensors if isinstance(values[0], dict): merged_value = {} - for dict_key in values[0].keys(): + for dict_key in values[0]: tensors = [v[dict_key] for v in values] if "_amax" in dict_key: merged_value[dict_key] = torch.stack(tensors).max(dim=0)[0] @@ -161,9 +161,7 @@ def _merge_values_by_max_or_concat( return merged_value -def _merge_values_require_identical( - merged_key: str, key_value_pairs: list[tuple[str, Any]] -) -> Any: +def _merge_values_require_identical(merged_key: str, key_value_pairs: list[tuple[str, Any]]) -> Any: """ Merge values by requiring all values to be identical. Used for quantizer state (config/metadata). @@ -183,8 +181,7 @@ def _merge_values_require_identical( def convert_dict_to_vllm( - state_dict: dict[str, Any], - merge_mode: str = "max_or_concat" + state_dict: dict[str, Any], merge_mode: str = "max_or_concat" ) -> dict[str, Any]: """ Common implementation for converting quantizer state from HF to vLLM format. @@ -196,7 +193,11 @@ def convert_dict_to_vllm( """ vllm_state_dict, merge_groups = _group_keys_for_vllm(state_dict) - merge_fn = _merge_values_require_identical if merge_mode == "require_identical" else _merge_values_by_max_or_concat + merge_fn = ( + _merge_values_require_identical + if merge_mode == "require_identical" + else _merge_values_by_max_or_concat + ) # Merge grouped values for merged_key, key_value_pairs in merge_groups.items(): @@ -228,10 +229,12 @@ def convert_modelopt_state_to_vllm(modelopt_state: dict[str, Any]) -> dict[str, current_mode_metadata = current_mode[1].pop("metadata", {}) current_mode_quant_state = current_mode_metadata.pop("quantizer_state", {}) if current_mode_quant_state: - current_mode_metadata["quantizer_state"] = convert_dict_to_vllm(current_mode_quant_state, merge_mode="require_identical") + current_mode_metadata["quantizer_state"] = convert_dict_to_vllm( + current_mode_quant_state, merge_mode="require_identical" + ) else: current_mode_metadata.pop("quantizer_state", None) - current_mode[1]['metadata'] = current_mode_metadata + current_mode[1]["metadata"] = current_mode_metadata modelopt_state_dict[idx] = (current_mode[0], current_mode[1]) modelopt_state["modelopt_state_dict"] = modelopt_state_dict return modelopt_state From eabd8ca5a81e79c8a6e97e834ea13fc3a086aec3 Mon Sep 17 00:00:00 2001 From: Kinjal Patel Date: Thu, 22 Jan 2026 21:58:00 +0000 Subject: [PATCH 04/13] updated for TP>1 Signed-off-by: Kinjal Patel --- examples/vllm_serve/README.md | 4 +- examples/vllm_serve/fakequant_worker.py | 23 +++++++++-- examples/vllm_serve/vllm_reload_utils.py | 50 +++++++++++++++++++----- 3 files changed, 63 insertions(+), 14 deletions(-) diff --git a/examples/vllm_serve/README.md b/examples/vllm_serve/README.md index 74a1f2510..60002b747 100644 --- a/examples/vllm_serve/README.md +++ b/examples/vllm_serve/README.md @@ -26,7 +26,7 @@ You can either edit the `quant_config` dictionary in `vllm_serve_fakequant.py`, | QUANT_CFG | Quantization config | None | | KV_QUANT_CFG | KV-cache quantization config | None | | QUANT_FILE_PATH | Optional path to exported quantizer state dict `quantizer_state.pth` | None | -| MODELOPT_STATE_PATH | Optional path to exported `modelopt_state.pth` (restores ModelOpt mode + weights) | None | +| MODELOPT_STATE_PATH | Optional path to exported `vllm_fq_modelopt_state.pth` (restores quantizer state and parameters) | None | | CALIB_BATCH_SIZE | Calibration batch size | 1 | Set these variables in your shell or Docker environment as needed to customize calibration. @@ -110,3 +110,5 @@ QUANT_CFG= QUANT_FILE_PATH= python vllm_serve_fa ## Known Problems 1. **MCore reload does not use `MODELOPT_STATE_PATH`**; use `QUANT_FILE_PATH` and make sure `QUANT_CFG` matches the quantization recipe used for the original MCore model (otherwise quantizer keys/config won’t align). +2. AWQ reload is not supported yet +3. KV cache quantization export and reload is not supported in MCore yet. diff --git a/examples/vllm_serve/fakequant_worker.py b/examples/vllm_serve/fakequant_worker.py index a5f1b0332..fe822c877 100644 --- a/examples/vllm_serve/fakequant_worker.py +++ b/examples/vllm_serve/fakequant_worker.py @@ -25,7 +25,11 @@ from vllm.sampling_params import SamplingParams from vllm.v1.core.sched.output import CachedRequestData, NewRequestData, SchedulerOutput from vllm.v1.worker.gpu_worker import Worker as BaseWorker -from vllm_reload_utils import convert_dict_to_vllm, convert_modelopt_state_to_vllm +from vllm_reload_utils import ( + convert_dict_to_vllm, + convert_modelopt_state_to_vllm, + process_state_dict_for_tp, +) import modelopt.torch.quantization as mtq from modelopt.torch.opt.conversion import restore_from_modelopt_state @@ -106,7 +110,11 @@ def _fakequant_run_prolog_worker(self) -> None: model = self.model_runner.model if quant_config["modelopt_state_path"]: print(f"Loading modelopt state from {quant_config['modelopt_state_path']}") - modelopt_state = torch.load(quant_config["modelopt_state_path"], weights_only=False) + # Load on CPU to avoid failures when the checkpoint was saved from a different + # GPU mapping + modelopt_state = torch.load( + quant_config["modelopt_state_path"], weights_only=False, map_location="cpu" + ) modelopt_weights = modelopt_state.pop("modelopt_state_weights", None) modelopt_state = convert_modelopt_state_to_vllm(modelopt_state) restore_from_modelopt_state(model, modelopt_state) @@ -203,8 +211,11 @@ def calibrate_loop(model: Any = None) -> None: quantizer_file_path = quant_config["quant_file_path"] if quantizer_file_path: + self.model_runner._dummy_run(1) print(f"Loading quantizer values from {quantizer_file_path}") - saved_quant_dict = torch.load(quantizer_file_path) + # Load on CPU to avoid failures when the checkpoint was saved from a different + # GPU mapping + saved_quant_dict = torch.load(quantizer_file_path, map_location="cpu") # convert quant keys to vLLM format if hasattr(self.model_runner.model, "hf_to_vllm_mapper"): saved_quant_dict = self.model_runner.model.hf_to_vllm_mapper.apply_dict( @@ -242,12 +253,16 @@ def calibrate_loop(model: Any = None) -> None: ) # Update quant values + saved_quant_dict = process_state_dict_for_tp(saved_quant_dict, current_state_dict) for key, value in saved_quant_dict.items(): if key in current_state_dict: current_state_dict[key] = value.to(current_state_dict[key].device) model.load_state_dict(current_state_dict) - torch.distributed.barrier() + + # Only barrier if distributed is actually initialized (avoids deadlocks). + if torch.distributed.is_initialized() and torch.distributed.get_world_size() > 1: + torch.distributed.barrier() if not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0: mtq.print_quant_summary(model) diff --git a/examples/vllm_serve/vllm_reload_utils.py b/examples/vllm_serve/vllm_reload_utils.py index bbd77be5b..a72cb0b3c 100644 --- a/examples/vllm_serve/vllm_reload_utils.py +++ b/examples/vllm_serve/vllm_reload_utils.py @@ -18,6 +18,7 @@ from typing import Any import torch +from vllm.distributed.parallel_state import get_tp_group def _values_equal(v1: Any, v2: Any) -> bool: @@ -92,15 +93,6 @@ def _convert_key_for_vllm(key: str, value: Any) -> tuple[str, str | None, Any]: bmm_match = re.search(r"(.*\.self_attn)\.([qkv]_bmm_quantizer.*)$", key) if bmm_match: new_key = bmm_match.group(1) + ".attn." + bmm_match.group(2) - # Debug: show device of amax values - if isinstance(value, dict): - for k, v in value.items(): - if isinstance(v, torch.Tensor): - print(f"Renamed {key} -> {new_key}, {k} device: {v.device}") - elif isinstance(value, torch.Tensor): - print(f"Renamed {key} -> {new_key}, device: {value.device}") - else: - print(f"Renamed {key} -> {new_key}") return ("copy", new_key, value) # Copy other quantizer keys as-is (like o_proj, down_proj) @@ -238,3 +230,43 @@ def convert_modelopt_state_to_vllm(modelopt_state: dict[str, Any]) -> dict[str, modelopt_state_dict[idx] = (current_mode[0], current_mode[1]) modelopt_state["modelopt_state_dict"] = modelopt_state_dict return modelopt_state + + +def process_state_dict_for_tp(saved_qstate_dict, current_state_dict): + """Shard quantizer tensors for tensor parallelism by matching expected shapes.""" + tp_group = get_tp_group() + tp_rank = tp_group.rank_in_group + tp_world_size = tp_group.world_size + + result = {} + for key, value in saved_qstate_dict.items(): + if key in current_state_dict: + expected_shape = current_state_dict[key].shape + if value.shape != expected_shape: + # Find the dimension that was tensor-parallel sharded. + # We expect exactly one dimension to satisfy: + # checkpoint_dim == expected_dim * tp_world_size + shard_dims = [ + d + for d in range(len(expected_shape)) + if value.shape[d] == expected_shape[d] * tp_world_size + ] + if len(shard_dims) != 1: + raise ValueError( + f"Cannot infer TP shard dim for {key}: " + f"expected_shape={tuple(expected_shape)}, checkpoint_shape={tuple(value.shape)}, " + ) + + shard_dim = shard_dims[0] + shard_size = expected_shape[shard_dim] + start = tp_rank * shard_size + end = start + shard_size + if end > value.shape[shard_dim]: + raise ValueError( + f"TP shard out of bounds for {key}: " + f"expected_shape={tuple(expected_shape)}, checkpoint_shape={tuple(value.shape)})" + ) + value = value.narrow(shard_dim, start, shard_size).contiguous() + result[key] = value + + return result From 45e637a26a7f02ec72b50ed7986f5b764fe13cb6 Mon Sep 17 00:00:00 2001 From: Kinjal Patel Date: Thu, 22 Jan 2026 23:33:06 +0000 Subject: [PATCH 05/13] minor Signed-off-by: Kinjal Patel --- .../torch/quantization/nn/modules/quant_module.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/modelopt/torch/quantization/nn/modules/quant_module.py b/modelopt/torch/quantization/nn/modules/quant_module.py index d1756dc27..0037fa062 100644 --- a/modelopt/torch/quantization/nn/modules/quant_module.py +++ b/modelopt/torch/quantization/nn/modules/quant_module.py @@ -61,15 +61,18 @@ def modelopt_post_restore(self, prefix: str = "", model: "torch.nn.Module | None # If not found (e.g., container modules like vLLM's attn that only have child quantizers), # traverse up to parent's parent to find a module with parameters - if model is not None: + if non_tq_param_or_buffer is None and model is not None: parts = prefix.split(".") parent_prefix = ".".join(parts[: len(parts) - 1]) - parent_module = model.get_submodule(parent_prefix) + parent_module = model.get_submodule(parent_prefix) if parent_prefix else model # Look for any parameter in parent module (not just state_dict) - for param in parent_module.parameters(): - # Skip if param belongs to a TensorQuantizer - non_tq_param_or_buffer = param - break + for name, param in parent_module.named_parameters(): + # Skip params that belong to TensorQuantizer submodules + param_parent_name = name.rsplit(".", 1)[0] if "." in name else "" + param_parent = parent_module.get_submodule(param_parent_name) + if not isinstance(param_parent, TensorQuantizer): + non_tq_param_or_buffer = param + break if non_tq_param_or_buffer is None: warnings.warn( From 720192819650fe04a701c71139fb503f5ca4f9ab Mon Sep 17 00:00:00 2001 From: Kinjal Patel Date: Mon, 26 Jan 2026 04:26:57 +0000 Subject: [PATCH 06/13] updated test Signed-off-by: Kinjal Patel --- .../export/test_vllm_fakequant_hf_export.py | 38 +++++++++---------- 1 file changed, 19 insertions(+), 19 deletions(-) diff --git a/tests/gpu/torch/export/test_vllm_fakequant_hf_export.py b/tests/gpu/torch/export/test_vllm_fakequant_hf_export.py index a156ad126..2c2c44ef2 100644 --- a/tests/gpu/torch/export/test_vllm_fakequant_hf_export.py +++ b/tests/gpu/torch/export/test_vllm_fakequant_hf_export.py @@ -48,7 +48,7 @@ def forward_loop(model): model(input_ids) model = mtq.quantize(model, quant_cfg, forward_loop) - + quantizer_state_dict_before = mtq.utils.get_quantizer_state_dict(model) model_state_dict = deepcopy(model.state_dict()) # Export directory @@ -59,8 +59,10 @@ def forward_loop(model): export_hf_vllm_fq_checkpoint(model, export_dir=export_dir) # check if quant_amax.pth file exists - quant_amax_file = export_dir / "quant_amax.pth" - assert quant_amax_file.exists(), f"quant_amax.pth file should be created in {export_dir}" + modelopt_state_file = export_dir / "vllm_fq_modelopt_state.pth" + assert modelopt_state_file.exists(), ( + f"vllm_fq_modelopt_state.pth file should be created in {export_dir}" + ) # make sure hf_quant_config.json file does not exist hf_quant_config_file = export_dir / "hf_quant_config.json" @@ -73,21 +75,19 @@ def forward_loop(model): model_after = model_after.cuda() model_after.eval() model_after_state_dict = model_after.state_dict() - amax_state_dict = {} for key, param in model_state_dict.items(): - if key.endswith("_amax"): - amax_state_dict[key] = param - continue - - assert torch.allclose(param, model_after_state_dict[key], atol=1e-6), ( - f"Weight mismatch for {key}: " - f"before shape={param.shape}, after shape={model_after_state_dict[key].shape}, " - f"max diff={torch.abs(param - model_after_state_dict[key]).max()}" - ) - - # Verify amax values are correct - amax_dict = torch.load(quant_amax_file) - assert len(amax_dict) > 0, "amax_dict should not be empty" - assert amax_dict.keys() == amax_state_dict.keys(), ( - "amax keys mismatch between before and after export" + if "quantizer" not in key: + assert torch.allclose(param, model_after_state_dict[key], atol=1e-6), ( + f"Weight mismatch for {key}: " + f"before shape={param.shape}, after shape={model_after_state_dict[key].shape}, " + f"max diff={torch.abs(param - model_after_state_dict[key]).max()}" + ) + + # Verify quantizer state dict values are correct + quantizer_state_dict = torch.load(modelopt_state_file)["modelopt_state_weights"] + assert len(quantizer_state_dict) > 0, ( + f"modelopt_state_weights should not be empty in {modelopt_state_file}" + ) + assert quantizer_state_dict.keys() == quantizer_state_dict_before.keys(), ( + "quantizer state dict keys mismatch between before and after export" ) From 365cbfa6618c25baf8d8830e852d75c6985d9a0c Mon Sep 17 00:00:00 2001 From: Kinjal Patel Date: Mon, 26 Jan 2026 15:41:56 +0000 Subject: [PATCH 07/13] test fix Signed-off-by: Kinjal Patel --- tests/gpu/torch/export/test_vllm_fakequant_megatron_export.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/gpu/torch/export/test_vllm_fakequant_megatron_export.py b/tests/gpu/torch/export/test_vllm_fakequant_megatron_export.py index ea351db6a..2fff24ea1 100644 --- a/tests/gpu/torch/export/test_vllm_fakequant_megatron_export.py +++ b/tests/gpu/torch/export/test_vllm_fakequant_megatron_export.py @@ -99,8 +99,8 @@ def forward_loop(model): ) # check if quant_amax.pth file exists - quant_amax_file = export_dir / "quant_amax.pth" - assert quant_amax_file.exists(), f"quant_amax.pth file should be created in {export_dir}" + quant_amax_file = export_dir / "quantizer_state.pth" + assert quant_amax_file.exists(), f"quantizer_state.pth file should be created in {export_dir}" # make sure hf_quant_config.json file does not exist hf_quant_config_file = export_dir / "hf_quant_config.json" From e03ff03787b96c07e00ee903dc36704cab03aee1 Mon Sep 17 00:00:00 2001 From: Kinjal Patel Date: Mon, 26 Jan 2026 18:10:03 +0000 Subject: [PATCH 08/13] minor Signed-off-by: Kinjal Patel --- modelopt/torch/quantization/plugins/custom.py | 2 +- modelopt/torch/quantization/plugins/megatron.py | 4 ++-- modelopt/torch/quantization/plugins/transformer_engine.py | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/modelopt/torch/quantization/plugins/custom.py b/modelopt/torch/quantization/plugins/custom.py index 4200aadc7..09a91796d 100644 --- a/modelopt/torch/quantization/plugins/custom.py +++ b/modelopt/torch/quantization/plugins/custom.py @@ -114,7 +114,7 @@ def _setup(self): # the dtype can change later. self.original_weight_dtype = None if self.weight is None else self.weight.dtype - def modelopt_post_restore(self, prefix: str = ""): + def modelopt_post_restore(self, prefix: str = "", *args, **kwargs): """Post restore to correctly configure the TensorQuantizer states for MCore/distributed frameworks. ModelOpt restores the TensorQuantizer states such as `_amax` and `_pre_quant_scale` to their diff --git a/modelopt/torch/quantization/plugins/megatron.py b/modelopt/torch/quantization/plugins/megatron.py index f47451eb0..630b6e4ce 100644 --- a/modelopt/torch/quantization/plugins/megatron.py +++ b/modelopt/torch/quantization/plugins/megatron.py @@ -471,7 +471,7 @@ def _get_shard_axis_dict(self, state_dict): shard_axis_dict[k] = self._scale_tensor_shard_axis return shard_axis_dict - def modelopt_post_restore(self, prefix: str = ""): + def modelopt_post_restore(self, prefix: str = "", *args, **kwargs): """Post restore to correctly configure the realquant scales. ModelOpt restores the TensorQuantizer states such as `_amax` and `_pre_quant_scale` to their @@ -715,7 +715,7 @@ def forward(self, query, key, value, *args, **kwargs): value = self.v_bmm_quantizer(value) return super().forward(query, key, value, *args, **kwargs) - def modelopt_post_restore(self, name=""): + def modelopt_post_restore(self, name="", *args, **kwargs): """Restore quantizer states after model loading.""" for tq in [self.q_bmm_quantizer, self.k_bmm_quantizer, self.v_bmm_quantizer]: # TODO: Add support for non-scalar states such as diff --git a/modelopt/torch/quantization/plugins/transformer_engine.py b/modelopt/torch/quantization/plugins/transformer_engine.py index cec7ff956..70e6e7e2c 100644 --- a/modelopt/torch/quantization/plugins/transformer_engine.py +++ b/modelopt/torch/quantization/plugins/transformer_engine.py @@ -141,7 +141,7 @@ def _setup(self): # TODO: GroupedLinear supports weights split by `num_gemms`, to support quantization # with static parameters beyond per-tensor, we need to support a unique quantizer for each gemm. - def modelopt_post_restore(self, prefix: str = ""): + def modelopt_post_restore(self, prefix: str = "", *args, **kwargs): # GroupedMLP stores the weights as weight0, weight1, etc. To run post_restore in order to # initialize the quantizer states, self.weight is used to extract shape, dtype etc. Assigning # self.weight0 to self.weight to run the quantizer states initialization. From 5c97d37584176bef82386e66556683ed19bbf9ec Mon Sep 17 00:00:00 2001 From: Kinjal Patel Date: Mon, 26 Jan 2026 21:35:33 +0000 Subject: [PATCH 09/13] created seperate script for vllm fq export Signed-off-by: Kinjal Patel --- examples/vllm_serve/hf_ptq_export.py | 314 ++++++++++++++++ examples/vllm_serve/vllm_fq_export.py | 337 ++++++++++++++++++ .../torch/export/plugins/vllm_fakequant_hf.py | 25 +- 3 files changed, 674 insertions(+), 2 deletions(-) create mode 100644 examples/vllm_serve/hf_ptq_export.py create mode 100644 examples/vllm_serve/vllm_fq_export.py diff --git a/examples/vllm_serve/hf_ptq_export.py b/examples/vllm_serve/hf_ptq_export.py new file mode 100644 index 000000000..7ee5c091d --- /dev/null +++ b/examples/vllm_serve/hf_ptq_export.py @@ -0,0 +1,314 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import random +import warnings + +import numpy as np +import torch +import transformers +from accelerate import infer_auto_device_map, init_empty_weights +from accelerate.utils import get_max_memory +from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer + +import modelopt.torch.opt as mto +import modelopt.torch.quantization as mtq +from modelopt.torch.export import export_hf_vllm_fq_checkpoint +from modelopt.torch.quantization.utils import is_quantized +from modelopt.torch.utils.dataset_utils import ( + create_forward_loop, + get_dataset_dataloader, + get_max_batch_size, + get_supported_datasets, +) +from modelopt.torch.utils.memory_monitor import launch_memory_monitor + +RAND_SEED = 1234 + +mto.enable_huggingface_checkpointing() + + +def load_model( + ckpt_path, + device="cuda", + gpu_mem_percentage=0.8, + trust_remote_code=False, + use_seq_device_map=False, +): + print(f"Initializing model from {ckpt_path}") + + config_kwargs = {"trust_remote_code": trust_remote_code} if trust_remote_code else {} + try: + hf_config = AutoConfig.from_pretrained(ckpt_path, **config_kwargs) + except Exception as e: + raise RuntimeError(f"Failed to load model configuration from {ckpt_path}") from e + + # Pick the transformers model class to load. + architecture = hf_config.architectures[0] + use_auto_causallm = (not hasattr(transformers, architecture)) or ("Deepseek" in architecture) + if use_auto_causallm: + if not hasattr(transformers, architecture): + warnings.warn( + f"Architecture {architecture} not found in transformers: {transformers.__version__}. " + "Falling back to AutoModelForCausalLM." + ) + assert trust_remote_code, ( + "Please set trust_remote_code=True if you want to use this architecture" + ) + model_cls = AutoModelForCausalLM + from_config = model_cls.from_config + else: + model_cls = getattr(transformers, architecture) + from_config = model_cls._from_config + + # Decide device_map and optional memory cap. + if device == "cpu": + device_map = "cpu" + elif use_seq_device_map: + device_map = "sequential" + else: + device_map = "auto" + + model_kwargs: dict[str, object] = dict(config_kwargs) + if device_map == "sequential": + max_memory = get_max_memory() + model_kwargs["max_memory"] = {k: v * gpu_mem_percentage for k, v in max_memory.items()} + + # Detect if the model would offload to CPU; if so, cap GPU memory for calibration. + with init_empty_weights(): + torch_dtype = getattr(hf_config, "torch_dtype", torch.bfloat16) + empty_kwargs: dict[str, object] = dict(model_kwargs, torch_dtype=torch_dtype) + empty_kwargs.pop("max_memory", None) # only used by from_pretrained dispatch + if model_cls is not AutoModelForCausalLM: + empty_kwargs.pop("trust_remote_code", None) + empty_model = from_config(hf_config, **empty_kwargs) + + max_memory = get_max_memory() + inferred_device_map = infer_auto_device_map(empty_model, max_memory=max_memory) + if "cpu" in inferred_device_map.values(): + for dev_id, mem in list(max_memory.items()): + if isinstance(dev_id, int): + max_memory[dev_id] = mem * gpu_mem_percentage + print( + "Model does not fit to the GPU mem. " + f"We apply the following memory limit for calibration: \n{max_memory}\n" + "If you hit GPU OOM issue, please adjust `gpu_mem_percentage` or " + "reduce the calibration `batch_size` manually." + ) + model_kwargs["max_memory"] = max_memory + + model = model_cls.from_pretrained(ckpt_path, device_map=device_map, **model_kwargs) + model.eval() + + # If device_map was disabled (None), manually move model to target device + if device_map is None and device != "cpu": + print(f"Moving model to {device} device...") + model = model.to(device) + + if device == "cuda" and not is_model_on_gpu(model): + print("Warning: Some parameters are not on a GPU. Calibration can be slow or hit OOM") + + return model + + +def is_model_on_gpu(model) -> bool: + """Returns if the model is fully loaded on GPUs.""" + return all("cuda" in str(param.device) for param in model.parameters()) + +def get_tokenizer(ckpt_path, trust_remote_code=False): + """Returns the tokenizer from the model ckpt_path.""" + print(f"Initializing tokenizer from {ckpt_path}") + tokenizer = AutoTokenizer.from_pretrained( + ckpt_path, + padding_side="left", + trust_remote_code=trust_remote_code, + ) + + # can't set attribute 'pad_token' for "" + if tokenizer.pad_token != "": + tokenizer.pad_token = tokenizer.eos_token + + return tokenizer + +def quantize_and_export_model( + args: argparse.Namespace, +): + model = load_model( + args.pyt_ckpt_path, + device=args.device, + gpu_mem_percentage=args.gpu_max_mem_percentage, + trust_remote_code=args.trust_remote_code, + use_seq_device_map=args.use_seq_device_map, + ) + + if args.batch_size == 0: + args.batch_size = get_max_batch_size( + model, + max_sample_length=args.calib_seq, + ) + args.batch_size = min(args.batch_size, sum(args.calib_size)) + + print(f"Use calib batch_size {args.batch_size}") + tokenizer = get_tokenizer(args.pyt_ckpt_path, trust_remote_code=args.trust_remote_code) + device = model.device + calib_dataloader = get_dataset_dataloader( + dataset_name=args.dataset, + tokenizer=tokenizer, + batch_size=args.batch_size, + num_samples=args.calib_size, + device=device, + include_labels=False, + ) + calibrate_loop = create_forward_loop(dataloader=calib_dataloader) + mtq_cfg = getattr(mtq, args.quant_cfg) + if args.kv_cache_quant_cfg is not None: + kv_cache_quant_cfg = getattr(mtq, args.kv_cache_quant_cfg) + mtq_cfg = mtq.utils.update_quant_cfg_with_kv_cache_quant( + mtq_cfg, kv_cache_quant_cfg["quant_cfg"] + ) + input_ids = next(iter(calib_dataloader))["input_ids"][0:1] + model_is_already_quantized = is_quantized(model) + if not model_is_already_quantized: + generated_str_before_ptq = tokenizer.decode(model.generate(input_ids)[0]) + quantized_model = mtq.quantize(model, mtq_cfg, calibrate_loop) + generated_str_after_ptq = tokenizer.decode(model.generate(input_ids)[0]) + else: + print("Model is already quantized, Skipping quantization...") + quantized_model = model + + mtq.print_quant_summary(quantized_model) + if not model_is_already_quantized: + print("--------") + print(f"example test input: {tokenizer.decode(input_ids[0])}") + print("--------") + print(f"example outputs before ptq: {generated_str_before_ptq}") + print("--------") + print(f"example outputs after ptq: {generated_str_after_ptq}") + + export_hf_vllm_fq_checkpoint(quantized_model, args.export_path) + # from modelopt.torch.quantization.utils import get_quantizer_state_dict + # quantized_model.save_pretrained(args.export_path, state_dict=quantized_model.state_dict(), save_modelopt_state=False) + # modelopt_state = mto.modelopt_state(quantized_model) + # modelopt_state["modelopt_state_weights"] = get_quantizer_state_dict(quantized_model) + # torch.save(modelopt_state, f"{args.export_path}/modelopt_state.pth") + tokenizer.save_pretrained(args.export_path) + print(f"Model exported to {args.export_path}") + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument( + "--pyt_ckpt_path", + help="Specify where the PyTorch checkpoint path is", + required=True, + ) + parser.add_argument("--device", default="cuda") + parser.add_argument( + "--quant_cfg", + help="Quantization configuration.", + default="FP8_DEFAULT_CFG", + ) + parser.add_argument( + "--batch_size", + help="Batch size for calibration. Default to 0 as we calculate max batch size on-the-fly", + type=int, + default=0, + ) + parser.add_argument( + "--calib_size", + help=( + "Number of samples for calibration. If a comma separated list of values is provided, " + "each value will be used as the calibration size for the corresponding dataset. " + "This argument will be parsed and converted as a list of ints." + ), + type=str, + default="512", + ) + parser.add_argument( + "--calib_seq", + help="Maximum sequence length for calibration.", + type=int, + default=512, + ) + parser.add_argument("--export_path", default="exported_model") + parser.add_argument( + "--dataset", + help=( + f"name of a dataset, or a comma separated list of datasets. " + f"dataset choices are {get_supported_datasets()}" + ), + type=str, + default=None, + ) + parser.add_argument( + "--kv_cache_quant_cfg", + required=False, + default=None, + help="Specify KV cache quantization configuration, default to None if not provided", + ) + parser.add_argument( + "--trust_remote_code", + help="Set trust_remote_code for Huggingface models and tokenizers", + default=False, + action="store_true", + ) + parser.add_argument( + "--gpu_max_mem_percentage", + help=( + "Specify the percentage of available GPU memory to use for loading the model when " + "device_map is set to sequential. " + "By default, 80%% of the available GPU memory is used." + ), + type=float, + default=0.8, + ) + parser.add_argument( + "--use_seq_device_map", + help=( + "Use device_map=sequential to load the model onto GPUs. This ensures the model is loaded " + "utilizing the percentage of available GPU memory as specified by the value passed with gpu_max_mem flag." + "Helpful in cases where device_map=auto loads model unevenly on GPUs causing GPU OOM during quantization." + ), + default=False, + action="store_true", + ) + + return parser.parse_args() + + +def main(args: argparse.Namespace): + if not torch.cuda.is_available(): + raise OSError("GPU is required for inference.") + + random.seed(RAND_SEED) + np.random.seed(RAND_SEED) + + # launch a memory monitor to read the currently used GPU memory. + launch_memory_monitor() + + # Force eager execution for all model types. + torch.compiler.set_stance("force_eager") + + # Quantize + quantize_and_export_model(args) + + +if __name__ == "__main__": + args = parse_args() + + args.dataset = args.dataset.split(",") if args.dataset else None + args.calib_size = [int(num_sample) for num_sample in args.calib_size.split(",")] + main(args) diff --git a/examples/vllm_serve/vllm_fq_export.py b/examples/vllm_serve/vllm_fq_export.py new file mode 100644 index 000000000..feeac3e92 --- /dev/null +++ b/examples/vllm_serve/vllm_fq_export.py @@ -0,0 +1,337 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import random +import time +import warnings +from typing import Any + +import numpy as np +import torch +from accelerate.hooks import remove_hook_from_module +from example_utils import ( + build_quant_cfg, + copy_custom_model_files, + get_model, + get_processor, + get_tokenizer, + is_enc_dec, + is_nemotron_vl, + run_nemotron_vl_preview, +) +from torch.utils.data import DataLoader +import transformers +from transformers import ( + AutoConfig, + AutoModelForCausalLM, + AutoProcessor, + PreTrainedTokenizer, + PreTrainedTokenizerBase, + PreTrainedTokenizerFast, + ProcessorMixin, + WhisperProcessor, +) +from accelerate import infer_auto_device_map, init_empty_weights +from accelerate.utils import get_max_memory +import modelopt.torch.opt as mto +import modelopt.torch.quantization as mtq +import modelopt.torch.sparsity as mts +from modelopt.torch.export import ( + export_hf_checkpoint, + export_hf_vllm_fq_checkpoint, + export_tensorrt_llm_checkpoint, + get_model_type, +) +from modelopt.torch.export.model_utils import get_language_model_from_vl, is_multimodal_model +from modelopt.torch.quantization.config import _default_disabled_quantizer_cfg, need_calibration +from modelopt.torch.quantization.plugins.accelerate import init_quantized_weights +from modelopt.torch.quantization.utils import is_quantized +from modelopt.torch.utils.dataset_utils import ( + create_forward_loop, + get_dataset_dataloader, + get_max_batch_size, + get_supported_datasets, +) +from modelopt.torch.utils.image_processor import BaseImageProcessor, MllamaImageProcessor +from modelopt.torch.utils.memory_monitor import launch_memory_monitor +from modelopt.torch.utils.speech_dataset_utils import get_speech_dataset_dataloader +from modelopt.torch.utils.vlm_dataset_utils import get_vlm_dataset_dataloader + +RAND_SEED = 1234 + +mto.enable_huggingface_checkpointing() + +def load_model( + ckpt_path, + device="cuda", + gpu_mem_percentage=0.8, + trust_remote_code=False, + use_seq_device_map=False, +): + print(f"Initializing model from {ckpt_path}") + + device_map = "auto" + if device == "cpu": + device_map = "cpu" + + # Prepare config kwargs for loading + config_kwargs = {"trust_remote_code": trust_remote_code} if trust_remote_code else {} + + # Load config once + try: + hf_config = AutoConfig.from_pretrained(ckpt_path, **config_kwargs) + except Exception as e: + print(f"Error: Could not load config from {ckpt_path}: {e}") + raise RuntimeError(f"Failed to load model configuration from {ckpt_path}") from e + + model_kwargs = config_kwargs.copy() + + if use_seq_device_map: + device_map = "sequential" + # If we use sequential, set max_memory limit to ensure that the model does not occupy the full GPU + max_memory = get_max_memory() + max_memory = {key: value * gpu_mem_percentage for key, value in max_memory.items()} + model_kwargs["max_memory"] = max_memory + + architecture = hf_config.architectures[0] + + if not hasattr(transformers, architecture) or "Deepseek" in architecture: + if not hasattr(transformers, architecture): + warnings.warn( + f"Architecture {architecture} not found in transformers: {transformers.__version__}. " + "Falling back to AutoModelForCausalLM." + ) + assert trust_remote_code, ( + "Please set trust_remote_code to True if you want to use this architecture" + ) + + auto_model_module = AutoModelForCausalLM + from_config = auto_model_module.from_config + else: + auto_model_module = getattr(transformers, architecture) + from_config = auto_model_module._from_config + + with init_empty_weights(): + # When computing the device_map, assuming bfloat16 precision by default, + # unless specified by the hf_config. + torch_dtype = getattr(hf_config, "torch_dtype", torch.bfloat16) + model_kwargs2 = model_kwargs.copy() + if auto_model_module != AutoModelForCausalLM: + model_kwargs2.pop("trust_remote_code", None) + model_kwargs2["torch_dtype"] = torch_dtype + model_kwargs2.pop("max_memory", None) + model = from_config(hf_config, **model_kwargs2) + + max_memory = get_max_memory() + inferred_device_map = infer_auto_device_map(model, max_memory=max_memory) + + on_cpu = "cpu" in inferred_device_map.values() + + if on_cpu: + for _device in max_memory: + if isinstance(_device, int): + max_memory[_device] *= gpu_mem_percentage + + print( + "Model does not fit to the GPU mem. " + f"We apply the following memory limit for calibration: \n{max_memory}\n" + "If you hit GPU OOM issue, please adjust `gpu_mem_percentage` or " + "reduce the calibration `batch_size` manually." + ) + model_kwargs["max_memory"] = max_memory + + model = auto_model_module.from_pretrained( + ckpt_path, + device_map=device_map, + **model_kwargs, + ) + model.eval() + + # If device_map was disabled (None), manually move model to target device + if device_map is None and device != "cpu": + print(f"Moving model to {device} device...") + model = model.to(device) + + if device == "cuda" and not is_model_on_gpu(model): + print("Warning: Some parameters are not on a GPU. Calibration can be slow or hit OOM") + + return model + + +def is_model_on_gpu(model) -> bool: + """Returns if the model is fully loaded on GPUs.""" + return all("cuda" in str(param.device) for param in model.parameters()) + + +def quantize_and_export_model( + args: argparse.Namespace, +): + model = load_model( args.pyt_ckpt_path, + device=args.device, + gpu_mem_percentage=args.gpu_max_mem_percentage, + trust_remote_code=args.trust_remote_code, + use_seq_device_map=args.use_seq_device_map,) + + args.batch_size = get_max_batch_size( + model, + max_sample_length=args.calib_seq, + ) + args.batch_size = min(args.batch_size, sum(args.calib_size)) + + print(f"Use calib batch_size {args.batch_size}") + tokenizer = get_tokenizer(args.pyt_ckpt_path, trust_remote_code=args.trust_remote_code) + device = model.device + calib_dataloader = get_dataset_dataloader( + dataset_name=args.dataset, + tokenizer=tokenizer, + batch_size=args.batch_size, + num_samples=args.calib_size, + device=device, + include_labels=False, + ) + calibrate_loop = create_forward_loop(dataloader=calib_dataloader) + mtq_cfg = getattr(mtq, args.quant_cfg) # type: ignore [arg-type] + if args.kv_cache_quant_cfg is not None: + kv_cache_quant_cfg = getattr(mtq, args.kv_cache_quant_cfg) # type: ignore [arg-type] + mtq_cfg = mtq.utils.update_quant_cfg_with_kv_cache_quant( + mtq_cfg["quant_cfg"], kv_cache_quant_cfg["quant_cfg"] + ) + input_str = tokenizer.decode(next(iter(calib_dataloader))["input_ids"][0]) + generated_str_before_ptq = model.run(input_str) + + quantized_model = mtq.quantize(model, mtq_cfg, calibrate_loop) + mtq.print_quant_summary(quantized_model) + generated_str_after_ptq = model.run(input_str) + + print("--------") + print(f"example test input: {input_str}") + print("--------") + print(f"example outputs before ptq: {generated_str_before_ptq}") + print("--------") + print(f"example outputs after ptq: {generated_str_after_ptq}") + + export_hf_vllm_fq_checkpoint(quantized_model, args.export_path) + print(f"Model exported to {args.export_path}") + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument( + "--pyt_ckpt_path", + help="Specify where the PyTorch checkpoint path is", + required=True, + ) + parser.add_argument("--device", default="cuda") + parser.add_argument( + "--quant_cfg", + help=( + "Quantization configuration." + ), + default="FP8_DEFAULT_CFG", + ) + parser.add_argument( + "--batch_size", + help="Batch size for calibration. Default to 0 as we calculate max batch size on-the-fly", + type=int, + default=0, + ) + parser.add_argument( + "--calib_size", + help=( + "Number of samples for calibration. If a comma separated list of values is provided, " + "each value will be used as the calibration size for the corresponding dataset. " + "This argument will be parsed and converted as a list of ints." + ), + type=str, + default="512", + ) + parser.add_argument( + "--calib_seq", + help="Maximum sequence length for calibration.", + type=int, + default=512, + ) + parser.add_argument("--export_path", default="exported_model") + parser.add_argument( + "--dataset", + help=( + f"name of a dataset, or a comma separated list of datasets. " + f"dataset choices are {get_supported_datasets()}" + ), + type=str, + default=None, + ) + parser.add_argument( + "--kv_cache_quant_cfg", + required=False, + default=None, + help="Specify KV cache quantization configuration, default to None if not provided", + ) + parser.add_argument( + "--trust_remote_code", + help="Set trust_remote_code for Huggingface models and tokenizers", + default=False, + action="store_true", + ) + parser.add_argument( + "--gpu_max_mem_percentage", + help=( + "Specify the percentage of available GPU memory to use for loading the model when " + "device_map is set to sequential. " + "By default, 80%% of the available GPU memory is used." + ), + type=float, + default=0.8, + ) + parser.add_argument( + "--use_seq_device_map", + help=( + "Use device_map=sequential to load the model onto GPUs. This ensures the model is loaded " + "utilizing the percentage of available GPU memory as specified by the value passed with gpu_max_mem flag." + "Helpful in cases where device_map=auto loads model unevenly on GPUs causing GPU OOM during quantization." + ), + default=False, + action="store_true", + ) + + return parser.parse_args() + + +def main(args: argparse.Namespace): + if not torch.cuda.is_available(): + raise OSError("GPU is required for inference.") + + random.seed(RAND_SEED) + np.random.seed(RAND_SEED) + + # launch a memory monitor to read the currently used GPU memory. + launch_memory_monitor() + + # Force eager execution for all model types. + torch.compiler.set_stance("force_eager") + + # Quantize + quantize_and_export_model( + args, + + ) + + +if __name__ == "__main__": + args = parse_args() + + args.dataset = args.dataset.split(",") if args.dataset else None + args.calib_size = [int(num_sample) for num_sample in args.calib_size.split(",")] + main(args) diff --git a/modelopt/torch/export/plugins/vllm_fakequant_hf.py b/modelopt/torch/export/plugins/vllm_fakequant_hf.py index 03b191346..b30e7530d 100644 --- a/modelopt/torch/export/plugins/vllm_fakequant_hf.py +++ b/modelopt/torch/export/plugins/vllm_fakequant_hf.py @@ -15,6 +15,7 @@ """Export HuggingFace model to vLLM fakequant checkpoint.""" from pathlib import Path +from typing import Any import torch import torch.nn as nn @@ -26,6 +27,25 @@ __all__ = ["export_hf_vllm_fq_checkpoint"] +def cleanup_for_torch_save(x: Any) -> Any: + """Drop callables / local closures (e.g. `.new_forward`) before torch.save. + + ModelOpt stored state dict may contain local closures like `.new_forward` + which are not picklable. So we need to cleanup the state dict before saving. + """ + if isinstance(x, dict): + return { + k: cleanup_for_torch_save(v) + for k, v in x.items() + if not callable(v) and "" not in str(getattr(v, "__qualname__", "")) + } + if isinstance(x, list): + return [cleanup_for_torch_save(v) for v in x] + if isinstance(x, tuple): + return tuple(cleanup_for_torch_save(v) for v in x) + return x + + def export_hf_vllm_fq_checkpoint( model: nn.Module, export_dir: Path | str, @@ -48,8 +68,9 @@ def export_hf_vllm_fq_checkpoint( quantizer_state_dict = get_quantizer_state_dict(model) modelopt_state = mto.modelopt_state(model) - modelopt_state["modelopt_state_weights"] = quantizer_state_dict - torch.save(modelopt_state, f"{export_dir}/vllm_fq_modelopt_state.pth") + modelopt_state = cleanup_for_torch_save(modelopt_state) + modelopt_state["modelopt_state_weights"] = cleanup_for_torch_save(quantizer_state_dict) + torch.save(modelopt_state, export_dir / "vllm_fq_modelopt_state.pth") # remove quantizer from model for _, module in model.named_modules(): if is_quantlinear(module): From 995cbda25d6f44ed290dc1773d8214f5cee1eb88 Mon Sep 17 00:00:00 2001 From: Kinjal Patel Date: Mon, 26 Jan 2026 21:36:45 +0000 Subject: [PATCH 10/13] removed vllm fq export from hf_ptq Signed-off-by: Kinjal Patel --- examples/llm_ptq/hf_ptq.py | 12 +----------- 1 file changed, 1 insertion(+), 11 deletions(-) diff --git a/examples/llm_ptq/hf_ptq.py b/examples/llm_ptq/hf_ptq.py index 3023161df..fd1c96cdb 100755 --- a/examples/llm_ptq/hf_ptq.py +++ b/examples/llm_ptq/hf_ptq.py @@ -50,7 +50,6 @@ import modelopt.torch.sparsity as mts from modelopt.torch.export import ( export_hf_checkpoint, - export_hf_vllm_fq_checkpoint, export_tensorrt_llm_checkpoint, get_model_type, ) @@ -623,10 +622,7 @@ def export_quantized( "Unified HF export format does not specify inference tensor parallel or pipeline parallel. " "They will be set at deployment time." ) - export_fn = ( - export_hf_vllm_fq_checkpoint if args.export_vllm_fq else export_hf_checkpoint - ) - export_fn( + export_hf_checkpoint( full_model, export_dir=export_path, ) @@ -1083,12 +1079,6 @@ def parse_args() -> argparse.Namespace: "(sensitivity scores, costs, etc.). Only used when auto_quantize_bits is specified." ), ) - parser.add_argument( - "--export_vllm_fq", - help="Export vLLM fakequant checkpoint.", - default=False, - action="store_true", - ) return parser.parse_args() From a9cefbec298b8bd243e5453d484783d387fe743e Mon Sep 17 00:00:00 2001 From: Kinjal Patel Date: Mon, 26 Jan 2026 21:38:03 +0000 Subject: [PATCH 11/13] minor Signed-off-by: Kinjal Patel --- examples/vllm_serve/README.md | 26 +- examples/vllm_serve/vllm_fq_export.py | 337 -------------------------- 2 files changed, 11 insertions(+), 352 deletions(-) delete mode 100644 examples/vllm_serve/vllm_fq_export.py diff --git a/examples/vllm_serve/README.md b/examples/vllm_serve/README.md index 60002b747..8c15d75e9 100644 --- a/examples/vllm_serve/README.md +++ b/examples/vllm_serve/README.md @@ -58,24 +58,20 @@ lm_eval --model local-completions --tasks gsm8k --model_args model=, ## Load QAT/PTQ model and serve in vLLM (WIP) -Overwrite the calibrated amax value with prepared values from either QAT/PTQ. +Step 1: export the model with bf16 weights and quantizer state. To export the model: -Step 1: export the model with bf16 weights and amax values. To export the model: +- For **HF** models, use `hf_ptq_export.py`: -- For **HF** models, you can use `modelopt.torch.export.export_hf_vllm_fq_checkpoint`: - - ```python - import torch - from modelopt.torch.export import export_hf_vllm_fq_checkpoint - - with torch.inference_mode(): - export_hf_vllm_fq_checkpoint( - model, # The quantized model. - export_dir, # The directory where the exported files will be stored. - ) - ``` +```bash +python hf_ptq_export.py\ + --pyt_ckpt_path \ + --quant_cfg NVFP4_DEFAULT_CFG \ + --export_path \ + --trust_remote_code +``` - Or run the example script `examples/llm_ptq/hf_ptq.py` with the `--export_vllm_fq` **flag** to export a vLLM-fakequant-compatible ModelOpt state (it generates `vllm_fq_modelopt_state.pth`, which you can use via `MODELOPT_STATE_PATH`). + This creates `/vllm_fq_modelopt_state.pth` (ModelOpt quantizer state for vLLM fake-quant reload) and saves the HF-exported model under `` (config/tokenizer/weights). + Note: `--pyt_ckpt_path` can point to either an HF checkpoint or a ModelOpt-saved checkpoint (e.g., a QAT/QAD checkpoint produced by `examples/llm_qat/main.py`). If the input checkpoint is already quantized, the script will **skip re-quantization** and only export artifacts for vLLM fakequant reload. - For **MCore** models, use `modelopt.torch.export.export_mcore_gpt_to_hf_vllm_fq`: diff --git a/examples/vllm_serve/vllm_fq_export.py b/examples/vllm_serve/vllm_fq_export.py deleted file mode 100644 index feeac3e92..000000000 --- a/examples/vllm_serve/vllm_fq_export.py +++ /dev/null @@ -1,337 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import argparse -import random -import time -import warnings -from typing import Any - -import numpy as np -import torch -from accelerate.hooks import remove_hook_from_module -from example_utils import ( - build_quant_cfg, - copy_custom_model_files, - get_model, - get_processor, - get_tokenizer, - is_enc_dec, - is_nemotron_vl, - run_nemotron_vl_preview, -) -from torch.utils.data import DataLoader -import transformers -from transformers import ( - AutoConfig, - AutoModelForCausalLM, - AutoProcessor, - PreTrainedTokenizer, - PreTrainedTokenizerBase, - PreTrainedTokenizerFast, - ProcessorMixin, - WhisperProcessor, -) -from accelerate import infer_auto_device_map, init_empty_weights -from accelerate.utils import get_max_memory -import modelopt.torch.opt as mto -import modelopt.torch.quantization as mtq -import modelopt.torch.sparsity as mts -from modelopt.torch.export import ( - export_hf_checkpoint, - export_hf_vllm_fq_checkpoint, - export_tensorrt_llm_checkpoint, - get_model_type, -) -from modelopt.torch.export.model_utils import get_language_model_from_vl, is_multimodal_model -from modelopt.torch.quantization.config import _default_disabled_quantizer_cfg, need_calibration -from modelopt.torch.quantization.plugins.accelerate import init_quantized_weights -from modelopt.torch.quantization.utils import is_quantized -from modelopt.torch.utils.dataset_utils import ( - create_forward_loop, - get_dataset_dataloader, - get_max_batch_size, - get_supported_datasets, -) -from modelopt.torch.utils.image_processor import BaseImageProcessor, MllamaImageProcessor -from modelopt.torch.utils.memory_monitor import launch_memory_monitor -from modelopt.torch.utils.speech_dataset_utils import get_speech_dataset_dataloader -from modelopt.torch.utils.vlm_dataset_utils import get_vlm_dataset_dataloader - -RAND_SEED = 1234 - -mto.enable_huggingface_checkpointing() - -def load_model( - ckpt_path, - device="cuda", - gpu_mem_percentage=0.8, - trust_remote_code=False, - use_seq_device_map=False, -): - print(f"Initializing model from {ckpt_path}") - - device_map = "auto" - if device == "cpu": - device_map = "cpu" - - # Prepare config kwargs for loading - config_kwargs = {"trust_remote_code": trust_remote_code} if trust_remote_code else {} - - # Load config once - try: - hf_config = AutoConfig.from_pretrained(ckpt_path, **config_kwargs) - except Exception as e: - print(f"Error: Could not load config from {ckpt_path}: {e}") - raise RuntimeError(f"Failed to load model configuration from {ckpt_path}") from e - - model_kwargs = config_kwargs.copy() - - if use_seq_device_map: - device_map = "sequential" - # If we use sequential, set max_memory limit to ensure that the model does not occupy the full GPU - max_memory = get_max_memory() - max_memory = {key: value * gpu_mem_percentage for key, value in max_memory.items()} - model_kwargs["max_memory"] = max_memory - - architecture = hf_config.architectures[0] - - if not hasattr(transformers, architecture) or "Deepseek" in architecture: - if not hasattr(transformers, architecture): - warnings.warn( - f"Architecture {architecture} not found in transformers: {transformers.__version__}. " - "Falling back to AutoModelForCausalLM." - ) - assert trust_remote_code, ( - "Please set trust_remote_code to True if you want to use this architecture" - ) - - auto_model_module = AutoModelForCausalLM - from_config = auto_model_module.from_config - else: - auto_model_module = getattr(transformers, architecture) - from_config = auto_model_module._from_config - - with init_empty_weights(): - # When computing the device_map, assuming bfloat16 precision by default, - # unless specified by the hf_config. - torch_dtype = getattr(hf_config, "torch_dtype", torch.bfloat16) - model_kwargs2 = model_kwargs.copy() - if auto_model_module != AutoModelForCausalLM: - model_kwargs2.pop("trust_remote_code", None) - model_kwargs2["torch_dtype"] = torch_dtype - model_kwargs2.pop("max_memory", None) - model = from_config(hf_config, **model_kwargs2) - - max_memory = get_max_memory() - inferred_device_map = infer_auto_device_map(model, max_memory=max_memory) - - on_cpu = "cpu" in inferred_device_map.values() - - if on_cpu: - for _device in max_memory: - if isinstance(_device, int): - max_memory[_device] *= gpu_mem_percentage - - print( - "Model does not fit to the GPU mem. " - f"We apply the following memory limit for calibration: \n{max_memory}\n" - "If you hit GPU OOM issue, please adjust `gpu_mem_percentage` or " - "reduce the calibration `batch_size` manually." - ) - model_kwargs["max_memory"] = max_memory - - model = auto_model_module.from_pretrained( - ckpt_path, - device_map=device_map, - **model_kwargs, - ) - model.eval() - - # If device_map was disabled (None), manually move model to target device - if device_map is None and device != "cpu": - print(f"Moving model to {device} device...") - model = model.to(device) - - if device == "cuda" and not is_model_on_gpu(model): - print("Warning: Some parameters are not on a GPU. Calibration can be slow or hit OOM") - - return model - - -def is_model_on_gpu(model) -> bool: - """Returns if the model is fully loaded on GPUs.""" - return all("cuda" in str(param.device) for param in model.parameters()) - - -def quantize_and_export_model( - args: argparse.Namespace, -): - model = load_model( args.pyt_ckpt_path, - device=args.device, - gpu_mem_percentage=args.gpu_max_mem_percentage, - trust_remote_code=args.trust_remote_code, - use_seq_device_map=args.use_seq_device_map,) - - args.batch_size = get_max_batch_size( - model, - max_sample_length=args.calib_seq, - ) - args.batch_size = min(args.batch_size, sum(args.calib_size)) - - print(f"Use calib batch_size {args.batch_size}") - tokenizer = get_tokenizer(args.pyt_ckpt_path, trust_remote_code=args.trust_remote_code) - device = model.device - calib_dataloader = get_dataset_dataloader( - dataset_name=args.dataset, - tokenizer=tokenizer, - batch_size=args.batch_size, - num_samples=args.calib_size, - device=device, - include_labels=False, - ) - calibrate_loop = create_forward_loop(dataloader=calib_dataloader) - mtq_cfg = getattr(mtq, args.quant_cfg) # type: ignore [arg-type] - if args.kv_cache_quant_cfg is not None: - kv_cache_quant_cfg = getattr(mtq, args.kv_cache_quant_cfg) # type: ignore [arg-type] - mtq_cfg = mtq.utils.update_quant_cfg_with_kv_cache_quant( - mtq_cfg["quant_cfg"], kv_cache_quant_cfg["quant_cfg"] - ) - input_str = tokenizer.decode(next(iter(calib_dataloader))["input_ids"][0]) - generated_str_before_ptq = model.run(input_str) - - quantized_model = mtq.quantize(model, mtq_cfg, calibrate_loop) - mtq.print_quant_summary(quantized_model) - generated_str_after_ptq = model.run(input_str) - - print("--------") - print(f"example test input: {input_str}") - print("--------") - print(f"example outputs before ptq: {generated_str_before_ptq}") - print("--------") - print(f"example outputs after ptq: {generated_str_after_ptq}") - - export_hf_vllm_fq_checkpoint(quantized_model, args.export_path) - print(f"Model exported to {args.export_path}") - -def parse_args() -> argparse.Namespace: - parser = argparse.ArgumentParser(description=__doc__) - parser.add_argument( - "--pyt_ckpt_path", - help="Specify where the PyTorch checkpoint path is", - required=True, - ) - parser.add_argument("--device", default="cuda") - parser.add_argument( - "--quant_cfg", - help=( - "Quantization configuration." - ), - default="FP8_DEFAULT_CFG", - ) - parser.add_argument( - "--batch_size", - help="Batch size for calibration. Default to 0 as we calculate max batch size on-the-fly", - type=int, - default=0, - ) - parser.add_argument( - "--calib_size", - help=( - "Number of samples for calibration. If a comma separated list of values is provided, " - "each value will be used as the calibration size for the corresponding dataset. " - "This argument will be parsed and converted as a list of ints." - ), - type=str, - default="512", - ) - parser.add_argument( - "--calib_seq", - help="Maximum sequence length for calibration.", - type=int, - default=512, - ) - parser.add_argument("--export_path", default="exported_model") - parser.add_argument( - "--dataset", - help=( - f"name of a dataset, or a comma separated list of datasets. " - f"dataset choices are {get_supported_datasets()}" - ), - type=str, - default=None, - ) - parser.add_argument( - "--kv_cache_quant_cfg", - required=False, - default=None, - help="Specify KV cache quantization configuration, default to None if not provided", - ) - parser.add_argument( - "--trust_remote_code", - help="Set trust_remote_code for Huggingface models and tokenizers", - default=False, - action="store_true", - ) - parser.add_argument( - "--gpu_max_mem_percentage", - help=( - "Specify the percentage of available GPU memory to use for loading the model when " - "device_map is set to sequential. " - "By default, 80%% of the available GPU memory is used." - ), - type=float, - default=0.8, - ) - parser.add_argument( - "--use_seq_device_map", - help=( - "Use device_map=sequential to load the model onto GPUs. This ensures the model is loaded " - "utilizing the percentage of available GPU memory as specified by the value passed with gpu_max_mem flag." - "Helpful in cases where device_map=auto loads model unevenly on GPUs causing GPU OOM during quantization." - ), - default=False, - action="store_true", - ) - - return parser.parse_args() - - -def main(args: argparse.Namespace): - if not torch.cuda.is_available(): - raise OSError("GPU is required for inference.") - - random.seed(RAND_SEED) - np.random.seed(RAND_SEED) - - # launch a memory monitor to read the currently used GPU memory. - launch_memory_monitor() - - # Force eager execution for all model types. - torch.compiler.set_stance("force_eager") - - # Quantize - quantize_and_export_model( - args, - - ) - - -if __name__ == "__main__": - args = parse_args() - - args.dataset = args.dataset.split(",") if args.dataset else None - args.calib_size = [int(num_sample) for num_sample in args.calib_size.split(",")] - main(args) From ef0691949e22f9d3a10e6334e9b38eb75684300f Mon Sep 17 00:00:00 2001 From: Kinjal Patel Date: Mon, 26 Jan 2026 21:40:26 +0000 Subject: [PATCH 12/13] minor Signed-off-by: Kinjal Patel --- examples/llm_ptq/hf_ptq.py | 1 + 1 file changed, 1 insertion(+) diff --git a/examples/llm_ptq/hf_ptq.py b/examples/llm_ptq/hf_ptq.py index fd1c96cdb..7c91ca97f 100755 --- a/examples/llm_ptq/hf_ptq.py +++ b/examples/llm_ptq/hf_ptq.py @@ -622,6 +622,7 @@ def export_quantized( "Unified HF export format does not specify inference tensor parallel or pipeline parallel. " "They will be set at deployment time." ) + export_hf_checkpoint( full_model, export_dir=export_path, From 12f7fe994a74fb53818262f636f0ab6282eee100 Mon Sep 17 00:00:00 2001 From: Kinjal Patel Date: Mon, 26 Jan 2026 21:55:52 +0000 Subject: [PATCH 13/13] cleanup Signed-off-by: Kinjal Patel --- examples/vllm_serve/hf_ptq_export.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/examples/vllm_serve/hf_ptq_export.py b/examples/vllm_serve/hf_ptq_export.py index 7ee5c091d..fda5a6dec 100644 --- a/examples/vllm_serve/hf_ptq_export.py +++ b/examples/vllm_serve/hf_ptq_export.py @@ -128,6 +128,7 @@ def is_model_on_gpu(model) -> bool: """Returns if the model is fully loaded on GPUs.""" return all("cuda" in str(param.device) for param in model.parameters()) + def get_tokenizer(ckpt_path, trust_remote_code=False): """Returns the tokenizer from the model ckpt_path.""" print(f"Initializing tokenizer from {ckpt_path}") @@ -143,6 +144,7 @@ def get_tokenizer(ckpt_path, trust_remote_code=False): return tokenizer + def quantize_and_export_model( args: argparse.Namespace, ): @@ -188,7 +190,7 @@ def quantize_and_export_model( else: print("Model is already quantized, Skipping quantization...") quantized_model = model - + mtq.print_quant_summary(quantized_model) if not model_is_already_quantized: print("--------") @@ -199,11 +201,6 @@ def quantize_and_export_model( print(f"example outputs after ptq: {generated_str_after_ptq}") export_hf_vllm_fq_checkpoint(quantized_model, args.export_path) - # from modelopt.torch.quantization.utils import get_quantizer_state_dict - # quantized_model.save_pretrained(args.export_path, state_dict=quantized_model.state_dict(), save_modelopt_state=False) - # modelopt_state = mto.modelopt_state(quantized_model) - # modelopt_state["modelopt_state_weights"] = get_quantizer_state_dict(quantized_model) - # torch.save(modelopt_state, f"{args.export_path}/modelopt_state.pth") tokenizer.save_pretrained(args.export_path) print(f"Model exported to {args.export_path}")