diff --git a/CHANGELOG.rst b/CHANGELOG.rst index e615627b2..10575a18e 100755 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -14,6 +14,7 @@ NVIDIA Model Optimizer Changelog (Linux) - Add support for Kimi K2 Thinking model quantization from the original int4 checkpoint. - Add support for ``params`` constraint based automatic neural architecture search in Minitron pruning (``mcore_minitron``) as an alternative to manual pruning (using ``export_config``). See `examples/pruning/README.md `_ for more details on its usage. - Add support for calibration data with multiple samples in ``npz`` format in the ONNX Autocast workflow. +- Add support for vLLM fakequant reload using ModelOpt state for HF models. See `examples/vllm_serve/README.md `_ for more details. 0.41 (2026-01-19) ^^^^^^^^^^^^^^^^^ diff --git a/examples/vllm_serve/README.md b/examples/vllm_serve/README.md index ff0c4eea3..8c15d75e9 100644 --- a/examples/vllm_serve/README.md +++ b/examples/vllm_serve/README.md @@ -23,9 +23,11 @@ You can either edit the `quant_config` dictionary in `vllm_serve_fakequant.py`, |-----------------|--------------------------------------------------|---------------------| | QUANT_DATASET | Dataset name for calibration | cnn_dailymail | | QUANT_CALIB_SIZE| Number of samples used for calibration | 512 | -| QUANT_CFG | Quantization format | NVFP4_DEFAULT_CFG | -| KV_QUANT_CFG | Quantization format for KV Cache | None | -| AMAX_FILE_PATH | Optional path to amax file (for loading amax) | None | +| QUANT_CFG | Quantization config | None | +| KV_QUANT_CFG | KV-cache quantization config | None | +| QUANT_FILE_PATH | Optional path to exported quantizer state dict `quantizer_state.pth` | None | +| MODELOPT_STATE_PATH | Optional path to exported `vllm_fq_modelopt_state.pth` (restores quantizer state and parameters) | None | +| CALIB_BATCH_SIZE | Calibration batch size | 1 | Set these variables in your shell or Docker environment as needed to customize calibration. @@ -56,21 +58,53 @@ lm_eval --model local-completions --tasks gsm8k --model_args model=, ## Load QAT/PTQ model and serve in vLLM (WIP) -Overwrite the calibrated amax value with prepared values from either QAT/PTQ. +Step 1: export the model with bf16 weights and quantizer state. To export the model: -Step 1: export the model with bf16 weights and amax values. To export the model: +- For **HF** models, use `hf_ptq_export.py`: -- For HF model use `modelopt.torch.export.export_hf_vllm_fq_checkpoint` function. -- For MCore model use `modelopt.torch.export.export_mcore_gpt_to_hf_vllm_fq` function. +```bash +python hf_ptq_export.py\ + --pyt_ckpt_path \ + --quant_cfg NVFP4_DEFAULT_CFG \ + --export_path \ + --trust_remote_code +``` + + This creates `/vllm_fq_modelopt_state.pth` (ModelOpt quantizer state for vLLM fake-quant reload) and saves the HF-exported model under `` (config/tokenizer/weights). + Note: `--pyt_ckpt_path` can point to either an HF checkpoint or a ModelOpt-saved checkpoint (e.g., a QAT/QAD checkpoint produced by `examples/llm_qat/main.py`). If the input checkpoint is already quantized, the script will **skip re-quantization** and only export artifacts for vLLM fakequant reload. + +- For **MCore** models, use `modelopt.torch.export.export_mcore_gpt_to_hf_vllm_fq`: + + ```python + from modelopt.torch.export import export_mcore_gpt_to_hf_vllm_fq + export_mcore_gpt_to_hf_vllm_fq( + unwrapped_model, # Quantized MCore model + args.pretrained_model_name, # HF model id/path (for config/tokenizer) + export_dir=args.export_dir, # Directory where exported files will be stored + ) + + ``` + + This generates `quantizer_state.pth`, which contains quantizer tensors for vLLM reload via `QUANT_FILE_PATH`. + +Step 2: use the exported artifacts when serving: + +- **HF export**: pass the exported `vllm_fq_modelopt_state.pth` via `MODELOPT_STATE_PATH` + +```bash +# HF +MODELOPT_STATE_PATH= python vllm_serve_fakequant.py -tp 8 --host 0.0.0.0 --port 8000 +``` -Step 2: configure from exported model using AMAX_FILE_PATH environment variable in step 1. For example: +- **MCore export**: pass the exported `quantizer_state.pth` via `QUANT_FILE_PATH` and set `QUANT_CFG` to match the MCore quantization recipe ```bash -AMAX_FILE_PATH= QUANT_CFG= python vllm_serve_fakequant.py -tp 8 --host 0.0.0.0 --port 8000 +# MCore +QUANT_CFG= QUANT_FILE_PATH= python vllm_serve_fakequant.py -tp 8 --host 0.0.0.0 --port 8000 ``` ## Known Problems -1. AWQ is not yet supported in vLLM. -2. QAT checkpoint export doesn't have KV Cache quantization enabled. KV Cache fake quantization works for PTQ. -3. Mixed precision checkpoint doesn't work currently. +1. **MCore reload does not use `MODELOPT_STATE_PATH`**; use `QUANT_FILE_PATH` and make sure `QUANT_CFG` matches the quantization recipe used for the original MCore model (otherwise quantizer keys/config won’t align). +2. AWQ reload is not supported yet +3. KV cache quantization export and reload is not supported in MCore yet. diff --git a/examples/vllm_serve/fakequant_worker.py b/examples/vllm_serve/fakequant_worker.py index 772c6fe66..fe822c877 100644 --- a/examples/vllm_serve/fakequant_worker.py +++ b/examples/vllm_serve/fakequant_worker.py @@ -15,9 +15,7 @@ import dataclasses import os -import re import warnings -from collections import defaultdict from contextlib import contextmanager from typing import Any @@ -27,104 +25,17 @@ from vllm.sampling_params import SamplingParams from vllm.v1.core.sched.output import CachedRequestData, NewRequestData, SchedulerOutput from vllm.v1.worker.gpu_worker import Worker as BaseWorker +from vllm_reload_utils import ( + convert_dict_to_vllm, + convert_modelopt_state_to_vllm, + process_state_dict_for_tp, +) import modelopt.torch.quantization as mtq +from modelopt.torch.opt.conversion import restore_from_modelopt_state from modelopt.torch.utils.dataset_utils import get_dataset_dataloader -def convert_amax_hf2vllm( - hf_state_dict: dict[str, torch.Tensor], fuse_experts: bool = False -) -> dict[str, torch.Tensor]: - """ - Convert amax values from HuggingFace format to vLLM format. - - This function merges: - - q_proj, k_proj, v_proj amax values into qkv_proj (taking max) - - gate_proj, up_proj amax values into gate_up_proj (taking max) - - Args: - hf_state_dict: HuggingFace state dict containing amax values - - Returns: - vLLM format state dict with merged amax values - """ - vllm_state_dict = {} - - # Group keys by their base pattern (without the specific projection name) - merge_groups = defaultdict(list) - - for key, value in hf_state_dict.items(): - if "_amax" not in key: - # Copy non-amax keys as-is - vllm_state_dict[key] = value - continue - - # Check if this is a q/k/v projection that needs merging - qkv_match = re.search(r"(.*\.)([qkv])_proj(\..+_amax)$", key) - if qkv_match: - base_pattern = qkv_match.group(1) + "qkv_proj" + qkv_match.group(3) - merge_groups[base_pattern].append((key, value)) - continue - - # Check if this is an expert gate/up projection - # Pattern: model.layers.0.mlp.experts.*.gate_proj.input_quantizer._amax and - # model.layers.0.mlp.experts.*.up_proj.input_quantizer._amax - # Maps to: model.layers.0.mlp.experts.w13_input_quantizer._amax - expert_gate_up_match = ( - "mixer" not in key - and fuse_experts - and re.search(r"(.*\.experts)\.\d+\.(gate|up)_proj\.([^.]+_quantizer\._amax)$", key) - ) - if expert_gate_up_match: - base_pattern = expert_gate_up_match.group(1) + ".w13_" + expert_gate_up_match.group(3) - merge_groups[base_pattern].append((key, value)) - continue - - # Check if this is a non-expert gate/up projection that needs merging - gate_up_match = ( - "mixer" not in key - and "experts" not in key - and re.search(r"(.*\.)(gate|up)_proj(\..+_amax)$", key) - ) - if gate_up_match: - base_pattern = gate_up_match.group(1) + "gate_up_proj" + gate_up_match.group(3) - merge_groups[base_pattern].append((key, value)) - continue - - # Check if this is an expert down_proj - # Pattern: model.layers.0.mlp.experts.*.down_proj.input_quantizer._amax - # Maps to: model.layers.0.mlp.experts.w2_input_quantizer._amax - expert_down_match = ( - "mixer" not in key - and fuse_experts - and re.search(r"(.*\.experts)\.\d+\.down_proj\.([^.]+_quantizer\._amax)$", key) - ) - if expert_down_match: - base_pattern = expert_down_match.group(1) + ".w2_" + expert_down_match.group(2) - merge_groups[base_pattern].append((key, value)) - continue - - # Copy other amax keys as-is (like o_proj, down_proj) - vllm_state_dict[key] = value - - # Merge grouped amax values by taking the maximum - for merged_key, key_value_pairs in merge_groups.items(): - if len(key_value_pairs) > 1: - # Take the maximum across all values for this merged key - values = [value for _, value in key_value_pairs] - merged_value = torch.stack(values).max(dim=0)[0] - vllm_state_dict[merged_key] = merged_value - print(f"Merged {len(key_value_pairs)} keys into {merged_key}") - for orig_key, _ in key_value_pairs: - print(f" - {orig_key}") - else: - # Single key, just rename it - _, value = key_value_pairs[0] - vllm_state_dict[merged_key] = value - - return vllm_state_dict - - @contextmanager def disable_compilation(model): do_not_compile = True @@ -151,7 +62,9 @@ def disable_compilation(model): "calib_size": int(os.environ.get("QUANT_CALIB_SIZE", 512)), "quant_cfg": os.environ.get("QUANT_CFG", None), "kv_quant_cfg": os.environ.get("KV_QUANT_CFG", None), - "amax_file_path": os.environ.get("AMAX_FILE_PATH", None), + "quant_file_path": os.environ.get("QUANT_FILE_PATH", None), + "modelopt_state_path": os.environ.get("MODELOPT_STATE_PATH", None), + "calib_batch_size": int(os.environ.get("CALIB_BATCH_SIZE", 1)), } @@ -194,137 +107,162 @@ def _fakequant_run_prolog_worker(self) -> None: if tokenizer.pad_token != "" or tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token - if quant_config["amax_file_path"]: - print("Will load amax, so only do a single sample calibration") - quant_config["calib_size"] = 1 - - calib_dataloader = get_dataset_dataloader( - dataset_name=quant_config["dataset"], - tokenizer=tokenizer, - batch_size=1, - num_samples=quant_config["calib_size"], - device=self.device, - ) - - def calibrate_loop(model: Any = None) -> None: - for batch_idx, batch in tqdm(enumerate(calib_dataloader)): - input_ids = batch["input_ids"][0] - - # Convert tensor to list of integers for vLLM compatibility - if torch.is_tensor(input_ids): - input_ids_list = input_ids.cpu().tolist() - else: - input_ids_list = list(input_ids) - - num_groups = len(self.model_runner.kv_cache_config.kv_cache_groups) - empty_block_ids = tuple([] for _ in range(num_groups)) - - req_id = f"req-{batch_idx}" - # Pass all possible parameters - the helper will filter based on vLLM version - new_req = _create_new_data_cls( - NewRequestData, - req_id=req_id, - prompt_token_ids=input_ids_list, - # Old API parameters - mm_kwargs=[], # TODO: remove this when vllm <= 0.11 is outdated - mm_hashes=[], # TODO: remove this when vllm <= 0.11 is outdated - mm_positions=[], # TODO: remove this when vllm <= 0.11 is outdated - # New API parameter - mm_features=[], - sampling_params=SamplingParams(max_tokens=1), - pooling_params=None, - block_ids=empty_block_ids, - num_computed_tokens=0, - lora_request=None, - ) - - scheduler_output = _create_new_data_cls( - SchedulerOutput, - scheduled_new_reqs=[new_req], - scheduled_cached_reqs=CachedRequestData.make_empty(), - num_scheduled_tokens={req_id: len(input_ids_list)}, - total_num_scheduled_tokens=len(input_ids_list), - scheduled_spec_decode_tokens={}, - scheduled_encoder_inputs={}, - num_common_prefix_blocks=[0] * num_groups, - finished_req_ids=set(), - free_encoder_mm_hashes=[], - kv_connector_metadata=None, - # Old API parameters - structured_output_request_ids={}, # TODO: remove this when vllm <= 0.11 is outdated - grammar_bitmask=None, # TODO: remove this when vllm <= 0.11 is outdated - ) - output = self.execute_model(scheduler_output) - if hasattr(self, "sample_tokens"): - if output is None: # TODO: make this default when vllm <= 0.11 is outdated - self.sample_tokens(None) - - quant_cfg = {} if quant_config["quant_cfg"] is None else getattr(mtq, quant_config["quant_cfg"]) - quant_kv_cfg = ( - {} if quant_config["kv_quant_cfg"] is None else getattr(mtq, quant_config["kv_quant_cfg"]) - ) - model = self.model_runner.model - if hasattr(model, "unwrap"): - model = model.unwrap() + if quant_config["modelopt_state_path"]: + print(f"Loading modelopt state from {quant_config['modelopt_state_path']}") + # Load on CPU to avoid failures when the checkpoint was saved from a different + # GPU mapping + modelopt_state = torch.load( + quant_config["modelopt_state_path"], weights_only=False, map_location="cpu" + ) + modelopt_weights = modelopt_state.pop("modelopt_state_weights", None) + modelopt_state = convert_modelopt_state_to_vllm(modelopt_state) + restore_from_modelopt_state(model, modelopt_state) - # Check if model has MLA and update KV config accordingly - if quant_kv_cfg: - quant_kv_cfg["quant_cfg"] = update_kv_cfg_for_mla(model, quant_kv_cfg["quant_cfg"]) + if modelopt_weights is not None: + modelopt_weights = convert_dict_to_vllm(modelopt_weights) + mtq.utils.set_quantizer_state_dict(model, modelopt_weights) - if quant_kv_cfg: - quant_cfg = mtq.utils.update_quant_cfg_with_kv_cache_quant( - quant_cfg, quant_kv_cfg["quant_cfg"] + else: + if quant_config["quant_file_path"]: + print("Will load quant, so only do a single sample calibration") + quant_config["calib_size"] = 1 + calib_dataloader = get_dataset_dataloader( + dataset_name=quant_config["dataset"], + tokenizer=tokenizer, + batch_size=quant_config["calib_batch_size"], + num_samples=quant_config["calib_size"], + device=self.device, ) - with disable_compilation(model): - print("quantizing model...") - mtq.quantize(model, quant_cfg, forward_loop=calibrate_loop) - - amax_file_path = quant_config["amax_file_path"] - if amax_file_path: - print(f"Loading amax values from {amax_file_path}") - saved_amax_dict = torch.load(amax_file_path) - # convert amax keys to vLLM format - if hasattr(self.model_runner.model, "hf_to_vllm_mapper"): - saved_amax_dict = self.model_runner.model.hf_to_vllm_mapper.apply_dict(saved_amax_dict) - saved_amax_dict = { - key.replace("quantizer_amax", "quantizer._amax"): value - for key, value in saved_amax_dict.items() - if key.endswith("quantizer_amax") - } - saved_amax_dict = convert_amax_hf2vllm(saved_amax_dict, fuse_experts=True) - - current_state_dict = model.state_dict() - # Count amax keys in checkpoint and model - checkpoint_amax_keys = [key for key in saved_amax_dict if key.endswith("_amax")] - model_amax_keys = [key for key in current_state_dict if key.endswith("_amax")] - for key in checkpoint_amax_keys: - if key not in model_amax_keys: - print(f"Key {key} not found in model state dict, but exists in checkpoint") - for key in model_amax_keys: - if key not in checkpoint_amax_keys: - raise ValueError( - f"Key {key} not found in checkpoint state dict, but exists in model" + def calibrate_loop(model: Any = None) -> None: + for batch_idx, batch in tqdm(enumerate(calib_dataloader)): + input_ids = batch["input_ids"][0] + + # Convert tensor to list of integers for vLLM compatibility + if torch.is_tensor(input_ids): + input_ids_list = input_ids.cpu().tolist() + else: + input_ids_list = list(input_ids) + + num_groups = len(self.model_runner.kv_cache_config.kv_cache_groups) + empty_block_ids = tuple([] for _ in range(num_groups)) + + req_id = f"req-{batch_idx}" + # Pass all possible parameters - the helper will filter based on vLLM version + new_req = _create_new_data_cls( + NewRequestData, + req_id=req_id, + prompt_token_ids=input_ids_list, + # Old API parameters + mm_kwargs=[], # TODO: remove this when vllm <= 0.11 is outdated + mm_hashes=[], # TODO: remove this when vllm <= 0.11 is outdated + mm_positions=[], # TODO: remove this when vllm <= 0.11 is outdated + # New API parameter + mm_features=[], + sampling_params=SamplingParams(max_tokens=1), + pooling_params=None, + block_ids=empty_block_ids, + num_computed_tokens=0, + lora_request=None, ) - checkpoint_amax_count = len(checkpoint_amax_keys) - model_amax_count = len(model_amax_keys) + scheduler_output = _create_new_data_cls( + SchedulerOutput, + scheduled_new_reqs=[new_req], + scheduled_cached_reqs=CachedRequestData.make_empty(), + num_scheduled_tokens={req_id: len(input_ids_list)}, + total_num_scheduled_tokens=len(input_ids_list), + scheduled_spec_decode_tokens={}, + scheduled_encoder_inputs={}, + num_common_prefix_blocks=[0] * num_groups, + finished_req_ids=set(), + free_encoder_mm_hashes=[], + kv_connector_metadata=None, + # Old API parameters + structured_output_request_ids={}, # TODO: remove this when vllm <= 0.11 is outdated + grammar_bitmask=None, # TODO: remove this when vllm <= 0.11 is outdated + ) + output = self.execute_model(scheduler_output) + if hasattr(self, "sample_tokens"): + if output is None: # TODO: make this default when vllm <= 0.11 is outdated + self.sample_tokens(None) + + quant_cfg = getattr(mtq, quant_config["quant_cfg"]) if quant_config["quant_cfg"] else {} + quant_kv_cfg = ( + getattr(mtq, quant_config["kv_quant_cfg"]) if quant_config["kv_quant_cfg"] else {} + ) - # Ensure counts match - if checkpoint_amax_count != model_amax_count: - warnings.warn( - f"Mismatch in amax key counts: checkpoint has {checkpoint_amax_count} " - f"amax keys but model has {model_amax_count} amax keys. This can happen if the model is using PP." + if hasattr(model, "unwrap"): + model = model.unwrap() + + # Check if model has MLA and update KV config accordingly + if quant_kv_cfg: + quant_kv_cfg["quant_cfg"] = update_kv_cfg_for_mla(model, quant_kv_cfg["quant_cfg"]) + + if quant_kv_cfg: + quant_cfg = mtq.utils.update_quant_cfg_with_kv_cache_quant( + quant_cfg, quant_kv_cfg["quant_cfg"] ) - # Update amax values - for key, value in saved_amax_dict.items(): - if key in current_state_dict: - current_state_dict[key] = value.to(current_state_dict[key].device) + with disable_compilation(model): + print("quantizing model...") + mtq.quantize(model, quant_cfg, forward_loop=calibrate_loop) + + quantizer_file_path = quant_config["quant_file_path"] + if quantizer_file_path: + self.model_runner._dummy_run(1) + print(f"Loading quantizer values from {quantizer_file_path}") + # Load on CPU to avoid failures when the checkpoint was saved from a different + # GPU mapping + saved_quant_dict = torch.load(quantizer_file_path, map_location="cpu") + # convert quant keys to vLLM format + if hasattr(self.model_runner.model, "hf_to_vllm_mapper"): + saved_quant_dict = self.model_runner.model.hf_to_vllm_mapper.apply_dict( + saved_quant_dict + ) + saved_quant_dict = { + key.replace("quantizer_", "quantizer._"): value + for key, value in saved_quant_dict.items() + if key.endswith("quantizer_") + } + saved_quant_dict = convert_dict_to_vllm(saved_quant_dict) + + current_state_dict = model.state_dict() + # Count quant keys in checkpoint and model + checkpoint_quant_keys = [key for key in saved_quant_dict if "quantizer" in key] + model_quant_keys = [key for key in current_state_dict if "quantizer" in key] + for key in checkpoint_quant_keys: + if key not in model_quant_keys: + print(f"Key {key} not found in model state dict, but exists in checkpoint") + for key in model_quant_keys: + if key not in checkpoint_quant_keys: + raise ValueError( + f"Key {key} not found in checkpoint state dict, but exists in model" + ) + + checkpoint_quant_count = len(checkpoint_quant_keys) + model_quant_count = len(model_quant_keys) + + # Ensure counts match + if checkpoint_quant_count != model_quant_count: + warnings.warn( + f"Mismatch in quantizer state key counts: checkpoint has {checkpoint_quant_count} " + f"quant keys but model has {model_quant_count} quantizer state keys. " + f"This can happen if the model is using PP." + ) + + # Update quant values + saved_quant_dict = process_state_dict_for_tp(saved_quant_dict, current_state_dict) + for key, value in saved_quant_dict.items(): + if key in current_state_dict: + current_state_dict[key] = value.to(current_state_dict[key].device) + + model.load_state_dict(current_state_dict) - model.load_state_dict(current_state_dict) - torch.distributed.barrier() + # Only barrier if distributed is actually initialized (avoids deadlocks). + if torch.distributed.is_initialized() and torch.distributed.get_world_size() > 1: + torch.distributed.barrier() if not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0: mtq.print_quant_summary(model) @@ -345,6 +283,10 @@ def determine_available_memory(self) -> int: return super().determine_available_memory() def compile_or_warm_up_model(self) -> None: - if quant_config["quant_cfg"] or quant_config["kv_quant_cfg"]: + if ( + quant_config["quant_cfg"] + or quant_config["kv_quant_cfg"] + or quant_config["modelopt_state_path"] + ): _fakequant_run_prolog_worker(self) super().compile_or_warm_up_model() diff --git a/examples/vllm_serve/hf_ptq_export.py b/examples/vllm_serve/hf_ptq_export.py new file mode 100644 index 000000000..fda5a6dec --- /dev/null +++ b/examples/vllm_serve/hf_ptq_export.py @@ -0,0 +1,311 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import random +import warnings + +import numpy as np +import torch +import transformers +from accelerate import infer_auto_device_map, init_empty_weights +from accelerate.utils import get_max_memory +from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer + +import modelopt.torch.opt as mto +import modelopt.torch.quantization as mtq +from modelopt.torch.export import export_hf_vllm_fq_checkpoint +from modelopt.torch.quantization.utils import is_quantized +from modelopt.torch.utils.dataset_utils import ( + create_forward_loop, + get_dataset_dataloader, + get_max_batch_size, + get_supported_datasets, +) +from modelopt.torch.utils.memory_monitor import launch_memory_monitor + +RAND_SEED = 1234 + +mto.enable_huggingface_checkpointing() + + +def load_model( + ckpt_path, + device="cuda", + gpu_mem_percentage=0.8, + trust_remote_code=False, + use_seq_device_map=False, +): + print(f"Initializing model from {ckpt_path}") + + config_kwargs = {"trust_remote_code": trust_remote_code} if trust_remote_code else {} + try: + hf_config = AutoConfig.from_pretrained(ckpt_path, **config_kwargs) + except Exception as e: + raise RuntimeError(f"Failed to load model configuration from {ckpt_path}") from e + + # Pick the transformers model class to load. + architecture = hf_config.architectures[0] + use_auto_causallm = (not hasattr(transformers, architecture)) or ("Deepseek" in architecture) + if use_auto_causallm: + if not hasattr(transformers, architecture): + warnings.warn( + f"Architecture {architecture} not found in transformers: {transformers.__version__}. " + "Falling back to AutoModelForCausalLM." + ) + assert trust_remote_code, ( + "Please set trust_remote_code=True if you want to use this architecture" + ) + model_cls = AutoModelForCausalLM + from_config = model_cls.from_config + else: + model_cls = getattr(transformers, architecture) + from_config = model_cls._from_config + + # Decide device_map and optional memory cap. + if device == "cpu": + device_map = "cpu" + elif use_seq_device_map: + device_map = "sequential" + else: + device_map = "auto" + + model_kwargs: dict[str, object] = dict(config_kwargs) + if device_map == "sequential": + max_memory = get_max_memory() + model_kwargs["max_memory"] = {k: v * gpu_mem_percentage for k, v in max_memory.items()} + + # Detect if the model would offload to CPU; if so, cap GPU memory for calibration. + with init_empty_weights(): + torch_dtype = getattr(hf_config, "torch_dtype", torch.bfloat16) + empty_kwargs: dict[str, object] = dict(model_kwargs, torch_dtype=torch_dtype) + empty_kwargs.pop("max_memory", None) # only used by from_pretrained dispatch + if model_cls is not AutoModelForCausalLM: + empty_kwargs.pop("trust_remote_code", None) + empty_model = from_config(hf_config, **empty_kwargs) + + max_memory = get_max_memory() + inferred_device_map = infer_auto_device_map(empty_model, max_memory=max_memory) + if "cpu" in inferred_device_map.values(): + for dev_id, mem in list(max_memory.items()): + if isinstance(dev_id, int): + max_memory[dev_id] = mem * gpu_mem_percentage + print( + "Model does not fit to the GPU mem. " + f"We apply the following memory limit for calibration: \n{max_memory}\n" + "If you hit GPU OOM issue, please adjust `gpu_mem_percentage` or " + "reduce the calibration `batch_size` manually." + ) + model_kwargs["max_memory"] = max_memory + + model = model_cls.from_pretrained(ckpt_path, device_map=device_map, **model_kwargs) + model.eval() + + # If device_map was disabled (None), manually move model to target device + if device_map is None and device != "cpu": + print(f"Moving model to {device} device...") + model = model.to(device) + + if device == "cuda" and not is_model_on_gpu(model): + print("Warning: Some parameters are not on a GPU. Calibration can be slow or hit OOM") + + return model + + +def is_model_on_gpu(model) -> bool: + """Returns if the model is fully loaded on GPUs.""" + return all("cuda" in str(param.device) for param in model.parameters()) + + +def get_tokenizer(ckpt_path, trust_remote_code=False): + """Returns the tokenizer from the model ckpt_path.""" + print(f"Initializing tokenizer from {ckpt_path}") + tokenizer = AutoTokenizer.from_pretrained( + ckpt_path, + padding_side="left", + trust_remote_code=trust_remote_code, + ) + + # can't set attribute 'pad_token' for "" + if tokenizer.pad_token != "": + tokenizer.pad_token = tokenizer.eos_token + + return tokenizer + + +def quantize_and_export_model( + args: argparse.Namespace, +): + model = load_model( + args.pyt_ckpt_path, + device=args.device, + gpu_mem_percentage=args.gpu_max_mem_percentage, + trust_remote_code=args.trust_remote_code, + use_seq_device_map=args.use_seq_device_map, + ) + + if args.batch_size == 0: + args.batch_size = get_max_batch_size( + model, + max_sample_length=args.calib_seq, + ) + args.batch_size = min(args.batch_size, sum(args.calib_size)) + + print(f"Use calib batch_size {args.batch_size}") + tokenizer = get_tokenizer(args.pyt_ckpt_path, trust_remote_code=args.trust_remote_code) + device = model.device + calib_dataloader = get_dataset_dataloader( + dataset_name=args.dataset, + tokenizer=tokenizer, + batch_size=args.batch_size, + num_samples=args.calib_size, + device=device, + include_labels=False, + ) + calibrate_loop = create_forward_loop(dataloader=calib_dataloader) + mtq_cfg = getattr(mtq, args.quant_cfg) + if args.kv_cache_quant_cfg is not None: + kv_cache_quant_cfg = getattr(mtq, args.kv_cache_quant_cfg) + mtq_cfg = mtq.utils.update_quant_cfg_with_kv_cache_quant( + mtq_cfg, kv_cache_quant_cfg["quant_cfg"] + ) + input_ids = next(iter(calib_dataloader))["input_ids"][0:1] + model_is_already_quantized = is_quantized(model) + if not model_is_already_quantized: + generated_str_before_ptq = tokenizer.decode(model.generate(input_ids)[0]) + quantized_model = mtq.quantize(model, mtq_cfg, calibrate_loop) + generated_str_after_ptq = tokenizer.decode(model.generate(input_ids)[0]) + else: + print("Model is already quantized, Skipping quantization...") + quantized_model = model + + mtq.print_quant_summary(quantized_model) + if not model_is_already_quantized: + print("--------") + print(f"example test input: {tokenizer.decode(input_ids[0])}") + print("--------") + print(f"example outputs before ptq: {generated_str_before_ptq}") + print("--------") + print(f"example outputs after ptq: {generated_str_after_ptq}") + + export_hf_vllm_fq_checkpoint(quantized_model, args.export_path) + tokenizer.save_pretrained(args.export_path) + print(f"Model exported to {args.export_path}") + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument( + "--pyt_ckpt_path", + help="Specify where the PyTorch checkpoint path is", + required=True, + ) + parser.add_argument("--device", default="cuda") + parser.add_argument( + "--quant_cfg", + help="Quantization configuration.", + default="FP8_DEFAULT_CFG", + ) + parser.add_argument( + "--batch_size", + help="Batch size for calibration. Default to 0 as we calculate max batch size on-the-fly", + type=int, + default=0, + ) + parser.add_argument( + "--calib_size", + help=( + "Number of samples for calibration. If a comma separated list of values is provided, " + "each value will be used as the calibration size for the corresponding dataset. " + "This argument will be parsed and converted as a list of ints." + ), + type=str, + default="512", + ) + parser.add_argument( + "--calib_seq", + help="Maximum sequence length for calibration.", + type=int, + default=512, + ) + parser.add_argument("--export_path", default="exported_model") + parser.add_argument( + "--dataset", + help=( + f"name of a dataset, or a comma separated list of datasets. " + f"dataset choices are {get_supported_datasets()}" + ), + type=str, + default=None, + ) + parser.add_argument( + "--kv_cache_quant_cfg", + required=False, + default=None, + help="Specify KV cache quantization configuration, default to None if not provided", + ) + parser.add_argument( + "--trust_remote_code", + help="Set trust_remote_code for Huggingface models and tokenizers", + default=False, + action="store_true", + ) + parser.add_argument( + "--gpu_max_mem_percentage", + help=( + "Specify the percentage of available GPU memory to use for loading the model when " + "device_map is set to sequential. " + "By default, 80%% of the available GPU memory is used." + ), + type=float, + default=0.8, + ) + parser.add_argument( + "--use_seq_device_map", + help=( + "Use device_map=sequential to load the model onto GPUs. This ensures the model is loaded " + "utilizing the percentage of available GPU memory as specified by the value passed with gpu_max_mem flag." + "Helpful in cases where device_map=auto loads model unevenly on GPUs causing GPU OOM during quantization." + ), + default=False, + action="store_true", + ) + + return parser.parse_args() + + +def main(args: argparse.Namespace): + if not torch.cuda.is_available(): + raise OSError("GPU is required for inference.") + + random.seed(RAND_SEED) + np.random.seed(RAND_SEED) + + # launch a memory monitor to read the currently used GPU memory. + launch_memory_monitor() + + # Force eager execution for all model types. + torch.compiler.set_stance("force_eager") + + # Quantize + quantize_and_export_model(args) + + +if __name__ == "__main__": + args = parse_args() + + args.dataset = args.dataset.split(",") if args.dataset else None + args.calib_size = [int(num_sample) for num_sample in args.calib_size.split(",")] + main(args) diff --git a/examples/vllm_serve/vllm_reload_utils.py b/examples/vllm_serve/vllm_reload_utils.py new file mode 100644 index 000000000..a72cb0b3c --- /dev/null +++ b/examples/vllm_serve/vllm_reload_utils.py @@ -0,0 +1,272 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import re +from collections import defaultdict +from typing import Any + +import torch +from vllm.distributed.parallel_state import get_tp_group + + +def _values_equal(v1: Any, v2: Any) -> bool: + """Compare values, handling dicts with tensors.""" + if isinstance(v1, dict) and isinstance(v2, dict): + if v1.keys() != v2.keys(): + return False + return all( + torch.equal(v1[k], v2[k]) if isinstance(v1[k], torch.Tensor) else v1[k] == v2[k] + for k in v1 + ) + elif isinstance(v1, torch.Tensor) and isinstance(v2, torch.Tensor): + return torch.equal(v1, v2) + return v1 == v2 + + +def _convert_key_for_vllm(key: str, value: Any) -> tuple[str, str | None, Any]: + """ + Transform a single key from HuggingFace format to vLLM format. + + Returns: + Tuple of (action, new_key_or_group, value) where action is one of: + - "copy": Copy value to new_key directly + - "group": Add to merge group identified by new_key + - "skip": Skip this key entirely + """ + if "quantizer" not in key: + return ("copy", key, value) + + # Skip softmax_quantizer and lm_head quantizers(not needed in vLLM) + if "softmax_quantizer" in key or (key.startswith("lm_head.") and "quantizer" in key): + return ("skip", None, None) + + # Check if this is a q/k/v projection that needs merging + qkv_match = re.search(r"(.*\.)([qkv])_proj\.([^.]+_quantizer)(\..+)?$", key) + if qkv_match: + suffix = qkv_match.group(4) or "" + group_key = qkv_match.group(1) + "qkv_proj." + qkv_match.group(3) + suffix + return ("group", group_key, value) + + # Check if this is an expert gate/up projection + if "mixer" not in key: + expert_gate_up_match = re.search( + r"(.*\.experts)\.\d+\.(gate|up)_proj\.([^.]+_quantizer)(\..+)?$", key + ) + if expert_gate_up_match: + suffix = expert_gate_up_match.group(4) or "" + group_key = ( + expert_gate_up_match.group(1) + ".w13_" + expert_gate_up_match.group(3) + suffix + ) + return ("group", group_key, value) + + # Check if this is a non-expert gate/up projection that needs merging + if "mixer" not in key and "experts" not in key: + gate_up_match = re.search(r"(.*\.)(gate|up)_proj\.([^.]+_quantizer)(\..+)?$", key) + if gate_up_match: + suffix = gate_up_match.group(4) or "" + group_key = gate_up_match.group(1) + "gate_up_proj." + gate_up_match.group(3) + suffix + return ("group", group_key, value) + + # Check if this is an expert down_proj + if "mixer" not in key: + expert_down_match = re.search( + r"(.*\.experts)\.\d+\.down_proj\.([^.]+_quantizer)(\..+)?$", key + ) + if expert_down_match: + suffix = expert_down_match.group(3) or "" + group_key = expert_down_match.group(1) + ".w2_" + expert_down_match.group(2) + suffix + return ("group", group_key, value) + + # Transform bmm_quantizer keys: self_attn.q/k/v_bmm_quantizer -> self_attn.attn.q/k/v_bmm_quantizer + bmm_match = re.search(r"(.*\.self_attn)\.([qkv]_bmm_quantizer.*)$", key) + if bmm_match: + new_key = bmm_match.group(1) + ".attn." + bmm_match.group(2) + return ("copy", new_key, value) + + # Copy other quantizer keys as-is (like o_proj, down_proj) + return ("copy", key, value) + + +def _group_keys_for_vllm( + state_dict: dict[str, Any], +) -> tuple[dict[str, Any], defaultdict[str, list[tuple[str, Any]]]]: + """ + Process state dict and group keys that need merging. + + Returns: + Tuple of (direct_copy_dict, merge_groups) + """ + vllm_state_dict = {} + merge_groups = defaultdict(list) + + for key, value in state_dict.items(): + action, new_key, new_value = _convert_key_for_vllm(key, value) + if new_key is None or new_value is None: + assert action == "skip", ( + f"Expected action to be 'skip' for key {key}, value {value}, got {action}" + ) + continue + if action == "copy": + vllm_state_dict[new_key] = new_value + elif action == "group": + merge_groups[new_key].append((key, new_value)) + # action == "skip" does nothing + + return vllm_state_dict, merge_groups + + +def _merge_values_by_max_or_concat(merged_key: str, key_value_pairs: list[tuple[str, Any]]) -> Any: + """ + Merge values by taking max for amax, concatenating for others. + Used for quantizer state weights (tensor values). + """ + values = [value for _, value in key_value_pairs] + + # Check if values are dicts (OrderedDict) containing tensors + if isinstance(values[0], dict): + merged_value = {} + for dict_key in values[0]: + tensors = [v[dict_key] for v in values] + if "_amax" in dict_key: + merged_value[dict_key] = torch.stack(tensors).max(dim=0)[0] + else: + merged_value[dict_key] = torch.cat(tensors, dim=0) + return merged_value + else: + # Values are tensors directly + if "_amax" in merged_key: + merged_value = torch.stack(values).max(dim=0)[0] + else: + merged_value = torch.cat(values, dim=0) + return merged_value + + +def _merge_values_require_identical(merged_key: str, key_value_pairs: list[tuple[str, Any]]) -> Any: + """ + Merge values by requiring all values to be identical. + Used for quantizer state (config/metadata). + """ + keys = [k for k, _ in key_value_pairs] + values = [v for _, v in key_value_pairs] + first_value = values[0] + + for i, val in enumerate(values[1:], start=1): + if not _values_equal(val, first_value): + raise ValueError( + f"Cannot merge keys into '{merged_key}': values differ.\n" + f" '{keys[0]}' has value: {first_value}\n" + f" '{keys[i]}' has value: {val}" + ) + return first_value + + +def convert_dict_to_vllm( + state_dict: dict[str, Any], merge_mode: str = "max_or_concat" +) -> dict[str, Any]: + """ + Common implementation for converting quantizer state from HF to vLLM format. + + Args: + state_dict: Input state dict + fuse_experts: Whether to fuse expert projections + merge_mode: Mode to merge grouped values, "max_or_concat" or "require_identical" + """ + vllm_state_dict, merge_groups = _group_keys_for_vllm(state_dict) + + merge_fn = ( + _merge_values_require_identical + if merge_mode == "require_identical" + else _merge_values_by_max_or_concat + ) + + # Merge grouped values + for merged_key, key_value_pairs in merge_groups.items(): + if len(key_value_pairs) > 1: + merged_value = merge_fn(merged_key, key_value_pairs) + vllm_state_dict[merged_key] = merged_value + else: + # Single key, just rename it + _, value = key_value_pairs[0] + vllm_state_dict[merged_key] = value + + return vllm_state_dict + + +def convert_modelopt_state_to_vllm(modelopt_state: dict[str, Any]) -> dict[str, Any]: + """ + Convert modelopt state from HuggingFace format to vLLM compatible format. + + This function converts the quantizer state from HuggingFace format to vLLM compatible format. + + Args: + modelopt_state: HuggingFace modelopt state dict + + Returns: + vLLM compatible modelopt state dict + """ + modelopt_state_dict = modelopt_state.pop("modelopt_state_dict", []) + for idx, current_mode in enumerate(modelopt_state_dict): + current_mode_metadata = current_mode[1].pop("metadata", {}) + current_mode_quant_state = current_mode_metadata.pop("quantizer_state", {}) + if current_mode_quant_state: + current_mode_metadata["quantizer_state"] = convert_dict_to_vllm( + current_mode_quant_state, merge_mode="require_identical" + ) + else: + current_mode_metadata.pop("quantizer_state", None) + current_mode[1]["metadata"] = current_mode_metadata + modelopt_state_dict[idx] = (current_mode[0], current_mode[1]) + modelopt_state["modelopt_state_dict"] = modelopt_state_dict + return modelopt_state + + +def process_state_dict_for_tp(saved_qstate_dict, current_state_dict): + """Shard quantizer tensors for tensor parallelism by matching expected shapes.""" + tp_group = get_tp_group() + tp_rank = tp_group.rank_in_group + tp_world_size = tp_group.world_size + + result = {} + for key, value in saved_qstate_dict.items(): + if key in current_state_dict: + expected_shape = current_state_dict[key].shape + if value.shape != expected_shape: + # Find the dimension that was tensor-parallel sharded. + # We expect exactly one dimension to satisfy: + # checkpoint_dim == expected_dim * tp_world_size + shard_dims = [ + d + for d in range(len(expected_shape)) + if value.shape[d] == expected_shape[d] * tp_world_size + ] + if len(shard_dims) != 1: + raise ValueError( + f"Cannot infer TP shard dim for {key}: " + f"expected_shape={tuple(expected_shape)}, checkpoint_shape={tuple(value.shape)}, " + ) + + shard_dim = shard_dims[0] + shard_size = expected_shape[shard_dim] + start = tp_rank * shard_size + end = start + shard_size + if end > value.shape[shard_dim]: + raise ValueError( + f"TP shard out of bounds for {key}: " + f"expected_shape={tuple(expected_shape)}, checkpoint_shape={tuple(value.shape)})" + ) + value = value.narrow(shard_dim, start, shard_size).contiguous() + result[key] = value + + return result diff --git a/examples/vllm_serve/vllm_serve_fakequant.py b/examples/vllm_serve/vllm_serve_fakequant.py index 25483f2be..c32593005 100644 --- a/examples/vllm_serve/vllm_serve_fakequant.py +++ b/examples/vllm_serve/vllm_serve_fakequant.py @@ -74,8 +74,10 @@ "QUANT_DATASET", "QUANT_CALIB_SIZE", "QUANT_CFG", - "AMAX_FILE_PATH", + "QUANT_FILE_PATH", "KV_QUANT_CFG", + "MODELOPT_STATE_PATH", + "CALIB_BATCH_SIZE", } RayDistributedExecutor.ADDITIONAL_ENV_VARS.update(additional_env_vars) diff --git a/modelopt/torch/export/plugins/vllm_fakequant_hf.py b/modelopt/torch/export/plugins/vllm_fakequant_hf.py index 54987b40c..b30e7530d 100644 --- a/modelopt/torch/export/plugins/vllm_fakequant_hf.py +++ b/modelopt/torch/export/plugins/vllm_fakequant_hf.py @@ -15,16 +15,37 @@ """Export HuggingFace model to vLLM fakequant checkpoint.""" from pathlib import Path +from typing import Any import torch import torch.nn as nn -from modelopt.torch.export.layer_utils import is_quantlinear +import modelopt.torch.opt as mto +from modelopt.torch.export.layer_utils import is_attention, is_quantlinear from modelopt.torch.quantization.utils import get_quantizer_state_dict __all__ = ["export_hf_vllm_fq_checkpoint"] +def cleanup_for_torch_save(x: Any) -> Any: + """Drop callables / local closures (e.g. `.new_forward`) before torch.save. + + ModelOpt stored state dict may contain local closures like `.new_forward` + which are not picklable. So we need to cleanup the state dict before saving. + """ + if isinstance(x, dict): + return { + k: cleanup_for_torch_save(v) + for k, v in x.items() + if not callable(v) and "" not in str(getattr(v, "__qualname__", "")) + } + if isinstance(x, list): + return [cleanup_for_torch_save(v) for v in x] + if isinstance(x, tuple): + return tuple(cleanup_for_torch_save(v) for v in x) + return x + + def export_hf_vllm_fq_checkpoint( model: nn.Module, export_dir: Path | str, @@ -44,12 +65,12 @@ def export_hf_vllm_fq_checkpoint( export_dir = Path(export_dir) export_dir.mkdir(parents=True, exist_ok=True) - amax_dict = { - name + "._amax": param["_amax"].detach().clone().cpu() - for name, param in get_quantizer_state_dict(model).items() - if "_amax" in param - } + quantizer_state_dict = get_quantizer_state_dict(model) + modelopt_state = mto.modelopt_state(model) + modelopt_state = cleanup_for_torch_save(modelopt_state) + modelopt_state["modelopt_state_weights"] = cleanup_for_torch_save(quantizer_state_dict) + torch.save(modelopt_state, export_dir / "vllm_fq_modelopt_state.pth") # remove quantizer from model for _, module in model.named_modules(): if is_quantlinear(module): @@ -57,6 +78,15 @@ def export_hf_vllm_fq_checkpoint( if hasattr(module, attr): delattr(module, attr) module.export() - torch.save(amax_dict, f"{export_dir}/quant_amax.pth") + if is_attention(module): + for attr in [ + "q_bmm_quantizer", + "k_bmm_quantizer", + "v_bmm_quantizer", + "softmax_quantizer", + ]: + if hasattr(module, attr): + delattr(module, attr) + # Save model model.save_pretrained(export_dir, state_dict=model.state_dict(), save_modelopt_state=False) diff --git a/modelopt/torch/export/plugins/vllm_fakequant_megatron.py b/modelopt/torch/export/plugins/vllm_fakequant_megatron.py index 95b194c3f..0549e2e9d 100644 --- a/modelopt/torch/export/plugins/vllm_fakequant_megatron.py +++ b/modelopt/torch/export/plugins/vllm_fakequant_megatron.py @@ -22,6 +22,7 @@ from modelopt.torch.export.model_config import QUANTIZATION_NONE from modelopt.torch.export.unified_export_megatron import GPTModelExporter +from modelopt.torch.quantization.utils import get_quantizer_state_dict __all__ = ["export_mcore_gpt_to_hf_vllm_fq"] @@ -38,8 +39,8 @@ def gather_mcore_vllm_fq_quantized_state_dict( Returns: The state dictionary of the module without quantized state. """ - amax_state_dict = { - k: v.detach().clone().cpu() for k, v in state_dict.items() if k.endswith("_amax") + quantizer_state_dict = { + k: v.detach().clone().cpu() for k, v in state_dict.items() if "quantizer" in k } # Gather all amax dicts to rank 0 @@ -48,20 +49,19 @@ def gather_mcore_vllm_fq_quantized_state_dict( if rank == 0: # Rank 0 will collect all amax values - all_amax_dicts = [None] * world_size - torch.distributed.gather_object(amax_state_dict, all_amax_dicts, dst=0) + all_quantizer_state_dicts = [None] * world_size + torch.distributed.gather_object(quantizer_state_dict, all_quantizer_state_dicts, dst=0) - # Merge all amax dicts into one - merged_amax_dict = {} - for amax_dict in all_amax_dicts: - if amax_dict is not None: - merged_amax_dict.update(amax_dict) + # Merge all quantizer state dicts into one + merged_quantizer_state_dict = {} + for quantizer_state_dict in all_quantizer_state_dicts: + if quantizer_state_dict is not None: + merged_quantizer_state_dict.update(quantizer_state_dict) - print(f"Total amax entries from all ranks: {len(merged_amax_dict.keys())}") - torch.save(merged_amax_dict, save_directory + "/quant_amax.pth") + torch.save(merged_quantizer_state_dict, save_directory + "/quantizer_state.pth") else: # Other ranks just send their amax values - torch.distributed.gather_object(amax_state_dict, None, dst=0) + torch.distributed.gather_object(quantizer_state_dict, None, dst=0) torch.distributed.barrier() @@ -76,6 +76,13 @@ def save_pretrained( ): os.makedirs(save_directory, exist_ok=True) gather_mcore_vllm_fq_quantized_state_dict(self.model, self.state_dict, save_directory) + + # NOTE: `self.state_dict` is an OrderedDict; mutating it while iterating + # over its keys raises "OrderedDict mutated during iteration". + keys_to_remove = [k for k in self.state_dict if "quantizer" in k] + for k in keys_to_remove: + self.state_dict.pop(k, None) + assert not (self.is_multimodal and pretrained_model_name_or_path is not None), ( "Exporting weights in bf16 and amax values is not supported for multimodal models " "when pretrained_model_name_or_path is not None" @@ -88,6 +95,37 @@ def save_pretrained( def _get_quantization_format(self, module: torch.nn.Module): return QUANTIZATION_NONE + def _get_quantized_state( + self, + module: torch.nn.Module, + dtype: torch.dtype = torch.float16, + ) -> tuple[dict[str, torch.Tensor], str, int]: + """Return a state_dict, quantization format, and block_size of the module. + + Args: + module: The target module to perform real quantization. + dtype: The default data type. + + Returns: + Tuple: state_dict, quantization format, and block_size of the module. + """ + name_to_value = {} + qformat: str = self._get_quantization_format(module) + block_size = 0 + + if hasattr(module, "weight") and module.weight is not None: + weight = module.weight.to(dtype).cpu() + name_to_value["weight"] = weight + else: + return name_to_value, qformat, block_size + + if hasattr(module, "bias") and module.bias is not None: + name_to_value["bias"] = module.bias.to(dtype).cpu() + for name, param in get_quantizer_state_dict(module).items(): + for key, value in param.items(): + name_to_value[name + "." + key] = value.to(dtype).cpu() + return name_to_value, qformat, block_size + def export_mcore_gpt_to_hf_vllm_fq( model: torch.nn.Module, diff --git a/modelopt/torch/quantization/conversion.py b/modelopt/torch/quantization/conversion.py index c93ea546f..fce232eb6 100644 --- a/modelopt/torch/quantization/conversion.py +++ b/modelopt/torch/quantization/conversion.py @@ -130,7 +130,7 @@ def restore_quantizer_state(model: nn.Module, config: QuantizeConfig, metadata: for name, module in model.named_modules(): if isinstance(module, QuantModule): name = get_unwrapped_name(name, model) - module.modelopt_post_restore(name) + module.modelopt_post_restore(name, model=model) return model diff --git a/modelopt/torch/quantization/nn/modules/quant_module.py b/modelopt/torch/quantization/nn/modules/quant_module.py index 12aaee3f8..0037fa062 100644 --- a/modelopt/torch/quantization/nn/modules/quant_module.py +++ b/modelopt/torch/quantization/nn/modules/quant_module.py @@ -37,7 +37,7 @@ class QuantModule(DynamicModule): """A base class for quantized modules.""" - def modelopt_post_restore(self, prefix: str = ""): + def modelopt_post_restore(self, prefix: str = "", model: "torch.nn.Module | None" = None): """Post-restore to correctly configure the TensorQuantizer states. TensorQuantizer states are restored to their shape before saving. Now we need to further configure them. @@ -46,6 +46,10 @@ def modelopt_post_restore(self, prefix: str = ""): 2. For sharded modules the restored states of TensorQuantizer could be incorrect. This is because parallelism such as TP might have been changed between saving and resoring. So we need to re-calculate the state shapes. Hence such modules should override this and implement their own logic. + + Args: + prefix: The module name prefix for error messages. + model: Optional main model to search for device if not found in this module. """ # Get a parameter or buffer that does not belong to a TensorQuantizer non_tq_param_or_buffer = None @@ -55,6 +59,21 @@ def modelopt_post_restore(self, prefix: str = ""): non_tq_param_or_buffer = param_or_buffer break + # If not found (e.g., container modules like vLLM's attn that only have child quantizers), + # traverse up to parent's parent to find a module with parameters + if non_tq_param_or_buffer is None and model is not None: + parts = prefix.split(".") + parent_prefix = ".".join(parts[: len(parts) - 1]) + parent_module = model.get_submodule(parent_prefix) if parent_prefix else model + # Look for any parameter in parent module (not just state_dict) + for name, param in parent_module.named_parameters(): + # Skip params that belong to TensorQuantizer submodules + param_parent_name = name.rsplit(".", 1)[0] if "." in name else "" + param_parent = parent_module.get_submodule(param_parent_name) + if not isinstance(param_parent, TensorQuantizer): + non_tq_param_or_buffer = param + break + if non_tq_param_or_buffer is None: warnings.warn( f"Could not identify the device for TensorQuantizer states of {prefix}. " diff --git a/modelopt/torch/quantization/plugins/custom.py b/modelopt/torch/quantization/plugins/custom.py index 4200aadc7..09a91796d 100644 --- a/modelopt/torch/quantization/plugins/custom.py +++ b/modelopt/torch/quantization/plugins/custom.py @@ -114,7 +114,7 @@ def _setup(self): # the dtype can change later. self.original_weight_dtype = None if self.weight is None else self.weight.dtype - def modelopt_post_restore(self, prefix: str = ""): + def modelopt_post_restore(self, prefix: str = "", *args, **kwargs): """Post restore to correctly configure the TensorQuantizer states for MCore/distributed frameworks. ModelOpt restores the TensorQuantizer states such as `_amax` and `_pre_quant_scale` to their diff --git a/modelopt/torch/quantization/plugins/megatron.py b/modelopt/torch/quantization/plugins/megatron.py index f47451eb0..630b6e4ce 100644 --- a/modelopt/torch/quantization/plugins/megatron.py +++ b/modelopt/torch/quantization/plugins/megatron.py @@ -471,7 +471,7 @@ def _get_shard_axis_dict(self, state_dict): shard_axis_dict[k] = self._scale_tensor_shard_axis return shard_axis_dict - def modelopt_post_restore(self, prefix: str = ""): + def modelopt_post_restore(self, prefix: str = "", *args, **kwargs): """Post restore to correctly configure the realquant scales. ModelOpt restores the TensorQuantizer states such as `_amax` and `_pre_quant_scale` to their @@ -715,7 +715,7 @@ def forward(self, query, key, value, *args, **kwargs): value = self.v_bmm_quantizer(value) return super().forward(query, key, value, *args, **kwargs) - def modelopt_post_restore(self, name=""): + def modelopt_post_restore(self, name="", *args, **kwargs): """Restore quantizer states after model loading.""" for tq in [self.q_bmm_quantizer, self.k_bmm_quantizer, self.v_bmm_quantizer]: # TODO: Add support for non-scalar states such as diff --git a/modelopt/torch/quantization/plugins/transformer_engine.py b/modelopt/torch/quantization/plugins/transformer_engine.py index cec7ff956..70e6e7e2c 100644 --- a/modelopt/torch/quantization/plugins/transformer_engine.py +++ b/modelopt/torch/quantization/plugins/transformer_engine.py @@ -141,7 +141,7 @@ def _setup(self): # TODO: GroupedLinear supports weights split by `num_gemms`, to support quantization # with static parameters beyond per-tensor, we need to support a unique quantizer for each gemm. - def modelopt_post_restore(self, prefix: str = ""): + def modelopt_post_restore(self, prefix: str = "", *args, **kwargs): # GroupedMLP stores the weights as weight0, weight1, etc. To run post_restore in order to # initialize the quantizer states, self.weight is used to extract shape, dtype etc. Assigning # self.weight0 to self.weight to run the quantizer states initialization. diff --git a/tests/gpu/torch/export/test_vllm_fakequant_hf_export.py b/tests/gpu/torch/export/test_vllm_fakequant_hf_export.py index a156ad126..2c2c44ef2 100644 --- a/tests/gpu/torch/export/test_vllm_fakequant_hf_export.py +++ b/tests/gpu/torch/export/test_vllm_fakequant_hf_export.py @@ -48,7 +48,7 @@ def forward_loop(model): model(input_ids) model = mtq.quantize(model, quant_cfg, forward_loop) - + quantizer_state_dict_before = mtq.utils.get_quantizer_state_dict(model) model_state_dict = deepcopy(model.state_dict()) # Export directory @@ -59,8 +59,10 @@ def forward_loop(model): export_hf_vllm_fq_checkpoint(model, export_dir=export_dir) # check if quant_amax.pth file exists - quant_amax_file = export_dir / "quant_amax.pth" - assert quant_amax_file.exists(), f"quant_amax.pth file should be created in {export_dir}" + modelopt_state_file = export_dir / "vllm_fq_modelopt_state.pth" + assert modelopt_state_file.exists(), ( + f"vllm_fq_modelopt_state.pth file should be created in {export_dir}" + ) # make sure hf_quant_config.json file does not exist hf_quant_config_file = export_dir / "hf_quant_config.json" @@ -73,21 +75,19 @@ def forward_loop(model): model_after = model_after.cuda() model_after.eval() model_after_state_dict = model_after.state_dict() - amax_state_dict = {} for key, param in model_state_dict.items(): - if key.endswith("_amax"): - amax_state_dict[key] = param - continue - - assert torch.allclose(param, model_after_state_dict[key], atol=1e-6), ( - f"Weight mismatch for {key}: " - f"before shape={param.shape}, after shape={model_after_state_dict[key].shape}, " - f"max diff={torch.abs(param - model_after_state_dict[key]).max()}" - ) - - # Verify amax values are correct - amax_dict = torch.load(quant_amax_file) - assert len(amax_dict) > 0, "amax_dict should not be empty" - assert amax_dict.keys() == amax_state_dict.keys(), ( - "amax keys mismatch between before and after export" + if "quantizer" not in key: + assert torch.allclose(param, model_after_state_dict[key], atol=1e-6), ( + f"Weight mismatch for {key}: " + f"before shape={param.shape}, after shape={model_after_state_dict[key].shape}, " + f"max diff={torch.abs(param - model_after_state_dict[key]).max()}" + ) + + # Verify quantizer state dict values are correct + quantizer_state_dict = torch.load(modelopt_state_file)["modelopt_state_weights"] + assert len(quantizer_state_dict) > 0, ( + f"modelopt_state_weights should not be empty in {modelopt_state_file}" + ) + assert quantizer_state_dict.keys() == quantizer_state_dict_before.keys(), ( + "quantizer state dict keys mismatch between before and after export" ) diff --git a/tests/gpu/torch/export/test_vllm_fakequant_megatron_export.py b/tests/gpu/torch/export/test_vllm_fakequant_megatron_export.py index ea351db6a..2fff24ea1 100644 --- a/tests/gpu/torch/export/test_vllm_fakequant_megatron_export.py +++ b/tests/gpu/torch/export/test_vllm_fakequant_megatron_export.py @@ -99,8 +99,8 @@ def forward_loop(model): ) # check if quant_amax.pth file exists - quant_amax_file = export_dir / "quant_amax.pth" - assert quant_amax_file.exists(), f"quant_amax.pth file should be created in {export_dir}" + quant_amax_file = export_dir / "quantizer_state.pth" + assert quant_amax_file.exists(), f"quantizer_state.pth file should be created in {export_dir}" # make sure hf_quant_config.json file does not exist hf_quant_config_file = export_dir / "hf_quant_config.json"