From 5ee2c2d02ab246a662725824484f776efe934777 Mon Sep 17 00:00:00 2001 From: yiliu30 Date: Wed, 22 Oct 2025 22:32:07 -0400 Subject: [PATCH 01/52] init moe support Signed-off-by: yiliu30 --- auto_round/experimental/vllm_ext/__init__.py | 16 + .../experimental/vllm_ext/auto_round_ext.py | 49 ++ auto_round/experimental/vllm_ext/envs_ext.py | 32 + auto_round/experimental/vllm_ext/fp4_utils.py | 125 ++++ .../experimental/vllm_ext/moe_impl_mxfp4.py | 556 ++++++++++++++++++ .../experimental/vllm_ext/mxfp4_qdq_utils.py | 218 +++++++ .../experimental/vllm_ext/quant_method_moe.py | 72 +++ .../experimental/vllm_ext/sitecustomize.py | 25 + auto_round/experimental/vllm_ext/utils.py | 81 +++ 9 files changed, 1174 insertions(+) create mode 100644 auto_round/experimental/vllm_ext/__init__.py create mode 100644 auto_round/experimental/vllm_ext/auto_round_ext.py create mode 100644 auto_round/experimental/vllm_ext/envs_ext.py create mode 100644 auto_round/experimental/vllm_ext/fp4_utils.py create mode 100644 auto_round/experimental/vllm_ext/moe_impl_mxfp4.py create mode 100644 auto_round/experimental/vllm_ext/mxfp4_qdq_utils.py create mode 100644 auto_round/experimental/vllm_ext/quant_method_moe.py create mode 100644 auto_round/experimental/vllm_ext/sitecustomize.py create mode 100644 auto_round/experimental/vllm_ext/utils.py diff --git a/auto_round/experimental/vllm_ext/__init__.py b/auto_round/experimental/vllm_ext/__init__.py new file mode 100644 index 000000000..aaadf910c --- /dev/null +++ b/auto_round/experimental/vllm_ext/__init__.py @@ -0,0 +1,16 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + + +# ==---------------------------------------------------------------------------== +# Apply the extension +# ==---------------------------------------------------------------------------== + + +def apply(): + import vllm.model_executor.layers.quantization.auto_round as auto_round_module + + from .auto_round_ext import AutoRoundExtensionConfig + + auto_round_module.AutoRoundConfig = AutoRoundExtensionConfig + from .envs_ext import extra_environment_variables diff --git a/auto_round/experimental/vllm_ext/auto_round_ext.py b/auto_round/experimental/vllm_ext/auto_round_ext.py new file mode 100644 index 000000000..cc75eb3b4 --- /dev/null +++ b/auto_round/experimental/vllm_ext/auto_round_ext.py @@ -0,0 +1,49 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from typing import Any + +import torch +from vllm.logger import init_logger +from vllm.model_executor.layers.fused_moe import FusedMoE +from vllm.model_executor.layers.quantization.auto_round import AutoRoundConfig + +from auto_round.schemes import QuantizationScheme + +from .quant_method_moe import AutoRoundMoEMethod + +logger = init_logger(__name__) + + +class AutoRoundExtensionConfig(AutoRoundConfig): + SUPPORTED_DTYPES = AutoRoundConfig.SUPPORTED_DTYPES.union({"mx_fp"}) + SUPPORTED_FORMATS = AutoRoundConfig.SUPPORTED_FORMATS.union({"auto_round:llm_compressor"}) + + def get_quant_method(self, layer: torch.nn.Module, prefix: str): + # FIXME: (yi) make it compatible with `AutoRoundConfig` + if isinstance(layer, FusedMoE): + quant_method = AutoRoundMoEMethod.get_moe_method(self, layer, prefix) + return quant_method + + return super().get_quant_method(layer, prefix) + + @staticmethod + def _parse_quant_scheme(config: dict): + quant_scheme_attrs = QuantizationScheme.get_attributes() + filter_config = {key: value for key, value in config.items() if key in quant_scheme_attrs} + quant_scheme = QuantizationScheme.from_dict(filter_config) + return quant_scheme + + @classmethod + def from_config(cls, config: dict[str, Any]) -> AutoRoundConfig: + ar_config = super().from_config(config) + # TODO: (yi) refine below implementation + quant_scheme = AutoRoundExtensionConfig._parse_quant_scheme(config) + layer_schemes = {} + layer_schemes = {} # ensure dict + extra_config = getattr(ar_config, "extra_config", None) + if extra_config is not None: + for layer_name, layer_config in extra_config.items(): + layer_schemes[layer_name] = AutoRoundExtensionConfig._parse_quant_scheme(layer_config) + ar_config.quant_scheme = quant_scheme + ar_config.layer_schemes = layer_schemes + return ar_config diff --git a/auto_round/experimental/vllm_ext/envs_ext.py b/auto_round/experimental/vllm_ext/envs_ext.py new file mode 100644 index 000000000..964272ea1 --- /dev/null +++ b/auto_round/experimental/vllm_ext/envs_ext.py @@ -0,0 +1,32 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import os +from typing import Any, Callable + +from vllm.logger import init_logger + +logger = init_logger(__name__) + +# Define extra environment variables +extra_environment_variables: dict[str, Callable[[], Any]] = { + "VLLM_AR_MXFP8_DISABLE_INPUT_QDQ": lambda: os.getenv("VLLM_AR_MXFP8_DISABLE_INPUT_QDQ", "0") + in ("1", "true", "True"), + "VLLM_AR_MXFP4_DISABLE_INPUT_QDQ": lambda: os.getenv("VLLM_AR_MXFP4_DISABLE_INPUT_QDQ", "0") + in ("1", "true", "True"), + "VLLM_MXFP4_PRE_UNPACK_WEIGHTS": lambda: os.getenv("VLLM_MXFP4_PRE_UNPACK_WEIGHTS", "0") in ("1", "true", "True"), + "VLLM_ENABLE_STATIC_MOE": lambda: os.getenv("VLLM_ENABLE_STATIC_MOE", "0") in ("1", "true", "True"), + "VLLM_AR_MXFP4_MODULAR_MOE": lambda: os.getenv("VLLM_AR_MXFP4_MODULAR_MOE", "1") in ("1", "true", "True"), +} +# Add the extra environment variables to vllm.envs +import vllm.envs as envs +from vllm.envs import environment_variables + +# Merge the environment variables +all_environment_variables = {**environment_variables, **extra_environment_variables} + + +for name, value_fn in extra_environment_variables.items(): + setattr(envs, name, value_fn()) + +logger.warning_once(f"Added extra environment variables: {list(extra_environment_variables.keys())}") diff --git a/auto_round/experimental/vllm_ext/fp4_utils.py b/auto_round/experimental/vllm_ext/fp4_utils.py new file mode 100644 index 000000000..26e97e633 --- /dev/null +++ b/auto_round/experimental/vllm_ext/fp4_utils.py @@ -0,0 +1,125 @@ +# Copyright (c) 2025 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional + +import torch + +FLOAT_TO_E2M1 = [ + 0.0, + 0.5, + 1.0, + 1.5, + 2.0, + 3.0, + 4.0, + 6.0, +] + +# Module-level device tensor cache +_DEVICE_E2M1_TENSORS = {} + +# Constants for FP4 values (E2M1 format) +_E2M1_VALUES = [0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0] + + +def get_e2m1_tensor(device): + """Get device-specific E2M1 lookup tensor, creating it if needed.""" + device_str = str(device) + if device_str not in _DEVICE_E2M1_TENSORS: + _DEVICE_E2M1_TENSORS[device_str] = torch.tensor(_E2M1_VALUES, dtype=torch.float32, device=device) + return _DEVICE_E2M1_TENSORS[device_str] + + +def pack_fp4_to_uint8(x: torch.Tensor) -> torch.Tensor: + m, n = x.shape + device = x.device + + # Create lookup table for FP4 values to indices + # Map the absolute values to 0-7 indices + kE2M1 = get_e2m1_tensor(x.device) + + # Find closest valid FP4 value index for each element + abs_x = torch.abs(x) + abs_diff_x = torch.abs(abs_x.unsqueeze(-1) - kE2M1) # [m, n, 8] + abs_indices = torch.argmin(abs_diff_x, dim=-1) # [m, n] + + # Apply sign bit (bit 3) to get final 4-bit representation + indices = abs_indices + (torch.signbit(x).to(torch.long) << 3) + + # Reshape to prepare for packing pairs of values + indices = indices.reshape(-1) + + # Handle odd length by padding if necessary + if indices.numel() % 2 != 0: + indices = torch.cat([indices, torch.zeros(1, dtype=torch.long, device=device)]) + + # Reshape to pair consecutive elements + indices = indices.reshape(-1, 2) + + # Pack pairs of 4-bit values into 8-bit values + packed = (indices[:, 0] | (indices[:, 1] << 4)).to(torch.uint8) + + return packed.reshape(m, n // 2) + + +def unpack_fp4_from_uint8( + a: torch.Tensor, m: int, n: int, dtype: Optional[torch.dtype] = torch.bfloat16 +) -> torch.Tensor: + """ + Unpacks uint8 values into fp4. Each uint8 consists of two fp4 values + (i.e. first four bits correspond to one fp4 value, last four correspond to a + consecutive fp4 value). The bits represent an index, which are mapped to an fp4 + value. + + :param a: tensor to unpack + :param m: original dim 0 size of the unpacked tensor + :param n: original dim 1 size of the unpacked tensor + :param dtype: dense dtype to cast the unpacked tensor to + """ + assert a.dtype == torch.uint8, f"expected uint8, got {a.dtype}" + + # Vectorized nibble processing + a_flat = a.flatten() + high = (a_flat & 0xF0) >> 4 # Upper nibbles + low = a_flat & 0x0F # Lower nibbles + + # Combine nibbles for batch processing + combined = torch.stack((low, high), dim=1).flatten() + + # Vectorized sign and magnitude extraction + signs = (combined & 0x08).to(torch.bool) # Sign bits + abs_vals = (combined & 0x07).to(torch.long) # Magnitude indices + + # Device-aware lookup and sign application + kE2M1 = get_e2m1_tensor(a.device) + + values = kE2M1[abs_vals] * torch.where(signs, -1.0, 1.0) + + # Reshape to final form + return values.reshape(m, n).to(dtype=dtype) + + +def cast_to_fp4(x): + sign = torch.sign(x) + x = torch.abs(x) + x[(x >= 0.0) & (x <= 0.25)] = 0.0 + x[(x > 0.25) & (x < 0.75)] = 0.5 + x[(x >= 0.75) & (x <= 1.25)] = 1.0 + x[(x > 1.25) & (x < 1.75)] = 1.5 + x[(x >= 1.75) & (x <= 2.5)] = 2.0 + x[(x > 2.5) & (x < 3.5)] = 3.0 + x[(x >= 3.5) & (x <= 5.0)] = 4.0 + x[x > 5.0] = 6.0 + return x * sign diff --git a/auto_round/experimental/vllm_ext/moe_impl_mxfp4.py b/auto_round/experimental/vllm_ext/moe_impl_mxfp4.py new file mode 100644 index 000000000..426fc535e --- /dev/null +++ b/auto_round/experimental/vllm_ext/moe_impl_mxfp4.py @@ -0,0 +1,556 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +# ==-------------------------------------------------------------------------== +# MOE MXFP4 +# ==------------------------------------------------------------------------== + + +from typing import Callable, Optional, Union + +import torch +import torch.nn.functional as F +import vllm.envs as envs +from vllm.distributed import get_tensor_model_parallel_rank +from vllm.logger import init_logger +from vllm.model_executor.layers.fused_moe import ( + FusedMoE, + FusedMoEConfig, + FusedMoeWeightScaleSupported, +) +from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig +from vllm.model_executor.utils import set_weight_attrs + +logger = init_logger(__name__) +from .quant_method_moe import AutoRoundMoEMethod + + +def apply_act(local_w1_out: torch.Tensor, local_w3_out: torch.Tensor, activation: str) -> torch.Tensor: + if activation == "silu": + act_fn = F.silu + w13_out = act_fn(local_w1_out) * local_w3_out + elif activation == "swigluoai": + limit = 7.0 + alpha = 1.702 + local_w1_out = local_w1_out.clamp(min=None, max=limit) + local_w3_out = local_w3_out.clamp(min=-limit, max=limit) + glu = (local_w1_out) * F.sigmoid(local_w1_out * alpha) + w13_out = (local_w3_out + 1) * glu + else: + raise NotImplementedError(f"Activation {activation} is not implemented.") + return w13_out + + +class AutoRoundMoEMethodMXFp4Impl(AutoRoundMoEMethod): + def __init__( + self, + quant_config: "AutoRoundConfig", # type: ignore # noqa E501 + moe: FusedMoEConfig, + ): + super().__init__(moe) + self.use_marlin = False + self.group_size = 32 + self.quant_config = quant_config + self.has_bias = self.moe.has_bias + self.mask_weights_buffer = None + self.experts_mask_buffer = None + self.num_all_tokens_threshold = 16 * 1024 + # self.num_all_tokens_threshold = 0 + + def create_weights( + self, + layer: torch.nn.Module, + num_experts: int, + hidden_size: int, + intermediate_size_per_partition: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + layer.num_experts = num_experts + layer.params_dtype = params_dtype + + w13_weight = torch.nn.Parameter( + torch.empty( + num_experts, + 2 * intermediate_size_per_partition, + # 2 fp4 items are packed in the input dimension + hidden_size // 2, + requires_grad=False, + dtype=torch.uint8, + ), + requires_grad=False, + ) + layer.register_parameter("w13_weight_packed", w13_weight) + set_weight_attrs(w13_weight, extra_weight_attrs) + + w2_weight = torch.nn.Parameter( + torch.empty( + num_experts, + hidden_size, + # 2 fp4 items are packed in the input dimension + intermediate_size_per_partition // 2, + dtype=torch.uint8, + ), + requires_grad=False, + ) + layer.register_parameter("w2_weight_packed", w2_weight) + set_weight_attrs(w2_weight, extra_weight_attrs) + + # Weight Scales + w13_weight_scale = torch.nn.Parameter( + torch.empty( + num_experts, + 2 * intermediate_size_per_partition, + # 2 fp4 items are packed in the input dimension + hidden_size // self.group_size, + dtype=torch.uint8, + ), + requires_grad=False, + ) + layer.register_parameter("w13_weight_scale", w13_weight_scale) + extra_weight_attrs.update({"quant_method": FusedMoeWeightScaleSupported.GROUP.value}) + set_weight_attrs(w13_weight_scale, extra_weight_attrs) + + w2_weight_scale = torch.nn.Parameter( + torch.empty( + num_experts, + hidden_size, + # 2 fp4 items are packed in the input dimension + intermediate_size_per_partition // self.group_size, + dtype=torch.uint8, + ), + requires_grad=False, + ) + layer.register_parameter("w2_weight_scale", w2_weight_scale) + extra_weight_attrs.update({"quant_method": FusedMoeWeightScaleSupported.GROUP.value}) + set_weight_attrs(w2_weight_scale, extra_weight_attrs) + + E = num_experts + H = hidden_size + IN = intermediate_size_per_partition + if self.has_bias: + # TODO: yiliu30 use the dtype in CK + bias_dtype = torch.bfloat16 + w13_bias = torch.nn.Parameter(torch.zeros(E, 2 * IN, dtype=bias_dtype), requires_grad=False) + layer.register_parameter("w13_bias", w13_bias) + set_weight_attrs(w13_bias, extra_weight_attrs) + + w2_bias = torch.nn.Parameter( + torch.zeros(num_experts, hidden_size, dtype=bias_dtype), + requires_grad=False, + ) + layer.register_parameter("w2_bias", w2_bias) + set_weight_attrs(w2_bias, extra_weight_attrs) + + def get_fused_moe_quant_config(self, layer: torch.nn.Module) -> Optional[FusedMoEQuantConfig]: + if envs.VLLM_AR_MXFP4_MODULAR_MOE: + from vllm.model_executor.layers.fused_moe.config import ( + ocp_mx_moe_quant_config, + ) + + self.input_dtype = "mxfp4" + self.weight_dtype = "mxfp4" + return ocp_mx_moe_quant_config( + quant_dtype=self.input_dtype, + weight_dtype=self.weight_dtype, + w1_scale=layer.w13_weight_scale, + w2_scale=layer.w2_weight_scale, + a1_scale=None, + a2_scale=None, + w1_bias=layer.w13_bias if self.has_bias else None, + w2_bias=layer.w2_bias if self.has_bias else None, + block_shape=None, + ) + return None + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + if envs.VLLM_ENABLE_STATIC_MOE: + if envs.VLLM_MXFP4_PRE_UNPACK_WEIGHTS: + weight_name_lst = ["w13_weight", "w2_weight"] + from .mxfp4_qdq_utils import dequant_mxfp4_to_fp8 + + for weight_name_prefix in weight_name_lst: + weight_name = f"{weight_name_prefix}_packed" + weight = getattr(layer, weight_name) + weight_scale_name = f"{weight_name_prefix}_scale" + weight_scale = getattr(layer, weight_scale_name) + new_weight_name = f"{weight_name_prefix}_unpacked" + new_scale_name = weight_scale_name + num_experts, _, _ = weight.shape + unpacked_weight_lst = [] + scale_list = [] + for expert_index in range(num_experts): + weight_fp8, scale_bf16 = dequant_mxfp4_to_fp8( + data_lp=weight[expert_index], + scale_e8m0=weight_scale[expert_index], + ) + + unpacked_weight_lst.append(weight_fp8) + scale_list.append(scale_bf16) + unpacked_weight_fp8 = torch.stack(unpacked_weight_lst, dim=0) + scale_bf16 = torch.stack(scale_list, dim=0) + assert unpacked_weight_fp8.shape[0] == num_experts, ( + f"Expected {num_experts} unpacked weights, got " f"{unpacked_weight_fp8.shape[0]}" + ) + delattr(layer, weight_name) + delattr(layer, weight_scale_name) + layer.register_parameter( + new_weight_name, + torch.nn.Parameter( + unpacked_weight_fp8, + requires_grad=False, + ), + ) + layer.register_parameter( + new_scale_name, + torch.nn.Parameter( + scale_bf16, + requires_grad=False, + ), + ) + + elif envs.VLLM_AR_MXFP4_MODULAR_MOE: + + def revert_interleaved_bias(bias): + """ + Convert from blocked bias format to interleaved format. + + Args: + bias: Tensor of shape [E, intermediate_size*2] where the first half contains + w1 biases and second half contains w3 biases for each expert + + Returns: + Tensor with interleaved w1 and w3 biases + """ + # loaded bias[0]: [e0_w1_bias_0, e0_w1_bias_1, ..., e0_w3_bias_0, e0_w3_bias_1...] + # Expected bias[0]: [e0_w1_bias_0, e0_w3_bias_0, e0_w1_bias_1, e0_w3_bias_1...] + # The expected bias order is used in triton kernel https://github.com/vllm-project/vllm/pull/22508 + revert_bias = torch.zeros_like(bias, device=bias.device) + E, two_IN = bias.shape + + # Verify the shape is as expected + if two_IN % 2 != 0: + raise ValueError(f"Expected even number of bias elements, got {two_IN}") + + revert_bias[..., ::2] = bias[..., : two_IN // 2] + revert_bias[..., 1::2] = bias[..., two_IN // 2 :] + + return revert_bias + + # breakpoint() + w13_bias_swapped = revert_interleaved_bias(layer.w13_bias) + layer.w13_bias.data.copy_(w13_bias_swapped) + + if envs.VLLM_MXFP4_PRE_UNPACK_WEIGHTS: + import vllm.model_executor.layers.quantization.auto_round_vllm_extension.mxfp4_qdq_utils as mxfp4_utils + + w1 = layer.w13_weight_packed + w1_scale = layer.w13_weight_scale + w1 = mxfp4_utils.to_dtype( + data_lp=w1, + scale_e8m0=w1_scale, + elem_dtype="fp4_e2m1", + block_size=32, + target_dtype=torch.bfloat16, + ) + + def revert_interleaved_w1(w1): + """ + w1[0,:8,:4] + tensor([[ 0.0000, 0.0000, 0.0000, -0.0625], + [ 0.0234, 0.0156, 0.0234, 0.0312], + [-0.0312, 0.0156, 0.0312, -0.0156], + [ 0.0078, 0.0078, 0.0078, 0.0156], + [ 0.0078, 0.0000, 0.0000, 0.0234], + [ 0.0000, -0.0078, -0.0156, -0.0469], + [-0.0469, 0.0078, -0.0625, 0.0234], + [ 0.0156, 0.0000, -0.0312, 0.0156]], device='cuda:0', + dtype=torch.bfloat16) + + + tensor([[ 0.0000, 0.0000, 0.0000, -0.0625], + [ 0.0156, -0.0000, -0.0000, -0.0938], + [ 0.0234, 0.0156, 0.0234, 0.0312], + [ 0.0312, 0.0000, -0.0156, 0.0000], + [-0.0312, 0.0156, 0.0312, -0.0156], + [-0.0234, 0.0469, -0.0312, -0.0312], + [ 0.0078, 0.0078, 0.0078, 0.0156], + [ 0.0469, -0.0312, -0.0469, 0.0625]], device='cuda:0', + dtype=torch.bfloat16) + + """ + new_w1 = torch.zeros_like(w1) + E, N, H = w1.shape + new_w1[:, ::2, :] = w1[:, : N // 2, :] + new_w1[:, 1::2, :] = w1[:, N // 2 :, :] + return new_w1 + + w1 = revert_interleaved_w1(w1) + + w1_scale = None + w2 = layer.w2_weight_packed + w2_scale = layer.w2_weight_scale + w2 = mxfp4_utils.to_dtype( + data_lp=w2, + scale_e8m0=w2_scale, + elem_dtype="fp4_e2m1", + block_size=32, + target_dtype=torch.bfloat16, + ) + w2_scale = None + del layer.w13_weight_packed + del layer.w13_weight_scale + del layer.w2_weight_packed + del layer.w2_weight_scale + layer.w13_weight_scale = None + layer.w2_weight_scale = None + layer.register_parameter( + "w13_weight_unpacked", + torch.nn.Parameter( + w1, + requires_grad=False, + ), + ) + layer.register_parameter( + "w2_weight_unpacked", + torch.nn.Parameter( + w2, + requires_grad=False, + ), + ) + + else: + raise NotImplementedError("process_weights_after_loading is not implemented for now.") + + @torch.inference_mode() + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + renormalize: bool, + use_grouped_topk: bool = False, + topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None, + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None, + custom_routing_function: Optional[Callable] = None, + scoring_func: str = "softmax", + routed_scaling_factor: float = 1.0, + e_score_correction_bias: Optional[torch.Tensor] = None, + apply_router_weight_on_input: bool = False, + activation: str = "silu", + enable_eplb: bool = False, + expert_load_view: Optional[torch.Tensor] = None, + logical_to_physical_map: Optional[torch.Tensor] = None, + logical_replica_count: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: + topk_weights, topk_ids, _ = FusedMoE.select_experts( + hidden_states=x, + router_logits=router_logits, + use_grouped_topk=use_grouped_topk, + top_k=top_k, + renormalize=renormalize, + topk_group=topk_group, + num_expert_group=num_expert_group, + custom_routing_function=custom_routing_function, + scoring_func=scoring_func, + e_score_correction_bias=e_score_correction_bias, + ) + assert self.fused_experts is None + + # There are three implementations: + + if envs.VLLM_AR_MXFP4_MODULAR_MOE: + from vllm.model_executor.layers.fused_moe import fused_experts + + if envs.VLLM_MXFP4_PRE_UNPACK_WEIGHTS: + w1 = layer.w13_weight_unpacked + w2 = layer.w2_weight_unpacked + else: + w1 = layer.w13_weight_packed + w2 = layer.w2_weight_packed + out = fused_experts( + x, + w1, + w2, + topk_weights=topk_weights, + topk_ids=topk_ids, + inplace=True, + activation=activation, + global_num_experts=global_num_experts, + apply_router_weight_on_input=apply_router_weight_on_input, + expert_map=expert_map, + quant_config=self.moe_quant_config, + ) + return out + + num_all_tokens, hidden_dim = x.shape + num_experts = layer.local_num_experts + total_num_experts = router_logits.size(-1) + if ( + self.mask_weights_buffer is None + or self.mask_weights_buffer.dtype != x.dtype + or self.mask_weights_buffer.device != x.device + or self.mask_weights_buffer.shape[0] < num_all_tokens + or self.mask_weights_buffer.shape[1] < total_num_experts + ): + if num_all_tokens > self.num_all_tokens_threshold: + mask_weights = torch.zeros((num_all_tokens, total_num_experts), dtype=x.dtype, device=x.device) + if self.mask_weights_buffer is None and self.num_all_tokens_threshold != 0: + self.mask_weights_buffer = torch.zeros( + (self.num_all_tokens_threshold, total_num_experts), + dtype=x.dtype, + device=x.device, + ) + self.experts_mask_buffer = torch.zeros( + (self.num_all_tokens_threshold, total_num_experts), + dtype=x.dtype, + device=x.device, + ) + else: + self.mask_weights_buffer = torch.zeros( + (num_all_tokens, total_num_experts), dtype=x.dtype, device=x.device + ) + self.experts_mask_buffer = torch.zeros( + (num_all_tokens, total_num_experts), dtype=x.dtype, device=x.device + ) + mask_weights = self.mask_weights_buffer + experts_mask = self.experts_mask_buffer + else: + self.mask_weights_buffer.zero_() + mask_weights = self.mask_weights_buffer + self.experts_mask_buffer.zero_() + experts_mask = self.experts_mask_buffer + + topk_ids = topk_ids.to(torch.int64) + topk_weights = topk_weights.to(x.dtype) + experts_mask.scatter_(-1, topk_ids, topk_weights) + mask_weights.scatter_(-1, topk_ids, 1) + mask_weights = mask_weights[:num_all_tokens, :total_num_experts] + mask_weights = mask_weights.transpose(0, 1) + experts_mask = experts_mask[:num_all_tokens, :total_num_experts] + experts_mask = experts_mask.transpose(0, 1) + # Note: ep_size equal tp_size + ep_rank = get_tensor_model_parallel_rank() if expert_map is not None else 0 + ep_shift = ep_rank * num_experts + + if envs.VLLM_ENABLE_STATIC_MOE and not envs.VLLM_MXFP4_PRE_UNPACK_WEIGHTS: + num_experts, intermediate_size_per_partition_x2, _ = layer.w13_weight_packed.shape + intermediate_size_per_partition = intermediate_size_per_partition_x2 // 2 + for expert_index in range(num_experts): + mask_weight = mask_weights[expert_index + ep_shift].unsqueeze(1) + current_state_static = x * mask_weight + + local_w13_packed = layer.w13_weight_packed[expert_index] + local_w13_scale = layer.w13_weight_scale[expert_index] + local_w2_packed = layer.w2_weight_packed[expert_index] + local_w2_scale = layer.w2_weight_scale[expert_index] + + local_w1_packed = local_w13_packed[:intermediate_size_per_partition, ...] + local_w1_scale = local_w13_scale[:intermediate_size_per_partition, ...] + + local_w3_packed = local_w13_packed[intermediate_size_per_partition:, ...] + local_w3_scale = local_w13_scale[intermediate_size_per_partition:, ...] + + from .mxfp4_qdq_utils import run_mxfp4_emulations + + local_w1_bias = None + local_w2_bias = None + local_w3_bias = None + if self.has_bias: + local_w13_bias = layer.w13_bias[expert_index] + local_w1_bias = local_w13_bias[:intermediate_size_per_partition] + local_w3_bias = local_w13_bias[intermediate_size_per_partition:] + local_w2_bias = layer.w2_bias[expert_index] + + local_w1_out = run_mxfp4_emulations( + x=current_state_static, + weight=local_w1_packed, + weight_scale=local_w1_scale, + bias=local_w1_bias, + ) + local_w3_out = run_mxfp4_emulations( + x=current_state_static, + weight=local_w3_packed, + weight_scale=local_w3_scale, + bias=local_w3_bias, + ) + + w13_out = apply_act(local_w1_out, local_w3_out, activation) + + local_w2_out = run_mxfp4_emulations( + x=w13_out, + weight=local_w2_packed, + weight_scale=local_w2_scale, + bias=local_w2_bias, + ) + padded_weight = experts_mask[expert_index + ep_shift].unsqueeze(1) + local_w2_out = local_w2_out * padded_weight + if expert_index == 0: + final_hidden_states = local_w2_out + else: + final_hidden_states += local_w2_out + return final_hidden_states + if envs.VLLM_ENABLE_STATIC_MOE and envs.VLLM_MXFP4_PRE_UNPACK_WEIGHTS: + num_experts, intermediate_size_per_partition_x2, _ = layer.w13_weight_unpacked.shape + intermediate_size_per_partition = intermediate_size_per_partition_x2 // 2 + + for expert_index in range(num_experts): + mask_weight = mask_weights[expert_index + ep_shift].unsqueeze(1) + current_state_static = x * mask_weight + + local_unpacked_w13 = layer.w13_weight_unpacked[expert_index] + local_w13_scale = layer.w13_weight_scale[expert_index] + + local_unpacked_w2 = layer.w2_weight_unpacked[expert_index] + local_w2_scale = layer.w2_weight_scale[expert_index] + + local_unpacked_w1 = local_unpacked_w13[:intermediate_size_per_partition, ...] + half_scale = local_w13_scale.shape[0] // 2 + local_w1_scale = local_w13_scale[:half_scale, ...] + local_unpacked_w3 = local_unpacked_w13[intermediate_size_per_partition:, ...] + local_w3_scale = local_w13_scale[half_scale:, ...] + + local_w1_bias = None + local_w2_bias = None + local_w3_bias = None + if self.has_bias: + local_w13_bias = layer.w13_bias[expert_index] + local_w1_bias = local_w13_bias[:intermediate_size_per_partition] + local_w3_bias = local_w13_bias[intermediate_size_per_partition:] + local_w2_bias = layer.w2_bias[expert_index] + + from .mxfp4_qdq_utils import mxfp4_gemm_with_unpacked_weight + + local_w1_out = mxfp4_gemm_with_unpacked_weight( + x=current_state_static, + weight_fp8=local_unpacked_w1, + weight_scale_bf16=local_w1_scale, + bias=local_w1_bias, + ) + local_w3_out = mxfp4_gemm_with_unpacked_weight( + x=current_state_static, + weight_fp8=local_unpacked_w3, + weight_scale_bf16=local_w3_scale, + bias=local_w3_bias, + ) + + w13_out = apply_act(local_w1_out, local_w3_out, activation) + + local_w2_out = mxfp4_gemm_with_unpacked_weight( + x=w13_out, + weight_fp8=local_unpacked_w2, + weight_scale_bf16=local_w2_scale, + bias=local_w2_bias, + ) + + padded_weight = experts_mask[expert_index + ep_shift].unsqueeze(1) + local_w2_out = local_w2_out * padded_weight + if expert_index == 0: + final_hidden_states = local_w2_out + else: + final_hidden_states += local_w2_out + return final_hidden_states + raise NotImplementedError("Not implemented for now.") diff --git a/auto_round/experimental/vllm_ext/mxfp4_qdq_utils.py b/auto_round/experimental/vllm_ext/mxfp4_qdq_utils.py new file mode 100644 index 000000000..d90272c55 --- /dev/null +++ b/auto_round/experimental/vllm_ext/mxfp4_qdq_utils.py @@ -0,0 +1,218 @@ +# Copyright (c) 2025 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Union + +import torch + +from .fp4_utils import cast_to_fp4, pack_fp4_to_uint8, unpack_fp4_from_uint8 +from .utils import _to_mx_rceil, get_fp_scale + +F4_E2M1_MAX = 6.0 +F32_EXP_BIAS = 127 + +F32_MIN_NORMAL = 2 ** (-F32_EXP_BIAS + 1) + + +def to_mxfp4_rceil( + data_hp: torch.Tensor, + elem_dtype: Union[torch.dtype, str], + block_size: int, +): + + assert data_hp.dtype in ( + torch.bfloat16, + torch.float, + ), f"{data_hp.dtype} is not supported yet" + # TODO(future PR): consider supporting padding + assert data_hp.numel() % block_size == 0, f"data size must be multiple of block_size={block_size}" + assert data_hp.is_contiguous(), f"data must be contiguous, got {data_hp.stride()}" + + # calculate the scale in e8m0 format + + orig_shape = data_hp.shape + data_hp = data_hp.reshape(-1, block_size) + + # find max value of the data + # Note: this only implements the `minimally supported` version of + # https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf + # section 6.3. + max_abs = torch.amax(torch.abs(data_hp), 1) + + # Add an epsilon to prevent the log2 function call for returning -inf + # where the values are zero. + eps = F32_MIN_NORMAL * (max_abs == 0).type(max_abs.dtype) + + # Set X to be the largest power-of-two less than or equal to + # max_abs(v), divided by the largest power of two representable + # in the element data type, and get the mbits at the same time + + max_pos = F4_E2M1_MAX + scale_e8m0_biased, data_lp = _to_mx_rceil(data_hp, max_abs, max_pos) + + data_lp = data_lp.reshape(orig_shape) + orig_shape = [*orig_shape[:-1], orig_shape[-1] // 2] + data_lp = cast_to_fp4(data_lp) + data_lp = pack_fp4_to_uint8(data_lp) + + scale_e8m0_biased = scale_e8m0_biased.view(torch.uint8) + return scale_e8m0_biased, data_lp + + +def to_dtype( + data_lp, + scale_e8m0, + elem_dtype, + block_size, + target_dtype, + scale_dtype=None, + return_scale=False, +): + orig_shape = data_lp.shape + last_dim = orig_shape[-1] + data_lp = data_lp.reshape(-1, last_dim) + result_shape = orig_shape[:-1] + (last_dim * 2,) + assert data_lp.is_contiguous, f"Data must be contiguous, got {data_lp.stride()}" + + assert elem_dtype == "fp4_e2m1", f"Expected 'fp4_e2m1', got {elem_dtype}" + + m, half_n = data_lp.shape + n = half_n * 2 + data_hp = unpack_fp4_from_uint8(data_lp, m, n, dtype=target_dtype) + + data_hp = data_hp.reshape(-1, block_size) + + if scale_dtype is None: + scale_dtype = target_dtype + s_fp = get_fp_scale(scale_e8m0).reshape(-1, 1).to(scale_dtype) + if return_scale: + return data_hp.reshape(result_shape), s_fp + + data_hp = data_hp * s_fp + data_hp = data_hp.reshape(result_shape) + + return data_hp + + +def run_mxfp4_emulations( + x: torch.Tensor, + weight: torch.Tensor, + weight_scale: torch.Tensor, + bias: torch.Tensor | None = None, +): + group_size = 32 + # quantize input to (FP4 and interleaved block scale) + input_scale, x_q = to_mxfp4_rceil( + data_hp=x, + elem_dtype="fp4_e2m1", + block_size=group_size, + ) + + # dequantize input + x_dq = to_dtype( + data_lp=x_q, + scale_e8m0=input_scale, + elem_dtype="fp4_e2m1", + block_size=group_size, + target_dtype=x.dtype, + ) + + # dequantize weight + w_dq = to_dtype( + data_lp=weight, + scale_e8m0=weight_scale, + elem_dtype="fp4_e2m1", + block_size=group_size, + target_dtype=x.dtype, + ) + + # matmul + out = torch.matmul(x_dq, w_dq.t()) + if bias is not None: + out += bias + return out + + +def dequant_mxfp4_to_fp8(data_lp, scale_e8m0): + data_fp8, scale_float = to_dtype( + data_lp=data_lp, + scale_e8m0=scale_e8m0, + elem_dtype="fp4_e2m1", + block_size=32, + target_dtype=torch.float8_e4m3fn, + scale_dtype=torch.bfloat16, + return_scale=True, + ) + return data_fp8, scale_float + + +def mxfp4_fp8_weight_to_bf16(weight_fp8, scale_bf16): + origin_shape = weight_fp8.shape + weight_fp8 = weight_fp8.reshape(-1, 32) + assert weight_fp8.shape[0] == scale_bf16.shape[0], f"shape mismatch: {weight_fp8.shape} vs {scale_bf16.shape}" + dequant_weight_bf16 = weight_fp8.to(torch.bfloat16) * scale_bf16 + dequant_weight_bf16 = dequant_weight_bf16.reshape(origin_shape) + return dequant_weight_bf16 + + +def mxfp4_gemm_with_unpacked_weight(x, weight_fp8, weight_scale_bf16, bias=None): + x = qdq_mxfp4(x) + + # dequantize weight + w_dq = mxfp4_fp8_weight_to_bf16(weight_fp8, weight_scale_bf16) + # matmul + out = torch.matmul(x, w_dq.t()) + if bias is not None: + out += bias + return out + + +def fp4_121_positive(x: torch.Tensor) -> torch.Tensor: + step1 = torch.round(2.0 * x) / 2.0 + step2 = torch.round(x) + step3 = 2.0 * torch.round(x / 2.0) + + mask1 = x < 2.0 + mask2 = x < 4.0 + + return step1 * mask1 + step2 * (~mask1) * mask2 + step3 * (~mask1) * (~mask2) + + +def fp4_121_scaled_even_rounding(x: torch.Tensor) -> torch.Tensor: + sign = x.sign() + x_abs = x.abs() + amax_x = x_abs.max(dim=-1, keepdim=True)[0] + scale_tmp = torch.floor(torch.log2(amax_x)) - 2.0 + scale_clamp = torch.clamp(scale_tmp, min=-127, max=127) + scale = torch.pow(2.0, scale_clamp) + + scale = torch.where((0 < scale) * (scale < torch.inf), scale, 1.0) + + x_fp4_abs = fp4_121_positive(x_abs / scale) * scale + return sign * x_fp4_abs + + +# https://github.com/Anonymous1252022/fp4-all-the-way/blob/main/experimental/fp4.py +def qdq_mxfp4( + x: torch.Tensor, +) -> torch.Tensor: + block_size = 32 + shape = x.shape + x = x.reshape(-1, block_size) + + x = fp4_121_scaled_even_rounding(x) + + x = x.reshape(shape) + + return x diff --git a/auto_round/experimental/vllm_ext/quant_method_moe.py b/auto_round/experimental/vllm_ext/quant_method_moe.py new file mode 100644 index 000000000..60206e947 --- /dev/null +++ b/auto_round/experimental/vllm_ext/quant_method_moe.py @@ -0,0 +1,72 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from typing import Optional + +import torch +from vllm.logger import init_logger +from vllm.model_executor.layers.fused_moe import ( + FusedMoEConfig, + FusedMoEMethodBase, +) +from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig +from vllm.model_executor.layers.quantization.auto_round import AutoRoundConfig + +from auto_round.schemes import QuantizationScheme + +logger = init_logger(__name__) + + +QMOE_METHODS_DISPATCH_TABLE = {} + + +def _is_mxfp4_w4a4(scheme: QuantizationScheme): + # FIXME: below impl is incomplete + return scheme.bits == 4 and scheme.group_size == 32 + + +class AutoRoundMoEMethod(FusedMoEMethodBase): + def __init__(self, moe: FusedMoEConfig): + super().__init__(moe) + + @staticmethod + def get_moe_method( + quant_config: AutoRoundConfig, + layer: torch.nn.Module, + prefix: str, + ) -> "AutoRoundMoEMethod": + + def get_scheme(quant_config: AutoRoundConfig, prefix: str): + # Check extra_config first + layer_schemes = quant_config.layer_schemes + # FIXME: make more robust + for name, scheme in layer_schemes.items(): + if prefix.startswith(name): + return scheme + # If not found, use default + return quant_config.quant_scheme + + def check_quantized(weight_bits: int) -> bool: + return weight_bits < 16 + + def get_impl(scheme: QuantizationScheme): + if not check_quantized(scheme.bits): + from vllm.model_executor.layers.fused_moe.layer import ( + UnquantizedFusedMoEMethod, + ) + + return UnquantizedFusedMoEMethod(layer.moe_config) + + elif _is_mxfp4_w4a4(scheme): + from .moe_impl_mxfp4 import AutoRoundMoEMethodMXFp4Impl + + return AutoRoundMoEMethodMXFp4Impl(quant_config, layer.moe_config) + + raise ValueError(f"Unsupported FusedMoe scheme: {scheme}") + + layer_scheme = get_scheme(quant_config, prefix) + impl = get_impl(layer_scheme) + logger.debug("Apply %s to %s", impl.__class__.__name__, prefix) + return impl + + def get_fused_moe_quant_config(self, layer: torch.nn.Module) -> Optional[FusedMoEQuantConfig]: + return self.impl.get_fused_moe_quant_config(layer) diff --git a/auto_round/experimental/vllm_ext/sitecustomize.py b/auto_round/experimental/vllm_ext/sitecustomize.py new file mode 100644 index 000000000..fcb2bc5a8 --- /dev/null +++ b/auto_round/experimental/vllm_ext/sitecustomize.py @@ -0,0 +1,25 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# PYTHONPATH=/home/yliu7/workspace/inc/3rd-party/vllm/vllm/model_executor/layers/quantization/auto_round_vllm_extension/:$PYTHONPATH + +import os + +VLLM_ENABLE_AR_EXT = os.environ.get("VLLM_ENABLE_AR_EXT", "") in [ + "1", + "true", + "True", +] + +if VLLM_ENABLE_AR_EXT: + print("*****************************************************************************") + print(f"* !!! VLLM_ENABLE_AR_EXT is set to {VLLM_ENABLE_AR_EXT}, applying auto_round_vllm_extension *") + print("*****************************************************************************") + from vllm.model_executor.layers.quantization import auto_round_vllm_extension as auto_round_ext + + auto_round_ext.apply() +else: + print("*****************************************************************************") + print( + f"* Sitecustomize is loaded, but VLLM_ENABLE_AR_EXT is set to {VLLM_ENABLE_AR_EXT}, skipping auto_round_vllm_extension *" + ) + print("*****************************************************************************") diff --git a/auto_round/experimental/vllm_ext/utils.py b/auto_round/experimental/vllm_ext/utils.py new file mode 100644 index 000000000..f98eb2325 --- /dev/null +++ b/auto_round/experimental/vllm_ext/utils.py @@ -0,0 +1,81 @@ +# Copyright (c) 2025 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch + +E8M0_EXPONENT_BIAS = 127 +E8M0_EXPONENT_NAN_VAL = 255 + + +def get_fp_scale(scale_e8m0): + scale_e8m0 = scale_e8m0.view(torch.uint8) + s_offset = scale_e8m0.to(torch.int16) - E8M0_EXPONENT_BIAS + # TODO(later): it would be nice if there was a way to do the 2^x operation + # in PyTorch without creating a tensor of twos + two = torch.full(s_offset.size(), 2.0, device=scale_e8m0.device) + # pow(two, s_offset) can be out of range of floating point formats. + # TODO(later): handle this for float16 if we decide to support float16 + # scales. + s_fp = torch.pow(two, s_offset) + + # If a block exponent was 255, set values of that block to NaN + s_fp = torch.where(scale_e8m0 != E8M0_EXPONENT_NAN_VAL, s_fp, float("nan")) + + return s_fp + + +def _to_mx_rceil( + data_hp: torch.Tensor, + max_abs: torch.Tensor, + max_pos: float, +) -> tuple[torch.Tensor, torch.Tensor]: + """ + A prototype implementation of MXFP scale factor derivation method described in + https://docs.nvidia.com/cuda/cublas/#d-block-quantization + + For Nvidia GPU with Blackwell+ architecture, the scale factor derivation method + could be accelerated by the `cvt.rp.satfinite.ue8m0x2.f32` instruction. + + Args: + data_hp: High precision data. + max_abs: Maximum absolute value for data_hp along specified dimension/block_size. + max_pos: The maximum value of the low precision data type. + + Returns: + exponent: The biased exponent with dtype E8M0 in uint8 container. + data_lp: The targeted low precision data, in high precision container + (requires cast to low precision data type). + """ + descale = max_abs / max_pos + # TODO: nan/inf needs to be set for any value + # of nan/inf in input not just amax. + exponent = torch.where( + torch.isnan(descale), + 0xFF, # Handle biased exponent for nan + # NOTE: descale < (torch.finfo(torch.float32).smallest_normal / 2) is handled through clamping + ( + torch.clamp( + torch.ceil(torch.log2(descale)), + min=-E8M0_EXPONENT_BIAS, + max=E8M0_EXPONENT_BIAS, + ) + + E8M0_EXPONENT_BIAS + ).to(torch.uint8), + ) + + descale_fp = torch.where(exponent == 0, 1.0, torch.exp2(E8M0_EXPONENT_BIAS - exponent.to(torch.float32))) + + # scale and saturated cast the data elements to max of target dtype + data_lp = torch.clamp(data_hp * descale_fp.unsqueeze(1), min=-1 * max_pos, max=max_pos) + return exponent, data_lp From c278f9dde5718215304095facd5741e5ce535f29 Mon Sep 17 00:00:00 2001 From: yiliu30 Date: Wed, 22 Oct 2025 22:38:24 -0400 Subject: [PATCH 02/52] add test Signed-off-by: yiliu30 --- .../experimental/vllm_ext/tests/conftest.py | 1303 +++++++++++++++++ .../vllm_ext/tests/test_mxfp4_moe.py | 26 + 2 files changed, 1329 insertions(+) create mode 100644 auto_round/experimental/vllm_ext/tests/conftest.py create mode 100644 auto_round/experimental/vllm_ext/tests/test_mxfp4_moe.py diff --git a/auto_round/experimental/vllm_ext/tests/conftest.py b/auto_round/experimental/vllm_ext/tests/conftest.py new file mode 100644 index 000000000..a73fe7e74 --- /dev/null +++ b/auto_round/experimental/vllm_ext/tests/conftest.py @@ -0,0 +1,1303 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# Note: Copied from vllm/tests/conftest.py + +# ruff: noqa + +from tblib import pickling_support + +# Install support for pickling exceptions so that we can nicely propagate +# failures from tests running in a subprocess. +# This should be run before any custom exception subclasses are defined. +pickling_support.install() + +import http.server +import json +import math +import mimetypes +import os +import socket +import tempfile +import threading +import warnings +from collections.abc import Generator, Sequence +from contextlib import nullcontext +from dataclasses import dataclass +from enum import Enum +from typing import Any, Callable, Optional, TypedDict, TypeVar, Union, cast + +import numpy as np +import pytest +import torch +import torch.nn as nn +import torch.nn.functional as F +from huggingface_hub import snapshot_download +from PIL import Image +from transformers import ( + AutoConfig, + AutoModelForCausalLM, + AutoTokenizer, + BatchEncoding, + BatchFeature, + PretrainedConfig, +) +from transformers.models.auto.auto_factory import _BaseAutoModelClass +from vllm.config.model import ModelConfig, ModelDType, RunnerOption +from vllm.logprobs import Logprob, PromptLogprobs, SampleLogprobs +from vllm.multimodal.processing import InputProcessingContext +from vllm.transformers_utils.tokenizer import cached_tokenizer_from_config + +# Representation of generated sequence as a tuple of +# * Token ID list +# * String +# * List of top sample logprobs for each sampled token +# +# Assumes prompt logprobs were not requested. +TokensTextLogprobs = tuple[list[int], str, Optional[Union[list[dict[int, float]], SampleLogprobs]]] + +# Allow for tokens to be represented as str's rather than IDs; +# tuple of +# * Token string representations list +# * String +# * Optional list of top sample logprobs for each sampled token +# +# Assumes prompt logprobs were not requested. +TextTextLogprobs = tuple[list[str], str, Optional[Union[list[dict[str, float]], list[dict[str, Logprob]]]]] + +# Representation of generated sequence as a tuple of +# * Token ID list +# * String +# * Optional list of top sample logprobs for each sampled token +# * Optional list of top prompt logprobs for each prompt token +# +# Allows prompt logprobs to be requested. +TokensTextLogprobsPromptLogprobs = tuple[ + list[int], + str, + Optional[Union[list[dict[int, float]], SampleLogprobs]], + Optional[Union[list[Optional[dict[int, float]]], PromptLogprobs]], +] + + +# from tests.models.utils import TokensTextLogprobs, TokensTextLogprobsPromptLogprobs + + +from vllm import LLM, SamplingParams +from vllm.assets.audio import AudioAsset +from vllm.assets.image import ImageAsset +from vllm.assets.video import VideoAsset +from vllm.config.model import ConvertOption, RunnerOption, _get_and_verify_dtype +from vllm.connections import global_http_connection +from vllm.distributed import ( + cleanup_dist_env_and_memory, + init_distributed_environment, + initialize_model_parallel, +) +from vllm.logger import init_logger +from vllm.logprobs import Logprob +from vllm.multimodal.utils import fetch_image +from vllm.outputs import RequestOutput +from vllm.sampling_params import BeamSearchParams +from vllm.transformers_utils.utils import maybe_model_redirect +from vllm.utils import is_list_of, set_default_torch_num_threads + +logger = init_logger(__name__) + +_TEST_DIR = os.path.dirname(__file__) +_TEST_PROMPTS = [os.path.join(_TEST_DIR, "prompts", "example.txt")] +_LONG_PROMPTS = [os.path.join(_TEST_DIR, "prompts", "summary.txt")] +_SYS_MSG = os.path.join(_TEST_DIR, "system_messages", "sonnet3.5_nov2024.txt") + +_M = TypeVar("_M") + +_PromptMultiModalInput = Union[list[_M], list[list[_M]]] + +PromptImageInput = _PromptMultiModalInput[Image.Image] +PromptAudioInput = _PromptMultiModalInput[tuple[np.ndarray, int]] +PromptVideoInput = _PromptMultiModalInput[np.ndarray] + + +def _read_prompts(filename: str) -> list[str]: + with open(filename) as f: + prompts = f.readlines() + return prompts + + +class ImageAssetPrompts(TypedDict): + stop_sign: str + cherry_blossom: str + + +class ImageTestAssets(list[ImageAsset]): + def __init__(self) -> None: + super().__init__( + [ + ImageAsset("stop_sign"), + ImageAsset("cherry_blossom"), + ] + ) + + def prompts(self, prompts: ImageAssetPrompts) -> list[str]: + """ + Convenience method to define the prompt for each test image. + + The order of the returned prompts matches the order of the + assets when iterating through this object. + """ + return [prompts["stop_sign"], prompts["cherry_blossom"]] + + +class VideoAssetPrompts(TypedDict): + baby_reading: str + + +class VideoTestAssets(list[VideoAsset]): + def __init__(self) -> None: + super().__init__( + [ + VideoAsset("baby_reading"), + ] + ) + + def prompts(self, prompts: VideoAssetPrompts) -> list[str]: + return [prompts["baby_reading"]] + + +class AudioAssetPrompts(TypedDict): + mary_had_lamb: str + winning_call: str + + +class AudioTestAssets(list[AudioAsset]): + def __init__(self) -> None: + super().__init__( + [ + AudioAsset("mary_had_lamb"), + AudioAsset("winning_call"), + ] + ) + + def prompts(self, prompts: AudioAssetPrompts) -> list[str]: + return [prompts["mary_had_lamb"], prompts["winning_call"]] + + +IMAGE_ASSETS = ImageTestAssets() +"""Singleton instance of {class}`ImageTestAssets`.""" +VIDEO_ASSETS = VideoTestAssets() +"""Singleton instance of {class}`VideoTestAssets`.""" +AUDIO_ASSETS = AudioTestAssets() +"""Singleton instance of {class}`AudioTestAssets`.""" + + +@pytest.fixture(scope="function", autouse=True) +def cleanup_VLLM_USE_V1(monkeypatch): + """ + The V1 oracle sets "VLLM_USE_V1" during loading. This means + that each invocation of a test change the env variable. + + If we touch "VLLM_USE_V1" with monkeypatch, then any changes + made during the test run by vLLM will be cleaned up. + + This fixture is used by every test. + """ + + # If VLLM_USE_V1 is not set, set then delete. This will + # cause monkeypatch to clean up VLLM_USE_V1 upon exit + # if VLLM modifies the value of envs.VLLM_USE_V1. + if "VLLM_USE_V1" not in os.environ: + monkeypatch.setenv("VLLM_USE_V1", "") + monkeypatch.delenv("VLLM_USE_V1") + + +@pytest.fixture(autouse=True) +def init_test_http_connection(): + # pytest_asyncio may use a different event loop per test + # so we need to make sure the async client is created anew + global_http_connection.reuse_client = False + + +@pytest.fixture +def dist_init(): + temp_file = tempfile.mkstemp()[1] + init_distributed_environment( + world_size=1, + rank=0, + distributed_init_method=f"file://{temp_file}", + local_rank=0, + backend="nccl", + ) + initialize_model_parallel(1, 1) + yield + cleanup_dist_env_and_memory() + + +@pytest.fixture() +def should_do_global_cleanup_after_test(request) -> bool: + """Allow subdirectories to skip global cleanup by overriding this fixture. + This can provide a ~10x speedup for non-GPU unit tests since they don't need + to initialize torch. + """ + + return not request.node.get_closest_marker("skip_global_cleanup") + + +@pytest.fixture(autouse=True) +def cleanup_fixture(should_do_global_cleanup_after_test: bool): + yield + if should_do_global_cleanup_after_test: + cleanup_dist_env_and_memory() + + +@pytest.fixture(autouse=True) +def dynamo_reset(): + yield + torch._dynamo.reset() + + +@pytest.fixture +def example_prompts() -> list[str]: + prompts = [] + for filename in _TEST_PROMPTS: + prompts += _read_prompts(filename) + return prompts + + +@pytest.fixture +def example_system_message() -> str: + with open(_SYS_MSG) as f: + return f.read() + + +class DecoderPromptType(Enum): + """For encoder/decoder models only.""" + + CUSTOM = 1 + NONE = 2 + EMPTY_STR = 3 + + +@pytest.fixture +def example_long_prompts() -> list[str]: + prompts = [] + for filename in _LONG_PROMPTS: + prompts += _read_prompts(filename) + return prompts + + +@pytest.fixture(scope="session") +def image_assets() -> ImageTestAssets: + return IMAGE_ASSETS + + +@pytest.fixture(scope="session") +def video_assets() -> VideoTestAssets: + return VIDEO_ASSETS + + +@pytest.fixture(scope="session") +def audio_assets() -> AudioTestAssets: + return AUDIO_ASSETS + + +_T = TypeVar("_T", nn.Module, torch.Tensor, BatchEncoding, BatchFeature, dict) +_R = TypeVar("_R") + + +class HfRunner: + def get_default_device(self): + from vllm.platforms import current_platform + + return "cpu" if current_platform.is_cpu() else current_platform.device_type + + def wrap_device(self, x: _T, device: Optional[str] = None) -> _T: + if x is None or isinstance(x, (bool,)): + return x + + if device is None: + device = self.device + + if isinstance(x, dict): + return {k: self.wrap_device(v, device) for k, v in x.items()} + + if hasattr(x, "device") and x.device.type == device: + return x + + return x.to(device) + + def __init__( + self, + model_name: str, + dtype: str = "auto", + *, + model_kwargs: Optional[dict[str, Any]] = None, + trust_remote_code: bool = True, + is_sentence_transformer: bool = False, + is_cross_encoder: bool = False, + skip_tokenizer_init: bool = False, + auto_cls: type[_BaseAutoModelClass] = AutoModelForCausalLM, + # Set this to avoid hanging issue + default_torch_num_threads: Optional[int] = None, + ) -> None: + init_ctx = ( + nullcontext() + if default_torch_num_threads is None + else set_default_torch_num_threads(default_torch_num_threads) + ) + + with init_ctx: + self._init( + model_name=model_name, + dtype=dtype, + model_kwargs=model_kwargs, + trust_remote_code=trust_remote_code, + is_sentence_transformer=is_sentence_transformer, + is_cross_encoder=is_cross_encoder, + skip_tokenizer_init=skip_tokenizer_init, + auto_cls=auto_cls, + ) + + def _init( + self, + model_name: str, + dtype: str = "auto", + *, + model_kwargs: Optional[dict[str, Any]] = None, + trust_remote_code: bool = True, + is_sentence_transformer: bool = False, + is_cross_encoder: bool = False, + skip_tokenizer_init: bool = False, + auto_cls: type[_BaseAutoModelClass] = AutoModelForCausalLM, + ) -> None: + model_name = maybe_model_redirect(model_name) + self.model_name = model_name + + self.config = AutoConfig.from_pretrained( + model_name, + trust_remote_code=trust_remote_code, + ) + self.device = self.get_default_device() + self.dtype = torch_dtype = _get_and_verify_dtype( + self.model_name, + self.config, + dtype=dtype, + is_pooling_model=is_sentence_transformer or is_cross_encoder, + ) + + model_kwargs = model_kwargs if model_kwargs is not None else {} + model_kwargs.setdefault("torch_dtype", torch_dtype) + + if is_sentence_transformer: + # Lazy init required for AMD CI + from sentence_transformers import SentenceTransformer + + self.model = SentenceTransformer( + model_name, + device=self.device, + model_kwargs=model_kwargs, + trust_remote_code=trust_remote_code, + ) + elif is_cross_encoder: + # Lazy init required for AMD CI + from sentence_transformers import CrossEncoder + + self.model = CrossEncoder( + model_name, + device=self.device, + automodel_args=model_kwargs, + trust_remote_code=trust_remote_code, + ) + else: + model = auto_cls.from_pretrained( + model_name, + trust_remote_code=trust_remote_code, + **model_kwargs, + ) + + # in case some unquantized custom models are not in same dtype + if getattr(model, "quantization_method", None) is None and any( + p.dtype != self.dtype for p in model.parameters() + ): + model = model.to(dtype=self.dtype) + + if ( + getattr(model, "quantization_method", None) != "bitsandbytes" + and len({p.device for p in model.parameters()}) < 2 + ): + model = model.to(device=self.device) + + self.model = model + + if not skip_tokenizer_init: + self.tokenizer = AutoTokenizer.from_pretrained( + model_name, + torch_dtype=torch_dtype, + trust_remote_code=trust_remote_code, + ) + + # don't put this import at the top level + # it will call torch.cuda.device_count() + from transformers import AutoProcessor # noqa: F401 + + self.processor = AutoProcessor.from_pretrained( + model_name, + torch_dtype=torch_dtype, + trust_remote_code=trust_remote_code, + ) + if skip_tokenizer_init: + self.tokenizer = self.processor.tokenizer + + def get_inputs( + self, + prompts: Union[list[str], list[list[int]]], + images: Optional[PromptImageInput] = None, + videos: Optional[PromptVideoInput] = None, + audios: Optional[PromptAudioInput] = None, + ) -> list[Union[BatchFeature, BatchEncoding, dict[str, torch.Tensor]]]: + if images is not None: + assert len(prompts) == len(images) + + if videos is not None: + assert len(prompts) == len(videos) + + if audios is not None: + assert len(prompts) == len(audios) + + all_inputs: list[Union[BatchFeature, BatchEncoding, dict[str, torch.Tensor]]] = [] + for i, prompt in enumerate(prompts): + if isinstance(prompt, str): + processor_kwargs: dict[str, Any] = { + "text": prompt, + "return_tensors": "pt", + } + if images is not None and (image := images[i]) is not None: + processor_kwargs["images"] = image + if videos is not None and (video := videos[i]) is not None: + processor_kwargs["videos"] = video + if audios is not None and (audio_inputs := audios[i]) is not None: + # HACK - not all processors take sampling_rate; we should + # clean this up in the future. + if len(audio_inputs) == 2: + audio, sr = audio_inputs + processor_kwargs["audio"] = audio + processor_kwargs["sampling_rate"] = sr + else: + processor_kwargs["audio"] = audio_inputs + + inputs = self.processor(**processor_kwargs) + if isinstance(inputs, BatchFeature): + inputs = inputs.to(dtype=self.dtype) + all_inputs.append(inputs) + else: + # check that prompt is (batched) list of integers (token ids) + if not is_list_of(prompt, typ=int, check="all"): + raise ValueError("Prompt must be a list of ints corresponding to the prompt token ids.") + # check that no multimodal input is provided + if images or videos or audios: + raise ValueError("When providing prompt token ids multimodal inputs are not supported.") + input_dict = { + "input_ids": torch.tensor(prompt, dtype=torch.long).unsqueeze(0), + } + all_inputs.append(input_dict) + + return all_inputs + + def get_prompt_embeddings(self, prompts: list[str]) -> list[torch.Tensor]: + all_inputs = self.get_inputs(prompts) + embeddings = [] + for inputs in all_inputs: + input_ids = self.wrap_device(inputs)["input_ids"] + embedding = self.model.get_input_embeddings()(input_ids).squeeze(0) + embeddings.append(embedding) + return embeddings + + def classify(self, prompts: list[str]) -> list[str]: + # output is final logits + all_inputs = self.get_inputs(prompts) + outputs = [] + problem_type = getattr(self.config, "problem_type", "") + + for inputs in all_inputs: + output = self.model(**self.wrap_device(inputs)) + if problem_type == "regression": + logits = output.logits[0].tolist() + elif problem_type == "multi_label_classification": + logits = output.logits.sigmoid()[0].tolist() + else: + logits = output.logits.softmax(dim=-1)[0].tolist() + outputs.append(logits) + + return outputs + + def generate( + self, + prompts: Union[list[str], list[list[int]]], + images: Optional[PromptImageInput] = None, + videos: Optional[PromptVideoInput] = None, + audios: Optional[PromptAudioInput] = None, + **kwargs: Any, + ) -> list[tuple[list[list[int]], list[str]]]: + all_inputs = self.get_inputs(prompts, images=images, videos=videos, audios=audios) + + outputs: list[tuple[list[list[int]], list[str]]] = [] + for inputs in all_inputs: + output_ids = self.model.generate( + **self.wrap_device(inputs), + use_cache=True, + **kwargs, + ) + output_str = self.processor.batch_decode( + output_ids, + skip_special_tokens=True, + clean_up_tokenization_spaces=False, + ) + output_ids = output_ids.cpu().tolist() + outputs.append((output_ids, output_str)) + return outputs + + def generate_greedy( + self, + prompts: Union[list[str], list[list[int]]], + max_tokens: int, + images: Optional[PromptImageInput] = None, + videos: Optional[PromptVideoInput] = None, + audios: Optional[PromptAudioInput] = None, + **kwargs: Any, + ) -> list[tuple[list[int], str]]: + outputs = self.generate( + prompts, + do_sample=False, + max_new_tokens=max_tokens, + images=images, + videos=videos, + audios=audios, + **kwargs, + ) + + return [(output_ids[0], output_str[0]) for output_ids, output_str in outputs] + + def generate_beam_search( + self, + prompts: list[str], + beam_width: int, + max_tokens: int, + images: Optional[PromptImageInput] = None, + videos: Optional[PromptVideoInput] = None, + audios: Optional[PromptAudioInput] = None, + ) -> list[tuple[list[list[int]], list[str]]]: + outputs = self.generate( + prompts, + do_sample=False, + max_new_tokens=max_tokens, + num_beams=beam_width, + num_return_sequences=beam_width, + images=images, + videos=videos, + audios=audios, + ) + + for i in range(len(outputs)): + output_ids, output_str = outputs[i] + for j in range(len(output_ids)): + output_ids[j] = [x for x in output_ids[j] if x != self.tokenizer.pad_token_id] + outputs[i] = (output_ids, output_str) + return outputs + + def generate_greedy_logprobs( + self, + prompts: list[str], + max_tokens: int, + images: Optional[PromptImageInput] = None, + videos: Optional[PromptVideoInput] = None, + audios: Optional[PromptAudioInput] = None, + **kwargs: Any, + ) -> list[list[torch.Tensor]]: + all_inputs = self.get_inputs(prompts, images=images, videos=videos, audios=audios) + + all_logprobs: list[list[torch.Tensor]] = [] + for inputs in all_inputs: + output = self.model.generate( + **self.wrap_device(inputs), + use_cache=True, + do_sample=False, + max_new_tokens=max_tokens, + output_hidden_states=True, + return_dict_in_generate=True, + **kwargs, + ) + seq_logprobs = self._hidden_states_to_seq_logprobs(output.hidden_states) + all_logprobs.append(seq_logprobs) + return all_logprobs + + def _hidden_states_to_seq_logprobs( + self, + hidden_states: tuple[tuple[torch.Tensor, ...], ...], + ) -> list[torch.Tensor]: + output_embeddings = self.model.get_output_embeddings() + + seq_logprobs: list[torch.Tensor] = [] + for _, hidden_state in enumerate(hidden_states): + last_hidden_states = hidden_state[-1][0] + logits = torch.matmul( + last_hidden_states.to( + device=output_embeddings.weight.device, + dtype=output_embeddings.weight.dtype, + ), + output_embeddings.weight.t(), + ) + if getattr(output_embeddings, "bias", None) is not None: + logits += output_embeddings.bias.unsqueeze(0) + logprobs = F.log_softmax(logits, dim=-1, dtype=torch.float32) + seq_logprobs.append(logprobs) + + return seq_logprobs + + def _hidden_states_to_logprobs( + self, + hidden_states: tuple[tuple[torch.Tensor, ...], ...], + num_logprobs: Optional[int], + ) -> tuple[list[dict[int, float]], int]: + seq_logprobs = self._hidden_states_to_seq_logprobs(hidden_states) + output_len = len(hidden_states) + + # convert to dict + seq_logprobs_lst: list[dict[int, float]] = [] + for tok_idx, tok_logprobs in enumerate(seq_logprobs): + # drop prompt logprobs + if tok_idx == 0: + tok_logprobs = tok_logprobs[-1, :].reshape(1, -1) + topk = tok_logprobs.topk(num_logprobs) + + tok_logprobs_dct = {} + for token_id, logprob in zip(topk.indices[0], topk.values[0]): + tok_logprobs_dct[token_id.item()] = logprob.item() + + seq_logprobs_lst.append(tok_logprobs_dct) + + return ( + seq_logprobs_lst, + output_len, + ) + + def generate_greedy_logprobs_limit( + self, + prompts: list[str], + max_tokens: int, + num_logprobs: Optional[int], + images: Optional[PromptImageInput] = None, + audios: Optional[PromptAudioInput] = None, + videos: Optional[PromptVideoInput] = None, + **kwargs: Any, + ) -> list[TokensTextLogprobs]: + all_inputs = self.get_inputs(prompts, images=images, videos=videos, audios=audios) + + all_logprobs: list[list[dict[int, float]]] = [] + all_output_ids: list[list[int]] = [] + all_output_strs: list[str] = [] + + for inputs in all_inputs: + output = self.model.generate( + **self.wrap_device(inputs), + use_cache=True, + do_sample=False, + max_new_tokens=max_tokens, + output_hidden_states=True, + return_dict_in_generate=True, + **kwargs, + ) + + ( + seq_logprobs_lst, + output_len, + ) = self._hidden_states_to_logprobs(output.hidden_states, num_logprobs) + + all_logprobs.append(seq_logprobs_lst) + seq_ids = output.sequences[0] + output_len = len(seq_logprobs_lst) + output_ids = seq_ids[-output_len:] + all_output_ids.append(output_ids.tolist()) + all_output_strs.append(self.tokenizer.decode(output_ids)) + + outputs = zip(all_output_ids, all_output_strs, all_logprobs) + return [(output_ids, output_str, output_logprobs) for output_ids, output_str, output_logprobs in outputs] + + def encode(self, prompts: list[str], *args, **kwargs) -> list[list[torch.Tensor]]: + return self.model.encode(prompts, *args, **kwargs) + + def predict(self, prompts: list[list[str]], *args, **kwargs) -> torch.Tensor: + return self.model.predict(prompts, *args, convert_to_tensor=True, **kwargs) + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_value, traceback): + del self.model + cleanup_dist_env_and_memory() + + +@pytest.fixture(scope="session") +def hf_runner(): + return HfRunner + + +class VllmRunner: + """ + The default value of some arguments have been modified from + {class}`~vllm.LLM` as follows: + + - `trust_remote_code`: Set to `True` instead of `False` for convenience. + - `seed`: Set to `0` instead of `None` for test reproducibility. + - `max_model_len`: Set to `1024` instead of `None` to reduce memory usage. + - `block_size`: To reduce memory usage, set default to `64` if on XPU + devices, otherwise default to `16`. + - `enable_chunked_prefill`: Set to `False` instead of `None` for + test reproducibility. + - `enforce_eager`: Set to `False` to test CUDA graph. + """ + + def __init__( + self, + model_name: str, + runner: RunnerOption = "auto", + convert: ConvertOption = "auto", + tokenizer_name: Optional[str] = None, + tokenizer_mode: str = "auto", + trust_remote_code: bool = True, + seed: Optional[int] = 0, + max_model_len: Optional[int] = 1024, + dtype: str = "auto", + disable_log_stats: bool = True, + tensor_parallel_size: int = 1, + block_size: int = 16 if not torch.xpu.is_available() else 64, + enable_chunked_prefill: Optional[bool] = False, + swap_space: int = 4, + enforce_eager: Optional[bool] = False, + # Set this to avoid hanging issue + default_torch_num_threads: Optional[int] = None, + **kwargs, + ) -> None: + init_ctx = ( + nullcontext() + if default_torch_num_threads is None + else set_default_torch_num_threads(default_torch_num_threads) + ) + + if not kwargs.get("compilation_config", None): + # Note(@tdoublep): This is set to 4 because some tests (e.g., hybrid + # model tests) may set max_num_seqs=4. If min cudagraph_capture_size is + # set to larger than max_num_seqs, then it will lead to *no* graphs + # being captured which can trigger edge cases that we don't handle yet. + kwargs["compilation_config"] = {"cudagraph_capture_sizes": [4]} + + with init_ctx: + self.llm = LLM( + model=model_name, + runner=runner, + convert=convert, + tokenizer=tokenizer_name, + tokenizer_mode=tokenizer_mode, + trust_remote_code=trust_remote_code, + dtype=dtype, + seed=seed, + swap_space=swap_space, + enforce_eager=enforce_eager, + disable_log_stats=disable_log_stats, + tensor_parallel_size=tensor_parallel_size, + max_model_len=max_model_len, + block_size=block_size, + enable_chunked_prefill=enable_chunked_prefill, + **kwargs, + ) + + def get_inputs( + self, + prompts: Union[list[str], list[torch.Tensor], list[list[int]]], + images: Optional[PromptImageInput] = None, + videos: Optional[PromptVideoInput] = None, + audios: Optional[PromptAudioInput] = None, + ) -> list[dict[str, Any]]: + if any(x is not None and len(x) != len(prompts) for x in [images, videos, audios]): + raise ValueError("All non-None multimodal inputs must have the same length as prompts") + + inputs = list[dict[str, Any]]() + for i, prompt in enumerate(prompts): + prompt_dict = dict[str, Any]() + if isinstance(prompt, str): + prompt_dict["prompt"] = prompt + elif isinstance(prompt, list): + prompt_dict["prompt_token_ids"] = prompt + else: + prompt_dict["prompt_embeds"] = prompt + + multi_modal_data = dict[str, Any]() + if images is not None and (image := images[i]) is not None: + multi_modal_data["image"] = image + if videos is not None and (video := videos[i]) is not None: + multi_modal_data["video"] = video + if audios is not None and (audio := audios[i]) is not None: + multi_modal_data["audio"] = audio + + if multi_modal_data: + prompt_dict["multi_modal_data"] = multi_modal_data + + inputs.append(prompt_dict) + + return inputs + + def generate( + self, + prompts: Union[list[str], list[torch.Tensor], list[list[int]]], + sampling_params: SamplingParams, + images: Optional[PromptImageInput] = None, + videos: Optional[PromptVideoInput] = None, + audios: Optional[PromptAudioInput] = None, + **kwargs: Any, + ) -> list[tuple[list[list[int]], list[str]]]: + inputs = self.get_inputs(prompts, images=images, videos=videos, audios=audios) + + req_outputs = self.llm.generate(inputs, sampling_params=sampling_params, **kwargs) + + outputs: list[tuple[list[list[int]], list[str]]] = [] + for req_output in req_outputs: + prompt_str = req_output.prompt + prompt_ids = req_output.prompt_token_ids + req_sample_output_ids: list[list[int]] = [] + req_sample_output_strs: list[str] = [] + for sample in req_output.outputs: + output_str = sample.text + output_ids = list(sample.token_ids) + req_sample_output_ids.append(prompt_ids + output_ids) + req_sample_output_strs.append((prompt_str or "") + output_str) + outputs.append((req_sample_output_ids, req_sample_output_strs)) + return outputs + + @staticmethod + def _final_steps_generate_w_logprobs( + req_outputs: list[RequestOutput], + ) -> list[TokensTextLogprobsPromptLogprobs]: + outputs: list[TokensTextLogprobsPromptLogprobs] = [] + for req_output in req_outputs: + assert len(req_output.outputs) > 0 + for sample in req_output.outputs: + output_str = sample.text + output_ids = list(sample.token_ids) + output_logprobs = sample.logprobs + outputs.append((output_ids, output_str, output_logprobs, req_output.prompt_logprobs)) + return outputs + + def generate_w_logprobs( + self, + prompts: list[str], + sampling_params: SamplingParams, + images: Optional[PromptImageInput] = None, + audios: Optional[PromptAudioInput] = None, + videos: Optional[PromptVideoInput] = None, + **kwargs: Any, + ) -> Union[list[TokensTextLogprobs], list[TokensTextLogprobsPromptLogprobs]]: + inputs = self.get_inputs(prompts, images=images, videos=videos, audios=audios) + + req_outputs = self.llm.generate(inputs, sampling_params=sampling_params, **kwargs) + + toks_str_logsprobs_prompt_logprobs = self._final_steps_generate_w_logprobs(req_outputs) + # Omit prompt logprobs if not required by sampling params + return ( + [x[0:-1] for x in toks_str_logsprobs_prompt_logprobs] + if sampling_params.prompt_logprobs is None + else toks_str_logsprobs_prompt_logprobs + ) + + def generate_greedy( + self, + prompts: Union[list[str], list[torch.Tensor], list[list[int]]], + max_tokens: int, + images: Optional[PromptImageInput] = None, + videos: Optional[PromptVideoInput] = None, + audios: Optional[PromptAudioInput] = None, + **kwargs: Any, + ) -> list[tuple[list[int], str]]: + greedy_params = SamplingParams(temperature=0.0, max_tokens=max_tokens) + outputs = self.generate( + prompts, + greedy_params, + images=images, + videos=videos, + audios=audios, + **kwargs, + ) + return [(output_ids[0], output_str[0]) for output_ids, output_str in outputs] + + def generate_greedy_logprobs( + self, + prompts: list[str], + max_tokens: int, + num_logprobs: Optional[int], + num_prompt_logprobs: Optional[int] = None, + images: Optional[PromptImageInput] = None, + audios: Optional[PromptAudioInput] = None, + videos: Optional[PromptVideoInput] = None, + stop_token_ids: Optional[list[int]] = None, + stop: Optional[list[str]] = None, + **kwargs: Any, + ) -> Union[list[TokensTextLogprobs], list[TokensTextLogprobsPromptLogprobs]]: + greedy_logprobs_params = SamplingParams( + temperature=0.0, + max_tokens=max_tokens, + logprobs=num_logprobs, + prompt_logprobs=num_prompt_logprobs, + stop_token_ids=stop_token_ids, + stop=stop, + ) + + return self.generate_w_logprobs( + prompts, + greedy_logprobs_params, + images=images, + audios=audios, + videos=videos, + **kwargs, + ) + + def generate_prompt_perplexity(self, prompts: list[str]) -> list[float]: + """ + Return the perplexity score associated with generating the prompts + + :param prompts: list of prompts to score + :return: perplexity score of each prompt + """ + outputs = self.generate_greedy_logprobs(prompts, max_tokens=1, num_logprobs=None, num_prompt_logprobs=0) + + perplexities = [] + for output in outputs: + output = cast(TokensTextLogprobsPromptLogprobs, output) + token_data = cast(list[Optional[dict[int, Logprob]]], output[3]) + assert token_data[0] is None + token_log_probs = [] + for token_data in token_data[1:]: + assert token_data is not None + assert len(token_data) == 1 + token_log_prob = list(token_data.values())[0].logprob + token_log_probs.append(token_log_prob) + + perplexity = math.exp(-sum(token_log_probs) / len(token_log_probs)) + perplexities.append(perplexity) + + return perplexities + + def generate_beam_search( + self, + prompts: list[str], + beam_width: int, + max_tokens: int, + images: Optional[PromptImageInput] = None, + videos: Optional[PromptVideoInput] = None, + audios: Optional[PromptAudioInput] = None, + concurrency_limit: Optional[int] = None, + ) -> list[tuple[list[list[int]], list[str]]]: + inputs = self.get_inputs(prompts, images=images, videos=videos, audios=audios) + + outputs = self.llm.beam_search( + inputs, + BeamSearchParams(beam_width=beam_width, max_tokens=max_tokens), + concurrency_limit=concurrency_limit, + ) + returned_outputs = [] + for output in outputs: + token_ids = [x.tokens for x in output.sequences] + texts = [x.text for x in output.sequences] + returned_outputs.append((token_ids, texts)) + return returned_outputs + + def classify(self, prompts: list[str]) -> list[list[float]]: + req_outputs = self.llm.classify(prompts) + return [req_output.outputs.probs for req_output in req_outputs] + + def embed( + self, + prompts: list[str], + images: Optional[PromptImageInput] = None, + videos: Optional[PromptVideoInput] = None, + audios: Optional[PromptAudioInput] = None, + *args, + **kwargs, + ) -> list[list[float]]: + inputs = self.get_inputs(prompts, images=images, videos=videos, audios=audios) + + req_outputs = self.llm.embed(inputs, *args, **kwargs) + return [req_output.outputs.embedding for req_output in req_outputs] + + def encode(self, prompts: list[str]) -> list[list[float]]: + req_outputs = self.llm.encode(prompts) + return [req_output.outputs.data for req_output in req_outputs] + + def reward(self, prompts: list[str]) -> list[list[float]]: + req_outputs = self.llm.reward(prompts) + return [req_output.outputs.data for req_output in req_outputs] + + def score( + self, + text_1: Union[str, list[str]], + text_2: Union[str, list[str]], + *args, + **kwargs, + ) -> list[float]: + req_outputs = self.llm.score(text_1, text_2, *args, **kwargs) + return [req_output.outputs.score for req_output in req_outputs] + + def apply_model(self, func: Callable[[nn.Module], _R]) -> list[_R]: + return self.llm.apply_model(func) + + def get_llm(self) -> LLM: + return self.llm + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_value, traceback): + del self.llm + cleanup_dist_env_and_memory() + + +@pytest.fixture(scope="session") +def vllm_runner(): + return VllmRunner + + +@pytest.fixture() +def temporary_enable_log_propagate(): + import logging + + logger = logging.getLogger("vllm") + logger.propagate = True + yield + logger.propagate = False + + +@pytest.fixture() +def caplog_vllm(temporary_enable_log_propagate, caplog): + # To capture vllm log, we should enable propagate=True temporarily + # because caplog depends on logs propagated to the root logger. + yield caplog + + +@pytest.fixture(scope="session") +def num_gpus_available(): + """Get number of GPUs without initializing the CUDA context + in current process.""" + + from vllm.platforms import current_platform + + return current_platform.device_count() + + +temp_dir = tempfile.gettempdir() +_dummy_opt_path = os.path.join(temp_dir, "dummy_opt") +_dummy_llava_path = os.path.join(temp_dir, "dummy_llava") +_dummy_gemma2_embedding_path = os.path.join(temp_dir, "dummy_gemma2_embedding") + + +@pytest.fixture +def dummy_opt_path(): + json_path = os.path.join(_dummy_opt_path, "config.json") + if not os.path.exists(_dummy_opt_path): + snapshot_download( + repo_id="facebook/opt-125m", + local_dir=_dummy_opt_path, + ignore_patterns=["*.bin", "*.bin.index.json", "*.pt", "*.h5", "*.msgpack"], + ) + assert os.path.exists(json_path) + with open(json_path) as f: + config = json.load(f) + config["architectures"] = ["MyOPTForCausalLM"] + with open(json_path, "w") as f: + json.dump(config, f) + return _dummy_opt_path + + +@pytest.fixture +def dummy_llava_path(): + json_path = os.path.join(_dummy_llava_path, "config.json") + if not os.path.exists(_dummy_llava_path): + snapshot_download( + repo_id="llava-hf/llava-1.5-7b-hf", + local_dir=_dummy_llava_path, + ignore_patterns=[ + "*.bin", + "*.bin.index.json", + "*.pt", + "*.h5", + "*.msgpack", + "*.safetensors", + ], + ) + assert os.path.exists(json_path) + with open(json_path) as f: + config = json.load(f) + config["architectures"] = ["MyLlava"] + with open(json_path, "w") as f: + json.dump(config, f) + return _dummy_llava_path + + +@pytest.fixture +def dummy_gemma2_embedding_path(): + json_path = os.path.join(_dummy_gemma2_embedding_path, "config.json") + if not os.path.exists(_dummy_gemma2_embedding_path): + snapshot_download( + repo_id="BAAI/bge-multilingual-gemma2", + local_dir=_dummy_gemma2_embedding_path, + ignore_patterns=[ + "*.bin", + "*.bin.index.json", + "*.pt", + "*.h5", + "*.msgpack", + "*.safetensors", + ], + ) + assert os.path.exists(json_path) + with open(json_path) as f: + config = json.load(f) + config["architectures"] = ["MyGemma2Embedding"] + with open(json_path, "w") as f: + json.dump(config, f) + return _dummy_gemma2_embedding_path + + +# Add the flag `--optional` to allow run tests +# that are marked with @pytest.mark.optional +def pytest_addoption(parser): + parser.addoption("--optional", action="store_true", default=False, help="run optional test") + + +def pytest_collection_modifyitems(config, items): + if config.getoption("--optional"): + # --optional given in cli: do not skip optional tests + return + skip_optional = pytest.mark.skip(reason="need --optional option to run") + for item in items: + if "optional" in item.keywords: + item.add_marker(skip_optional) + + +@pytest.fixture(scope="session") +def cli_config_file(): + """Return the path to the CLI config file.""" + return os.path.join(_TEST_DIR, "config", "test_config.yaml") + + +@pytest.fixture(scope="session") +def cli_config_file_with_model(): + """Return the path to the CLI config file with model.""" + return os.path.join(_TEST_DIR, "config", "test_config_with_model.yaml") + + +class AssetHandler(http.server.BaseHTTPRequestHandler): + # _IMAGE_CACHE : Dict[str, bytes] = {} + + def log_message(self, *args, **kwargs): + pass + + def do_GET(self): + # Accepts paths like: /1280px-Venn_diagram_rgb.jpg + filename = self.path.lstrip("/") + if not filename or "." not in filename: + self.send_error(404, "Missing filename (expected /.)") + return + + base, ext = filename.rsplit(".", 1) + ext = ext.lower() + + if ext not in ["jpg", "png"]: + self.send_error(404, f"Unsupported extension: .{ext}") + return + + try: + data = ImageAsset(base).read_bytes(ext=ext) + except Exception as e: + self.send_error(500, f"Failed to load asset: {ext} {base} {e} ") + return + + ctype, _ = mimetypes.guess_type(filename) + if ctype is None: + ctype = {"jpg": "image/jpg", "png": "image/png"}[ext] + self.send_response(200) + self.send_header("Content-Type", ctype) + self.send_header("Content-Length", str(len(data))) + self.end_headers() + self.wfile.write(data) + + +def _find_free_port() -> int: + with socket.socket() as s: + s.bind(("127.0.0.1", 0)) + return s.getsockname()[1] + + +class LocalAssetServer: + address: str + port: int + server: Optional[http.server.ThreadingHTTPServer] + thread: Optional[threading.Thread] + + def __init__(self, address: str = "127.0.0.1") -> None: + self.address = address + self.port = -1 + self.server = None + self.thread = None + + def __enter__(self): + self.port = _find_free_port() + self.server = http.server.ThreadingHTTPServer((self.address, self.port), AssetHandler) + self.thread = threading.Thread(target=self.server.serve_forever, daemon=True) + self.thread.start() + return self + + def __exit__(self, exc_type, exc_value, traceback): + if self.server: + self.server.shutdown() + del self.server + + if self.thread: + self.thread.join() + del self.thread + + if exc_type is None: + return None + + return False + + @property + def base_url(self) -> str: + assert self.port is not None + return f"http://{self.address}:{self.port}" + + def url_for(self, name: str) -> str: + """e.g., name='RGBA_comp.png' -> 'http://127.0.0.1:PORT/RGBA_comp.png'""" + return f"{self.base_url}/{name}" + + def get_image_asset(self, name: str) -> Image.Image: + return fetch_image(self.url_for(name)) + + +@pytest.fixture(scope="session") +def local_asset_server() -> Generator[LocalAssetServer, None, None]: + """ + Starts a thread based HTTP server bound to 127.0.0.1 on a random free port. + The server currently servers images at: + http://127.0.0.1:/. + """ + with LocalAssetServer() as srv: + yield srv + + +@pytest.fixture +def image_url(request, local_asset_server) -> str: + # request.param is one of the IMAGE_ASSETS filenames + name = request.param + return local_asset_server.url_for(name) + + +@pytest.fixture +def image_urls(request, local_asset_server) -> list[str]: + """Indirect fixture: takes a list of names, returns list of full URLs.""" + names: list[str] = request.param + return [local_asset_server.url_for(name) for name in names] diff --git a/auto_round/experimental/vllm_ext/tests/test_mxfp4_moe.py b/auto_round/experimental/vllm_ext/tests/test_mxfp4_moe.py new file mode 100644 index 000000000..9581a5dbe --- /dev/null +++ b/auto_round/experimental/vllm_ext/tests/test_mxfp4_moe.py @@ -0,0 +1,26 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Test model set-up and inference for quantized HF models supported +on the AutoRound. + +Validating the configuration and printing results for manual checking. + +Run `pytest tests/quantization/test_auto_round.py`. +""" + +import pytest +from vllm.platforms import current_platform + +MODELS = ["/data5/yliu7/HF_HOME/unsloth-gpt-oss-20b-BF16-ar-MXFP4/"] + + +@pytest.mark.skipif( + not current_platform.is_cuda(), + reason="only supports CUDA backend.", +) +@pytest.mark.parametrize("model", MODELS) +def test_auto_round(vllm_runner, model): + with vllm_runner(model, enforce_eager=True) as llm: + output = llm.generate_greedy(["The capital of France is"], max_tokens=8) + assert output + print(f"{output[0][1]}") From 418e6a042dfde7fb242ebc1c12cf374185eae7da Mon Sep 17 00:00:00 2001 From: yiliu30 Date: Wed, 22 Oct 2025 22:45:27 -0400 Subject: [PATCH 03/52] fix import Signed-off-by: yiliu30 --- auto_round/experimental/vllm_ext/sitecustomize.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/auto_round/experimental/vllm_ext/sitecustomize.py b/auto_round/experimental/vllm_ext/sitecustomize.py index fcb2bc5a8..321673de8 100644 --- a/auto_round/experimental/vllm_ext/sitecustomize.py +++ b/auto_round/experimental/vllm_ext/sitecustomize.py @@ -14,9 +14,9 @@ print("*****************************************************************************") print(f"* !!! VLLM_ENABLE_AR_EXT is set to {VLLM_ENABLE_AR_EXT}, applying auto_round_vllm_extension *") print("*****************************************************************************") - from vllm.model_executor.layers.quantization import auto_round_vllm_extension as auto_round_ext + from .auto_round_vllm_extension import apply as auto_round_ext_apply - auto_round_ext.apply() + auto_round_ext_apply() else: print("*****************************************************************************") print( From 184783f60ecd3f5cf8f52fd6edaff2f0ec50a245 Mon Sep 17 00:00:00 2001 From: yiliu30 Date: Wed, 22 Oct 2025 22:47:03 -0400 Subject: [PATCH 04/52] clean envs Signed-off-by: yiliu30 --- auto_round/experimental/vllm_ext/envs_ext.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/auto_round/experimental/vllm_ext/envs_ext.py b/auto_round/experimental/vllm_ext/envs_ext.py index 964272ea1..3c053f6c4 100644 --- a/auto_round/experimental/vllm_ext/envs_ext.py +++ b/auto_round/experimental/vllm_ext/envs_ext.py @@ -10,10 +10,6 @@ # Define extra environment variables extra_environment_variables: dict[str, Callable[[], Any]] = { - "VLLM_AR_MXFP8_DISABLE_INPUT_QDQ": lambda: os.getenv("VLLM_AR_MXFP8_DISABLE_INPUT_QDQ", "0") - in ("1", "true", "True"), - "VLLM_AR_MXFP4_DISABLE_INPUT_QDQ": lambda: os.getenv("VLLM_AR_MXFP4_DISABLE_INPUT_QDQ", "0") - in ("1", "true", "True"), "VLLM_MXFP4_PRE_UNPACK_WEIGHTS": lambda: os.getenv("VLLM_MXFP4_PRE_UNPACK_WEIGHTS", "0") in ("1", "true", "True"), "VLLM_ENABLE_STATIC_MOE": lambda: os.getenv("VLLM_ENABLE_STATIC_MOE", "0") in ("1", "true", "True"), "VLLM_AR_MXFP4_MODULAR_MOE": lambda: os.getenv("VLLM_AR_MXFP4_MODULAR_MOE", "1") in ("1", "true", "True"), From b9da06fbd734738ef91c6ab16516d561e1deaee2 Mon Sep 17 00:00:00 2001 From: yiliu30 Date: Wed, 22 Oct 2025 22:55:55 -0400 Subject: [PATCH 05/52] add script for apply ext Signed-off-by: yiliu30 --- auto_round/experimental/vllm_ext/apply_ext.sh | 38 +++++++++++++++++++ 1 file changed, 38 insertions(+) create mode 100644 auto_round/experimental/vllm_ext/apply_ext.sh diff --git a/auto_round/experimental/vllm_ext/apply_ext.sh b/auto_round/experimental/vllm_ext/apply_ext.sh new file mode 100644 index 000000000..646c82660 --- /dev/null +++ b/auto_round/experimental/vllm_ext/apply_ext.sh @@ -0,0 +1,38 @@ +#!/bin/bash + +# Copyright (c) 2025 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Define the approximate path for the `auto-round` installation +AUTO_ROUND_PATH="auto_round/experimental/vllm_ext/sitecustomize.py" + +# Use pip to find the installation path of the `auto-round` package +PIP_LOCATION=$(pip show auto-round 2>/dev/null | grep "Location:" | awk '{print $2}') + +if [ -n "$PIP_LOCATION" ]; then + # Construct the full path to `sitecustomize.py` + SITE_CUSTOMIZE_PATH="$PIP_LOCATION/$AUTO_ROUND_PATH" + + if [ -f "$SITE_CUSTOMIZE_PATH" ]; then + echo "Found sitecustomize.py at: $SITE_CUSTOMIZE_PATH" + export PYTHONPATH=$(dirname "$SITE_CUSTOMIZE_PATH"):$PYTHONPATH + echo "PYTHONPATH set to: $PYTHONPATH" + else + echo "Error: sitecustomize.py not found in the auto-round installation path." + exit 1 + fi +else + echo "Error: auto-round package not found via pip." + exit 1 +fi \ No newline at end of file From 187f38df072ff9c58a56a7445dbe21320d6ddefb Mon Sep 17 00:00:00 2001 From: yiliu30 Date: Wed, 22 Oct 2025 22:57:43 -0400 Subject: [PATCH 06/52] clean docs Signed-off-by: yiliu30 --- .../experimental/vllm_ext/moe_impl_mxfp4.py | 24 ------------------- 1 file changed, 24 deletions(-) diff --git a/auto_round/experimental/vllm_ext/moe_impl_mxfp4.py b/auto_round/experimental/vllm_ext/moe_impl_mxfp4.py index 426fc535e..bc231ae26 100644 --- a/auto_round/experimental/vllm_ext/moe_impl_mxfp4.py +++ b/auto_round/experimental/vllm_ext/moe_impl_mxfp4.py @@ -255,30 +255,6 @@ def revert_interleaved_bias(bias): ) def revert_interleaved_w1(w1): - """ - w1[0,:8,:4] - tensor([[ 0.0000, 0.0000, 0.0000, -0.0625], - [ 0.0234, 0.0156, 0.0234, 0.0312], - [-0.0312, 0.0156, 0.0312, -0.0156], - [ 0.0078, 0.0078, 0.0078, 0.0156], - [ 0.0078, 0.0000, 0.0000, 0.0234], - [ 0.0000, -0.0078, -0.0156, -0.0469], - [-0.0469, 0.0078, -0.0625, 0.0234], - [ 0.0156, 0.0000, -0.0312, 0.0156]], device='cuda:0', - dtype=torch.bfloat16) - - - tensor([[ 0.0000, 0.0000, 0.0000, -0.0625], - [ 0.0156, -0.0000, -0.0000, -0.0938], - [ 0.0234, 0.0156, 0.0234, 0.0312], - [ 0.0312, 0.0000, -0.0156, 0.0000], - [-0.0312, 0.0156, 0.0312, -0.0156], - [-0.0234, 0.0469, -0.0312, -0.0312], - [ 0.0078, 0.0078, 0.0078, 0.0156], - [ 0.0469, -0.0312, -0.0469, 0.0625]], device='cuda:0', - dtype=torch.bfloat16) - - """ new_w1 = torch.zeros_like(w1) E, N, H = w1.shape new_w1[:, ::2, :] = w1[:, : N // 2, :] From 4031724d9f45669ee48cbb3540e3a6664c4e5439 Mon Sep 17 00:00:00 2001 From: yiliu30 Date: Thu, 23 Oct 2025 02:35:57 -0400 Subject: [PATCH 07/52] fix license Signed-off-by: yiliu30 --- auto_round/experimental/vllm_ext/__init__.py | 16 +++++++++++--- .../experimental/vllm_ext/auto_round_ext.py | 16 ++++++++++++-- auto_round/experimental/vllm_ext/envs_ext.py | 15 +++++++++++-- .../experimental/vllm_ext/moe_impl_mxfp4.py | 20 +++++++++++------ .../experimental/vllm_ext/quant_method_moe.py | 16 ++++++++++++-- .../experimental/vllm_ext/sitecustomize.py | 16 +++++++++++--- .../vllm_ext/tests/test_mxfp4_moe.py | 22 +++++++++++-------- 7 files changed, 93 insertions(+), 28 deletions(-) diff --git a/auto_round/experimental/vllm_ext/__init__.py b/auto_round/experimental/vllm_ext/__init__.py index aaadf910c..181ab399b 100644 --- a/auto_round/experimental/vllm_ext/__init__.py +++ b/auto_round/experimental/vllm_ext/__init__.py @@ -1,6 +1,16 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - +# Copyright (c) 2025 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. # ==---------------------------------------------------------------------------== # Apply the extension diff --git a/auto_round/experimental/vllm_ext/auto_round_ext.py b/auto_round/experimental/vllm_ext/auto_round_ext.py index cc75eb3b4..f1f6402ac 100644 --- a/auto_round/experimental/vllm_ext/auto_round_ext.py +++ b/auto_round/experimental/vllm_ext/auto_round_ext.py @@ -1,5 +1,17 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# Copyright (c) 2025 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from typing import Any import torch diff --git a/auto_round/experimental/vllm_ext/envs_ext.py b/auto_round/experimental/vllm_ext/envs_ext.py index 3c053f6c4..0132ad62e 100644 --- a/auto_round/experimental/vllm_ext/envs_ext.py +++ b/auto_round/experimental/vllm_ext/envs_ext.py @@ -1,5 +1,16 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# Copyright (c) 2025 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import os from typing import Any, Callable diff --git a/auto_round/experimental/vllm_ext/moe_impl_mxfp4.py b/auto_round/experimental/vllm_ext/moe_impl_mxfp4.py index bc231ae26..f4e1dbbb0 100644 --- a/auto_round/experimental/vllm_ext/moe_impl_mxfp4.py +++ b/auto_round/experimental/vllm_ext/moe_impl_mxfp4.py @@ -1,10 +1,16 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -# ==-------------------------------------------------------------------------== -# MOE MXFP4 -# ==------------------------------------------------------------------------== - +# Copyright (c) 2025 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. from typing import Callable, Optional, Union diff --git a/auto_round/experimental/vllm_ext/quant_method_moe.py b/auto_round/experimental/vllm_ext/quant_method_moe.py index 60206e947..2f152203a 100644 --- a/auto_round/experimental/vllm_ext/quant_method_moe.py +++ b/auto_round/experimental/vllm_ext/quant_method_moe.py @@ -1,5 +1,17 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# Copyright (c) 2025 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from typing import Optional import torch diff --git a/auto_round/experimental/vllm_ext/sitecustomize.py b/auto_round/experimental/vllm_ext/sitecustomize.py index 321673de8..3d34e2f4e 100644 --- a/auto_round/experimental/vllm_ext/sitecustomize.py +++ b/auto_round/experimental/vllm_ext/sitecustomize.py @@ -1,6 +1,16 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -# PYTHONPATH=/home/yliu7/workspace/inc/3rd-party/vllm/vllm/model_executor/layers/quantization/auto_round_vllm_extension/:$PYTHONPATH +# Copyright (c) 2025 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import os diff --git a/auto_round/experimental/vllm_ext/tests/test_mxfp4_moe.py b/auto_round/experimental/vllm_ext/tests/test_mxfp4_moe.py index 9581a5dbe..587c12628 100644 --- a/auto_round/experimental/vllm_ext/tests/test_mxfp4_moe.py +++ b/auto_round/experimental/vllm_ext/tests/test_mxfp4_moe.py @@ -1,12 +1,16 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -"""Test model set-up and inference for quantized HF models supported -on the AutoRound. - -Validating the configuration and printing results for manual checking. - -Run `pytest tests/quantization/test_auto_round.py`. -""" +# Copyright (c) 2025 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import pytest from vllm.platforms import current_platform From 5fe01ef2396b12eea0bc4bf0717967d8be2dce4b Mon Sep 17 00:00:00 2001 From: yiliu30 Date: Thu, 23 Oct 2025 02:43:50 -0400 Subject: [PATCH 08/52] fix Signed-off-by: yiliu30 --- auto_round/experimental/vllm_ext/fp4_utils.py | 13 +------------ 1 file changed, 1 insertion(+), 12 deletions(-) diff --git a/auto_round/experimental/vllm_ext/fp4_utils.py b/auto_round/experimental/vllm_ext/fp4_utils.py index 26e97e633..cf8148992 100644 --- a/auto_round/experimental/vllm_ext/fp4_utils.py +++ b/auto_round/experimental/vllm_ext/fp4_utils.py @@ -16,18 +16,7 @@ import torch -FLOAT_TO_E2M1 = [ - 0.0, - 0.5, - 1.0, - 1.5, - 2.0, - 3.0, - 4.0, - 6.0, -] - -# Module-level device tensor cache +# Module-level device tensor cache to fix cuda graph issue _DEVICE_E2M1_TENSORS = {} # Constants for FP4 values (E2M1 format) From 73f1e9bc4df7d313f024e5aca5249f2aca2d3072 Mon Sep 17 00:00:00 2001 From: yiliu30 Date: Thu, 23 Oct 2025 06:35:49 -0400 Subject: [PATCH 09/52] fix import and sitecustomize Signed-off-by: yiliu30 --- auto_round/experimental/vllm_ext/apply_ext.sh | 28 ++++++++++++------- .../experimental/vllm_ext/auto_round_ext.py | 3 +- .../experimental/vllm_ext/moe_impl_mxfp4.py | 4 +-- .../experimental/vllm_ext/mxfp4_qdq_utils.py | 5 ++-- .../experimental/vllm_ext/quant_method_moe.py | 2 +- .../experimental/vllm_ext/sitecustomize.py | 10 +++++-- 6 files changed, 32 insertions(+), 20 deletions(-) diff --git a/auto_round/experimental/vllm_ext/apply_ext.sh b/auto_round/experimental/vllm_ext/apply_ext.sh index 646c82660..f3fa5dcb0 100644 --- a/auto_round/experimental/vllm_ext/apply_ext.sh +++ b/auto_round/experimental/vllm_ext/apply_ext.sh @@ -14,25 +14,33 @@ # See the License for the specific language governing permissions and # limitations under the License. -# Define the approximate path for the `auto-round` installation +# Define the relative path for the `auto-round` installation AUTO_ROUND_PATH="auto_round/experimental/vllm_ext/sitecustomize.py" -# Use pip to find the installation path of the `auto-round` package +# Try to find the pip installation location PIP_LOCATION=$(pip show auto-round 2>/dev/null | grep "Location:" | awk '{print $2}') if [ -n "$PIP_LOCATION" ]; then - # Construct the full path to `sitecustomize.py` SITE_CUSTOMIZE_PATH="$PIP_LOCATION/$AUTO_ROUND_PATH" + echo "Checking for sitecustomize.py at: $SITE_CUSTOMIZE_PATH" if [ -f "$SITE_CUSTOMIZE_PATH" ]; then echo "Found sitecustomize.py at: $SITE_CUSTOMIZE_PATH" export PYTHONPATH=$(dirname "$SITE_CUSTOMIZE_PATH"):$PYTHONPATH echo "PYTHONPATH set to: $PYTHONPATH" - else - echo "Error: sitecustomize.py not found in the auto-round installation path." - exit 1 + return 0 2>/dev/null || true fi -else - echo "Error: auto-round package not found via pip." - exit 1 -fi \ No newline at end of file +fi + +# Fallback: check current directory +LOCAL_SITE_CUSTOMIZE="./sitecustomize.py" +if [ -f "$LOCAL_SITE_CUSTOMIZE" ]; then + echo "Found sitecustomize.py at current directory." + export PYTHONPATH=$(pwd):$PYTHONPATH + echo "PYTHONPATH set to: $PYTHONPATH" + return 0 2>/dev/null || true +fi + +echo "Warning: sitecustomize.py not found in pip installation or current directory." +# Do not exit the shell +return 1 2>/dev/null || true \ No newline at end of file diff --git a/auto_round/experimental/vllm_ext/auto_round_ext.py b/auto_round/experimental/vllm_ext/auto_round_ext.py index f1f6402ac..29fed4aa0 100644 --- a/auto_round/experimental/vllm_ext/auto_round_ext.py +++ b/auto_round/experimental/vllm_ext/auto_round_ext.py @@ -15,14 +15,13 @@ from typing import Any import torch +from quant_method_moe import AutoRoundMoEMethod from vllm.logger import init_logger from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.quantization.auto_round import AutoRoundConfig from auto_round.schemes import QuantizationScheme -from .quant_method_moe import AutoRoundMoEMethod - logger = init_logger(__name__) diff --git a/auto_round/experimental/vllm_ext/moe_impl_mxfp4.py b/auto_round/experimental/vllm_ext/moe_impl_mxfp4.py index f4e1dbbb0..fb994b2fb 100644 --- a/auto_round/experimental/vllm_ext/moe_impl_mxfp4.py +++ b/auto_round/experimental/vllm_ext/moe_impl_mxfp4.py @@ -28,7 +28,7 @@ from vllm.model_executor.utils import set_weight_attrs logger = init_logger(__name__) -from .quant_method_moe import AutoRoundMoEMethod +from quant_method_moe import AutoRoundMoEMethod def apply_act(local_w1_out: torch.Tensor, local_w3_out: torch.Tensor, activation: str) -> torch.Tensor: @@ -248,7 +248,7 @@ def revert_interleaved_bias(bias): layer.w13_bias.data.copy_(w13_bias_swapped) if envs.VLLM_MXFP4_PRE_UNPACK_WEIGHTS: - import vllm.model_executor.layers.quantization.auto_round_vllm_extension.mxfp4_qdq_utils as mxfp4_utils + import mxfp4_qdq_utils as mxfp4_utils w1 = layer.w13_weight_packed w1_scale = layer.w13_weight_scale diff --git a/auto_round/experimental/vllm_ext/mxfp4_qdq_utils.py b/auto_round/experimental/vllm_ext/mxfp4_qdq_utils.py index d90272c55..ad43ac5e2 100644 --- a/auto_round/experimental/vllm_ext/mxfp4_qdq_utils.py +++ b/auto_round/experimental/vllm_ext/mxfp4_qdq_utils.py @@ -15,9 +15,8 @@ from typing import Union import torch - -from .fp4_utils import cast_to_fp4, pack_fp4_to_uint8, unpack_fp4_from_uint8 -from .utils import _to_mx_rceil, get_fp_scale +from fp4_utils import cast_to_fp4, pack_fp4_to_uint8, unpack_fp4_from_uint8 +from utils import _to_mx_rceil, get_fp_scale F4_E2M1_MAX = 6.0 F32_EXP_BIAS = 127 diff --git a/auto_round/experimental/vllm_ext/quant_method_moe.py b/auto_round/experimental/vllm_ext/quant_method_moe.py index 2f152203a..b3394f004 100644 --- a/auto_round/experimental/vllm_ext/quant_method_moe.py +++ b/auto_round/experimental/vllm_ext/quant_method_moe.py @@ -69,7 +69,7 @@ def get_impl(scheme: QuantizationScheme): return UnquantizedFusedMoEMethod(layer.moe_config) elif _is_mxfp4_w4a4(scheme): - from .moe_impl_mxfp4 import AutoRoundMoEMethodMXFp4Impl + from moe_impl_mxfp4 import AutoRoundMoEMethodMXFp4Impl return AutoRoundMoEMethodMXFp4Impl(quant_config, layer.moe_config) diff --git a/auto_round/experimental/vllm_ext/sitecustomize.py b/auto_round/experimental/vllm_ext/sitecustomize.py index 3d34e2f4e..63263b634 100644 --- a/auto_round/experimental/vllm_ext/sitecustomize.py +++ b/auto_round/experimental/vllm_ext/sitecustomize.py @@ -24,9 +24,15 @@ print("*****************************************************************************") print(f"* !!! VLLM_ENABLE_AR_EXT is set to {VLLM_ENABLE_AR_EXT}, applying auto_round_vllm_extension *") print("*****************************************************************************") - from .auto_round_vllm_extension import apply as auto_round_ext_apply - auto_round_ext_apply() + import vllm.model_executor.layers.quantization.auto_round as auto_round_module + + from auto_round.experimental.vllm_ext.auto_round_ext import AutoRoundExtensionConfig + + auto_round_module.AutoRoundConfig = AutoRoundExtensionConfig + from auto_round.experimental.vllm_ext.envs_ext import extra_environment_variables + + else: print("*****************************************************************************") print( From 849585412a460fe31fea4e6047f0f338b72b7fcf Mon Sep 17 00:00:00 2001 From: yiliu30 Date: Fri, 24 Oct 2025 01:03:01 -0400 Subject: [PATCH 10/52] move to ext Signed-off-by: yiliu30 --- .../experimental => auto_round_extension}/vllm_ext/__init__.py | 0 .../experimental => auto_round_extension}/vllm_ext/apply_ext.sh | 2 +- .../vllm_ext/auto_round_ext.py | 0 .../experimental => auto_round_extension}/vllm_ext/envs_ext.py | 0 .../experimental => auto_round_extension}/vllm_ext/fp4_utils.py | 0 .../vllm_ext/moe_impl_mxfp4.py | 0 .../vllm_ext/mxfp4_qdq_utils.py | 0 .../vllm_ext/quant_method_moe.py | 0 .../vllm_ext/sitecustomize.py | 0 .../vllm_ext/tests/conftest.py | 0 .../vllm_ext/tests/test_mxfp4_moe.py | 0 .../experimental => auto_round_extension}/vllm_ext/utils.py | 0 12 files changed, 1 insertion(+), 1 deletion(-) rename {auto_round/experimental => auto_round_extension}/vllm_ext/__init__.py (100%) rename {auto_round/experimental => auto_round_extension}/vllm_ext/apply_ext.sh (95%) rename {auto_round/experimental => auto_round_extension}/vllm_ext/auto_round_ext.py (100%) rename {auto_round/experimental => auto_round_extension}/vllm_ext/envs_ext.py (100%) rename {auto_round/experimental => auto_round_extension}/vllm_ext/fp4_utils.py (100%) rename {auto_round/experimental => auto_round_extension}/vllm_ext/moe_impl_mxfp4.py (100%) rename {auto_round/experimental => auto_round_extension}/vllm_ext/mxfp4_qdq_utils.py (100%) rename {auto_round/experimental => auto_round_extension}/vllm_ext/quant_method_moe.py (100%) rename {auto_round/experimental => auto_round_extension}/vllm_ext/sitecustomize.py (100%) rename {auto_round/experimental => auto_round_extension}/vllm_ext/tests/conftest.py (100%) rename {auto_round/experimental => auto_round_extension}/vllm_ext/tests/test_mxfp4_moe.py (100%) rename {auto_round/experimental => auto_round_extension}/vllm_ext/utils.py (100%) diff --git a/auto_round/experimental/vllm_ext/__init__.py b/auto_round_extension/vllm_ext/__init__.py similarity index 100% rename from auto_round/experimental/vllm_ext/__init__.py rename to auto_round_extension/vllm_ext/__init__.py diff --git a/auto_round/experimental/vllm_ext/apply_ext.sh b/auto_round_extension/vllm_ext/apply_ext.sh similarity index 95% rename from auto_round/experimental/vllm_ext/apply_ext.sh rename to auto_round_extension/vllm_ext/apply_ext.sh index f3fa5dcb0..b327ec22c 100644 --- a/auto_round/experimental/vllm_ext/apply_ext.sh +++ b/auto_round_extension/vllm_ext/apply_ext.sh @@ -15,7 +15,7 @@ # limitations under the License. # Define the relative path for the `auto-round` installation -AUTO_ROUND_PATH="auto_round/experimental/vllm_ext/sitecustomize.py" +AUTO_ROUND_PATH="auto_round/../auto_round_extension/vllm_ext/sitecustomize.py" # Try to find the pip installation location PIP_LOCATION=$(pip show auto-round 2>/dev/null | grep "Location:" | awk '{print $2}') diff --git a/auto_round/experimental/vllm_ext/auto_round_ext.py b/auto_round_extension/vllm_ext/auto_round_ext.py similarity index 100% rename from auto_round/experimental/vllm_ext/auto_round_ext.py rename to auto_round_extension/vllm_ext/auto_round_ext.py diff --git a/auto_round/experimental/vllm_ext/envs_ext.py b/auto_round_extension/vllm_ext/envs_ext.py similarity index 100% rename from auto_round/experimental/vllm_ext/envs_ext.py rename to auto_round_extension/vllm_ext/envs_ext.py diff --git a/auto_round/experimental/vllm_ext/fp4_utils.py b/auto_round_extension/vllm_ext/fp4_utils.py similarity index 100% rename from auto_round/experimental/vllm_ext/fp4_utils.py rename to auto_round_extension/vllm_ext/fp4_utils.py diff --git a/auto_round/experimental/vllm_ext/moe_impl_mxfp4.py b/auto_round_extension/vllm_ext/moe_impl_mxfp4.py similarity index 100% rename from auto_round/experimental/vllm_ext/moe_impl_mxfp4.py rename to auto_round_extension/vllm_ext/moe_impl_mxfp4.py diff --git a/auto_round/experimental/vllm_ext/mxfp4_qdq_utils.py b/auto_round_extension/vllm_ext/mxfp4_qdq_utils.py similarity index 100% rename from auto_round/experimental/vllm_ext/mxfp4_qdq_utils.py rename to auto_round_extension/vllm_ext/mxfp4_qdq_utils.py diff --git a/auto_round/experimental/vllm_ext/quant_method_moe.py b/auto_round_extension/vllm_ext/quant_method_moe.py similarity index 100% rename from auto_round/experimental/vllm_ext/quant_method_moe.py rename to auto_round_extension/vllm_ext/quant_method_moe.py diff --git a/auto_round/experimental/vllm_ext/sitecustomize.py b/auto_round_extension/vllm_ext/sitecustomize.py similarity index 100% rename from auto_round/experimental/vllm_ext/sitecustomize.py rename to auto_round_extension/vllm_ext/sitecustomize.py diff --git a/auto_round/experimental/vllm_ext/tests/conftest.py b/auto_round_extension/vllm_ext/tests/conftest.py similarity index 100% rename from auto_round/experimental/vllm_ext/tests/conftest.py rename to auto_round_extension/vllm_ext/tests/conftest.py diff --git a/auto_round/experimental/vllm_ext/tests/test_mxfp4_moe.py b/auto_round_extension/vllm_ext/tests/test_mxfp4_moe.py similarity index 100% rename from auto_round/experimental/vllm_ext/tests/test_mxfp4_moe.py rename to auto_round_extension/vllm_ext/tests/test_mxfp4_moe.py diff --git a/auto_round/experimental/vllm_ext/utils.py b/auto_round_extension/vllm_ext/utils.py similarity index 100% rename from auto_round/experimental/vllm_ext/utils.py rename to auto_round_extension/vllm_ext/utils.py From c47393448a09e146dfda1439a95ed2d99ce56470 Mon Sep 17 00:00:00 2001 From: yiliu30 Date: Fri, 24 Oct 2025 04:04:14 -0400 Subject: [PATCH 11/52] update mxfp4 Signed-off-by: yiliu30 --- .../vllm_ext/auto_round_ext.py | 5 +- auto_round_extension/vllm_ext/envs_ext.py | 6 +- .../vllm_ext/moe_impl_mxfp4.py | 6 +- .../vllm_ext/sitecustomize.py | 4 +- .../vllm_ext/tests/conftest.py | 323 +++++++++++------- .../vllm_ext/tests/test_mxfp4_moe.py | 8 +- 6 files changed, 225 insertions(+), 127 deletions(-) diff --git a/auto_round_extension/vllm_ext/auto_round_ext.py b/auto_round_extension/vllm_ext/auto_round_ext.py index 29fed4aa0..69d1d3f92 100644 --- a/auto_round_extension/vllm_ext/auto_round_ext.py +++ b/auto_round_extension/vllm_ext/auto_round_ext.py @@ -18,6 +18,7 @@ from quant_method_moe import AutoRoundMoEMethod from vllm.logger import init_logger from vllm.model_executor.layers.fused_moe import FusedMoE +from vllm.model_executor.layers.linear import LinearBase, UnquantizedLinearMethod from vllm.model_executor.layers.quantization.auto_round import AutoRoundConfig from auto_round.schemes import QuantizationScheme @@ -34,8 +35,8 @@ def get_quant_method(self, layer: torch.nn.Module, prefix: str): if isinstance(layer, FusedMoE): quant_method = AutoRoundMoEMethod.get_moe_method(self, layer, prefix) return quant_method - - return super().get_quant_method(layer, prefix) + elif isinstance(layer, LinearBase): + return UnquantizedLinearMethod() @staticmethod def _parse_quant_scheme(config: dict): diff --git a/auto_round_extension/vllm_ext/envs_ext.py b/auto_round_extension/vllm_ext/envs_ext.py index 0132ad62e..325b9de7c 100644 --- a/auto_round_extension/vllm_ext/envs_ext.py +++ b/auto_round_extension/vllm_ext/envs_ext.py @@ -21,9 +21,9 @@ # Define extra environment variables extra_environment_variables: dict[str, Callable[[], Any]] = { - "VLLM_MXFP4_PRE_UNPACK_WEIGHTS": lambda: os.getenv("VLLM_MXFP4_PRE_UNPACK_WEIGHTS", "0") in ("1", "true", "True"), - "VLLM_ENABLE_STATIC_MOE": lambda: os.getenv("VLLM_ENABLE_STATIC_MOE", "0") in ("1", "true", "True"), - "VLLM_AR_MXFP4_MODULAR_MOE": lambda: os.getenv("VLLM_AR_MXFP4_MODULAR_MOE", "1") in ("1", "true", "True"), + "VLLM_MXFP4_PRE_UNPACK_WEIGHTS": lambda: os.getenv("VLLM_MXFP4_PRE_UNPACK_WEIGHTS", "1") in ("1", "true", "True"), + "VLLM_ENABLE_STATIC_MOE": lambda: os.getenv("VLLM_ENABLE_STATIC_MOE", "1") in ("1", "true", "True"), + "VLLM_AR_MXFP4_MODULAR_MOE": lambda: os.getenv("VLLM_AR_MXFP4_MODULAR_MOE", "0") in ("1", "true", "True"), } # Add the extra environment variables to vllm.envs import vllm.envs as envs diff --git a/auto_round_extension/vllm_ext/moe_impl_mxfp4.py b/auto_round_extension/vllm_ext/moe_impl_mxfp4.py index fb994b2fb..1bfccb623 100644 --- a/auto_round_extension/vllm_ext/moe_impl_mxfp4.py +++ b/auto_round_extension/vllm_ext/moe_impl_mxfp4.py @@ -173,7 +173,7 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: if envs.VLLM_ENABLE_STATIC_MOE: if envs.VLLM_MXFP4_PRE_UNPACK_WEIGHTS: weight_name_lst = ["w13_weight", "w2_weight"] - from .mxfp4_qdq_utils import dequant_mxfp4_to_fp8 + from mxfp4_qdq_utils import dequant_mxfp4_to_fp8 for weight_name_prefix in weight_name_lst: weight_name = f"{weight_name_prefix}_packed" @@ -436,7 +436,7 @@ def apply( local_w3_packed = local_w13_packed[intermediate_size_per_partition:, ...] local_w3_scale = local_w13_scale[intermediate_size_per_partition:, ...] - from .mxfp4_qdq_utils import run_mxfp4_emulations + from mxfp4_qdq_utils import run_mxfp4_emulations local_w1_bias = None local_w2_bias = None @@ -504,7 +504,7 @@ def apply( local_w3_bias = local_w13_bias[intermediate_size_per_partition:] local_w2_bias = layer.w2_bias[expert_index] - from .mxfp4_qdq_utils import mxfp4_gemm_with_unpacked_weight + from mxfp4_qdq_utils import mxfp4_gemm_with_unpacked_weight local_w1_out = mxfp4_gemm_with_unpacked_weight( x=current_state_static, diff --git a/auto_round_extension/vllm_ext/sitecustomize.py b/auto_round_extension/vllm_ext/sitecustomize.py index 63263b634..8c22d1ab2 100644 --- a/auto_round_extension/vllm_ext/sitecustomize.py +++ b/auto_round_extension/vllm_ext/sitecustomize.py @@ -27,10 +27,10 @@ import vllm.model_executor.layers.quantization.auto_round as auto_round_module - from auto_round.experimental.vllm_ext.auto_round_ext import AutoRoundExtensionConfig + from auto_round_extension.vllm_ext.auto_round_ext import AutoRoundExtensionConfig auto_round_module.AutoRoundConfig = AutoRoundExtensionConfig - from auto_round.experimental.vllm_ext.envs_ext import extra_environment_variables + from auto_round_extension.vllm_ext.envs_ext import extra_environment_variables else: diff --git a/auto_round_extension/vllm_ext/tests/conftest.py b/auto_round_extension/vllm_ext/tests/conftest.py index a73fe7e74..fcdd97a49 100644 --- a/auto_round_extension/vllm_ext/tests/conftest.py +++ b/auto_round_extension/vllm_ext/tests/conftest.py @@ -1,11 +1,13 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -# Note: Copied from vllm/tests/conftest.py - -# ruff: noqa +import contextlib +import pathlib +from copy import deepcopy from tblib import pickling_support +# ruff: noqa + # Install support for pickling exceptions so that we can nicely propagate # failures from tests running in a subprocess. # This should be run before any custom exception subclasses are defined. @@ -19,12 +21,10 @@ import socket import tempfile import threading -import warnings -from collections.abc import Generator, Sequence +from collections.abc import Generator from contextlib import nullcontext -from dataclasses import dataclass from enum import Enum -from typing import Any, Callable, Optional, TypedDict, TypeVar, Union, cast +from typing import Any, Callable, TypedDict, TypeVar, cast import numpy as np import pytest @@ -39,13 +39,29 @@ AutoTokenizer, BatchEncoding, BatchFeature, - PretrainedConfig, ) from transformers.models.auto.auto_factory import _BaseAutoModelClass -from vllm.config.model import ModelConfig, ModelDType, RunnerOption + +# from tests.models.utils import TokensTextLogprobs, TokensTextLogprobsPromptLogprobs +from vllm import LLM, SamplingParams, envs +from vllm.assets.audio import AudioAsset +from vllm.assets.image import ImageAsset +from vllm.assets.video import VideoAsset +from vllm.config.model import ConvertOption, RunnerOption, _get_and_verify_dtype +from vllm.connections import global_http_connection +from vllm.distributed import ( + cleanup_dist_env_and_memory, + init_distributed_environment, + initialize_model_parallel, +) +from vllm.logger import init_logger from vllm.logprobs import Logprob, PromptLogprobs, SampleLogprobs -from vllm.multimodal.processing import InputProcessingContext -from vllm.transformers_utils.tokenizer import cached_tokenizer_from_config +from vllm.multimodal.utils import fetch_image +from vllm.outputs import RequestOutput +from vllm.sampling_params import BeamSearchParams +from vllm.transformers_utils.utils import maybe_model_redirect +from vllm.utils.collection_utils import is_list_of +from vllm.utils.torch_utils import set_default_torch_num_threads # Representation of generated sequence as a tuple of # * Token ID list @@ -53,7 +69,7 @@ # * List of top sample logprobs for each sampled token # # Assumes prompt logprobs were not requested. -TokensTextLogprobs = tuple[list[int], str, Optional[Union[list[dict[int, float]], SampleLogprobs]]] +TokensTextLogprobs = tuple[list[int], str, list[dict[int, float]] | SampleLogprobs | None] # Allow for tokens to be represented as str's rather than IDs; # tuple of @@ -62,7 +78,7 @@ # * Optional list of top sample logprobs for each sampled token # # Assumes prompt logprobs were not requested. -TextTextLogprobs = tuple[list[str], str, Optional[Union[list[dict[str, float]], list[dict[str, Logprob]]]]] +TextTextLogprobs = tuple[list[str], str, list[dict[str, float]] | list[dict[str, Logprob]] | None] # Representation of generated sequence as a tuple of # * Token ID list @@ -74,33 +90,11 @@ TokensTextLogprobsPromptLogprobs = tuple[ list[int], str, - Optional[Union[list[dict[int, float]], SampleLogprobs]], - Optional[Union[list[Optional[dict[int, float]]], PromptLogprobs]], + list[dict[int, float]] | SampleLogprobs | None, + list[dict[int, float] | None] | PromptLogprobs | None, ] -# from tests.models.utils import TokensTextLogprobs, TokensTextLogprobsPromptLogprobs - - -from vllm import LLM, SamplingParams -from vllm.assets.audio import AudioAsset -from vllm.assets.image import ImageAsset -from vllm.assets.video import VideoAsset -from vllm.config.model import ConvertOption, RunnerOption, _get_and_verify_dtype -from vllm.connections import global_http_connection -from vllm.distributed import ( - cleanup_dist_env_and_memory, - init_distributed_environment, - initialize_model_parallel, -) -from vllm.logger import init_logger -from vllm.logprobs import Logprob -from vllm.multimodal.utils import fetch_image -from vllm.outputs import RequestOutput -from vllm.sampling_params import BeamSearchParams -from vllm.transformers_utils.utils import maybe_model_redirect -from vllm.utils import is_list_of, set_default_torch_num_threads - logger = init_logger(__name__) _TEST_DIR = os.path.dirname(__file__) @@ -110,7 +104,7 @@ _M = TypeVar("_M") -_PromptMultiModalInput = Union[list[_M], list[list[_M]]] +_PromptMultiModalInput = list[_M] | list[list[_M]] PromptImageInput = _PromptMultiModalInput[Image.Image] PromptAudioInput = _PromptMultiModalInput[tuple[np.ndarray, int]] @@ -309,7 +303,7 @@ def get_default_device(self): return "cpu" if current_platform.is_cpu() else current_platform.device_type - def wrap_device(self, x: _T, device: Optional[str] = None) -> _T: + def wrap_device(self, x: _T, device: str | None = None) -> _T: if x is None or isinstance(x, (bool,)): return x @@ -329,14 +323,14 @@ def __init__( model_name: str, dtype: str = "auto", *, - model_kwargs: Optional[dict[str, Any]] = None, + model_kwargs: dict[str, Any] | None = None, trust_remote_code: bool = True, is_sentence_transformer: bool = False, is_cross_encoder: bool = False, skip_tokenizer_init: bool = False, auto_cls: type[_BaseAutoModelClass] = AutoModelForCausalLM, # Set this to avoid hanging issue - default_torch_num_threads: Optional[int] = None, + default_torch_num_threads: int | None = None, ) -> None: init_ctx = ( nullcontext() @@ -361,7 +355,7 @@ def _init( model_name: str, dtype: str = "auto", *, - model_kwargs: Optional[dict[str, Any]] = None, + model_kwargs: dict[str, Any] | None = None, trust_remote_code: bool = True, is_sentence_transformer: bool = False, is_cross_encoder: bool = False, @@ -376,7 +370,7 @@ def _init( trust_remote_code=trust_remote_code, ) self.device = self.get_default_device() - self.dtype = torch_dtype = _get_and_verify_dtype( + self.dtype = dtype = _get_and_verify_dtype( self.model_name, self.config, dtype=dtype, @@ -384,7 +378,7 @@ def _init( ) model_kwargs = model_kwargs if model_kwargs is not None else {} - model_kwargs.setdefault("torch_dtype", torch_dtype) + model_kwargs.setdefault("dtype", dtype) if is_sentence_transformer: # Lazy init required for AMD CI @@ -430,7 +424,7 @@ def _init( if not skip_tokenizer_init: self.tokenizer = AutoTokenizer.from_pretrained( model_name, - torch_dtype=torch_dtype, + dtype=dtype, trust_remote_code=trust_remote_code, ) @@ -440,7 +434,7 @@ def _init( self.processor = AutoProcessor.from_pretrained( model_name, - torch_dtype=torch_dtype, + dtype=dtype, trust_remote_code=trust_remote_code, ) if skip_tokenizer_init: @@ -448,11 +442,11 @@ def _init( def get_inputs( self, - prompts: Union[list[str], list[list[int]]], - images: Optional[PromptImageInput] = None, - videos: Optional[PromptVideoInput] = None, - audios: Optional[PromptAudioInput] = None, - ) -> list[Union[BatchFeature, BatchEncoding, dict[str, torch.Tensor]]]: + prompts: list[str] | list[list[int]], + images: PromptImageInput | None = None, + videos: PromptVideoInput | None = None, + audios: PromptAudioInput | None = None, + ) -> list[BatchFeature | BatchEncoding | dict[str, torch.Tensor]]: if images is not None: assert len(prompts) == len(images) @@ -462,7 +456,7 @@ def get_inputs( if audios is not None: assert len(prompts) == len(audios) - all_inputs: list[Union[BatchFeature, BatchEncoding, dict[str, torch.Tensor]]] = [] + all_inputs: list[BatchFeature | BatchEncoding | dict[str, torch.Tensor]] = [] for i, prompt in enumerate(prompts): if isinstance(prompt, str): processor_kwargs: dict[str, Any] = { @@ -530,10 +524,10 @@ def classify(self, prompts: list[str]) -> list[str]: def generate( self, - prompts: Union[list[str], list[list[int]]], - images: Optional[PromptImageInput] = None, - videos: Optional[PromptVideoInput] = None, - audios: Optional[PromptAudioInput] = None, + prompts: list[str] | list[list[int]], + images: PromptImageInput | None = None, + videos: PromptVideoInput | None = None, + audios: PromptAudioInput | None = None, **kwargs: Any, ) -> list[tuple[list[list[int]], list[str]]]: all_inputs = self.get_inputs(prompts, images=images, videos=videos, audios=audios) @@ -556,11 +550,11 @@ def generate( def generate_greedy( self, - prompts: Union[list[str], list[list[int]]], + prompts: list[str] | list[list[int]], max_tokens: int, - images: Optional[PromptImageInput] = None, - videos: Optional[PromptVideoInput] = None, - audios: Optional[PromptAudioInput] = None, + images: PromptImageInput | None = None, + videos: PromptVideoInput | None = None, + audios: PromptAudioInput | None = None, **kwargs: Any, ) -> list[tuple[list[int], str]]: outputs = self.generate( @@ -580,9 +574,9 @@ def generate_beam_search( prompts: list[str], beam_width: int, max_tokens: int, - images: Optional[PromptImageInput] = None, - videos: Optional[PromptVideoInput] = None, - audios: Optional[PromptAudioInput] = None, + images: PromptImageInput | None = None, + videos: PromptVideoInput | None = None, + audios: PromptAudioInput | None = None, ) -> list[tuple[list[list[int]], list[str]]]: outputs = self.generate( prompts, @@ -606,9 +600,9 @@ def generate_greedy_logprobs( self, prompts: list[str], max_tokens: int, - images: Optional[PromptImageInput] = None, - videos: Optional[PromptVideoInput] = None, - audios: Optional[PromptAudioInput] = None, + images: PromptImageInput | None = None, + videos: PromptVideoInput | None = None, + audios: PromptAudioInput | None = None, **kwargs: Any, ) -> list[list[torch.Tensor]]: all_inputs = self.get_inputs(prompts, images=images, videos=videos, audios=audios) @@ -654,7 +648,7 @@ def _hidden_states_to_seq_logprobs( def _hidden_states_to_logprobs( self, hidden_states: tuple[tuple[torch.Tensor, ...], ...], - num_logprobs: Optional[int], + num_logprobs: int | None, ) -> tuple[list[dict[int, float]], int]: seq_logprobs = self._hidden_states_to_seq_logprobs(hidden_states) output_len = len(hidden_states) @@ -682,10 +676,10 @@ def generate_greedy_logprobs_limit( self, prompts: list[str], max_tokens: int, - num_logprobs: Optional[int], - images: Optional[PromptImageInput] = None, - audios: Optional[PromptAudioInput] = None, - videos: Optional[PromptVideoInput] = None, + num_logprobs: int | None, + images: PromptImageInput | None = None, + audios: PromptAudioInput | None = None, + videos: PromptVideoInput | None = None, **kwargs: Any, ) -> list[TokensTextLogprobs]: all_inputs = self.get_inputs(prompts, images=images, videos=videos, audios=audios) @@ -759,20 +753,20 @@ def __init__( model_name: str, runner: RunnerOption = "auto", convert: ConvertOption = "auto", - tokenizer_name: Optional[str] = None, + tokenizer_name: str | None = None, tokenizer_mode: str = "auto", trust_remote_code: bool = True, - seed: Optional[int] = 0, - max_model_len: Optional[int] = 1024, + seed: int | None = 0, + max_model_len: int | None = 1024, dtype: str = "auto", disable_log_stats: bool = True, tensor_parallel_size: int = 1, block_size: int = 16 if not torch.xpu.is_available() else 64, - enable_chunked_prefill: Optional[bool] = False, + enable_chunked_prefill: bool | None = False, swap_space: int = 4, - enforce_eager: Optional[bool] = False, + enforce_eager: bool | None = False, # Set this to avoid hanging issue - default_torch_num_threads: Optional[int] = None, + default_torch_num_threads: int | None = None, **kwargs, ) -> None: init_ctx = ( @@ -810,10 +804,10 @@ def __init__( def get_inputs( self, - prompts: Union[list[str], list[torch.Tensor], list[list[int]]], - images: Optional[PromptImageInput] = None, - videos: Optional[PromptVideoInput] = None, - audios: Optional[PromptAudioInput] = None, + prompts: list[str] | list[torch.Tensor] | list[list[int]], + images: PromptImageInput | None = None, + videos: PromptVideoInput | None = None, + audios: PromptAudioInput | None = None, ) -> list[dict[str, Any]]: if any(x is not None and len(x) != len(prompts) for x in [images, videos, audios]): raise ValueError("All non-None multimodal inputs must have the same length as prompts") @@ -845,11 +839,11 @@ def get_inputs( def generate( self, - prompts: Union[list[str], list[torch.Tensor], list[list[int]]], + prompts: list[str] | list[torch.Tensor] | list[list[int]], sampling_params: SamplingParams, - images: Optional[PromptImageInput] = None, - videos: Optional[PromptVideoInput] = None, - audios: Optional[PromptAudioInput] = None, + images: PromptImageInput | None = None, + videos: PromptVideoInput | None = None, + audios: PromptAudioInput | None = None, **kwargs: Any, ) -> list[tuple[list[list[int]], list[str]]]: inputs = self.get_inputs(prompts, images=images, videos=videos, audios=audios) @@ -888,11 +882,11 @@ def generate_w_logprobs( self, prompts: list[str], sampling_params: SamplingParams, - images: Optional[PromptImageInput] = None, - audios: Optional[PromptAudioInput] = None, - videos: Optional[PromptVideoInput] = None, + images: PromptImageInput | None = None, + audios: PromptAudioInput | None = None, + videos: PromptVideoInput | None = None, **kwargs: Any, - ) -> Union[list[TokensTextLogprobs], list[TokensTextLogprobsPromptLogprobs]]: + ) -> list[TokensTextLogprobs] | list[TokensTextLogprobsPromptLogprobs]: inputs = self.get_inputs(prompts, images=images, videos=videos, audios=audios) req_outputs = self.llm.generate(inputs, sampling_params=sampling_params, **kwargs) @@ -907,11 +901,11 @@ def generate_w_logprobs( def generate_greedy( self, - prompts: Union[list[str], list[torch.Tensor], list[list[int]]], + prompts: list[str] | list[torch.Tensor] | list[list[int]], max_tokens: int, - images: Optional[PromptImageInput] = None, - videos: Optional[PromptVideoInput] = None, - audios: Optional[PromptAudioInput] = None, + images: PromptImageInput | None = None, + videos: PromptVideoInput | None = None, + audios: PromptAudioInput | None = None, **kwargs: Any, ) -> list[tuple[list[int], str]]: greedy_params = SamplingParams(temperature=0.0, max_tokens=max_tokens) @@ -929,15 +923,15 @@ def generate_greedy_logprobs( self, prompts: list[str], max_tokens: int, - num_logprobs: Optional[int], - num_prompt_logprobs: Optional[int] = None, - images: Optional[PromptImageInput] = None, - audios: Optional[PromptAudioInput] = None, - videos: Optional[PromptVideoInput] = None, - stop_token_ids: Optional[list[int]] = None, - stop: Optional[list[str]] = None, + num_logprobs: int | None, + num_prompt_logprobs: int | None = None, + images: PromptImageInput | None = None, + audios: PromptAudioInput | None = None, + videos: PromptVideoInput | None = None, + stop_token_ids: list[int] | None = None, + stop: list[str] | None = None, **kwargs: Any, - ) -> Union[list[TokensTextLogprobs], list[TokensTextLogprobsPromptLogprobs]]: + ) -> list[TokensTextLogprobs] | list[TokensTextLogprobsPromptLogprobs]: greedy_logprobs_params = SamplingParams( temperature=0.0, max_tokens=max_tokens, @@ -968,7 +962,7 @@ def generate_prompt_perplexity(self, prompts: list[str]) -> list[float]: perplexities = [] for output in outputs: output = cast(TokensTextLogprobsPromptLogprobs, output) - token_data = cast(list[Optional[dict[int, Logprob]]], output[3]) + token_data = cast(list[dict[int, Logprob] | None], output[3]) assert token_data[0] is None token_log_probs = [] for token_data in token_data[1:]: @@ -987,10 +981,10 @@ def generate_beam_search( prompts: list[str], beam_width: int, max_tokens: int, - images: Optional[PromptImageInput] = None, - videos: Optional[PromptVideoInput] = None, - audios: Optional[PromptAudioInput] = None, - concurrency_limit: Optional[int] = None, + images: PromptImageInput | None = None, + videos: PromptVideoInput | None = None, + audios: PromptAudioInput | None = None, + concurrency_limit: int | None = None, ) -> list[tuple[list[list[int]], list[str]]]: inputs = self.get_inputs(prompts, images=images, videos=videos, audios=audios) @@ -1013,9 +1007,9 @@ def classify(self, prompts: list[str]) -> list[list[float]]: def embed( self, prompts: list[str], - images: Optional[PromptImageInput] = None, - videos: Optional[PromptVideoInput] = None, - audios: Optional[PromptAudioInput] = None, + images: PromptImageInput | None = None, + videos: PromptVideoInput | None = None, + audios: PromptAudioInput | None = None, *args, **kwargs, ) -> list[list[float]]: @@ -1024,8 +1018,12 @@ def embed( req_outputs = self.llm.embed(inputs, *args, **kwargs) return [req_output.outputs.embedding for req_output in req_outputs] - def encode(self, prompts: list[str]) -> list[list[float]]: - req_outputs = self.llm.encode(prompts) + def token_embed(self, prompts: list[str]) -> list[list[float]]: + req_outputs = self.llm.encode(prompts, pooling_task="token_embed") + return [req_output.outputs.data for req_output in req_outputs] + + def token_classify(self, prompts: list[str]) -> list[list[float]]: + req_outputs = self.llm.encode(prompts, pooling_task="token_classify") return [req_output.outputs.data for req_output in req_outputs] def reward(self, prompts: list[str]) -> list[list[float]]: @@ -1034,8 +1032,8 @@ def reward(self, prompts: list[str]) -> list[list[float]]: def score( self, - text_1: Union[str, list[str]], - text_2: Union[str, list[str]], + text_1: list[str] | str, + text_2: list[str] | str, *args, **kwargs, ) -> list[float]: @@ -1078,6 +1076,101 @@ def caplog_vllm(temporary_enable_log_propagate, caplog): yield caplog +@pytest.fixture() +def caplog_mp_fork(): + """ + This fixture enables capturing logs from a forked MP subprocess. + It should be used in conjunction with caplog_vllm. + + By default, subprocess logs do not go through the parent process. + We instead create a queue listener in the parent process which + forwards logs to the logger's other handlers, and add a QueueHandler + to the root logger. Forked subprocesses will inherit the root logger + and pass their messages to the queue, which the listener will forward + to the root logger, which can be captured by caplog. + + Note that this workaround only works for fork; with spawn, the subprocess + reinitializes logging and does not automatically inherit the queue. + We'd have to manually pass the queue to the subprocess at the spawn point. + See caplog_mp_spawn below. + """ + + @contextlib.contextmanager + def ctx(): + import logging.handlers + import multiprocessing as mp + + logger_queue: mp.Queue[logging.LogRecord] = mp.Queue() + logger = logging.getLogger() + handlers = logger.handlers + + # The listener works on a background thread, not inherited by the child. + queue_listener = logging.handlers.QueueListener(logger_queue, *handlers) + queue_listener.start() + + # Add queue handler after creating the listener to avoid cycle + logger.addHandler(logging.handlers.QueueHandler(logger_queue)) + yield + queue_listener.stop() + + return ctx + + +class LogHolder: + def __init__(self): + self.text = None + + +@pytest.fixture() +def caplog_mp_spawn(tmp_path, monkeypatch): + """ + This fixture enables capturing logs from a forked MP subprocess. + It does not require caplog_vllm (but it only contains logs from the child). + + By default, subprocess logs do not go through the parent process. + We instead add a FileHandler to the config so the spawned child process + writes its logs to a temp file. + In the parent, we read the file and return the contents. + + Note: this method could be extended to fork by either reconfiguring logging + in the parent or using a SocketHandler: + https://docs.python.org/3/howto/logging-cookbook.html#sending-and-receiving-logging-events-across-a-network # noqa: E501 + """ + + @contextlib.contextmanager + def ctx(level: int | str): + from vllm.logger import DEFAULT_LOGGING_CONFIG + + config_path = tmp_path / "vllm_logging_config.json" + log_path = tmp_path / "vllm.log" + log_holder = LogHolder() + + config = deepcopy(DEFAULT_LOGGING_CONFIG) + if envs.VLLM_LOGGING_CONFIG_PATH: + path = pathlib.Path(envs.VLLM_LOGGING_CONFIG_PATH) + assert path.exists() + config = json.loads(path.read_text()) + + config["loggers"]["vllm"]["handlers"] += ["vllm_file"] + config["handlers"]["vllm_file"] = { + "class": "logging.FileHandler", + "formatter": "vllm", + "level": level, + "filename": log_path.as_posix(), + } + + config_path.write_text(json.dumps(config)) + + with monkeypatch.context() as monkeypatch_ctx: + monkeypatch_ctx.setenv("VLLM_LOGGING_CONFIG_PATH", config_path.as_posix()) + monkeypatch_ctx.setenv("VLLM_CONFIGURE_LOGGING", "1") + yield log_holder + + log_holder.text = log_path.read_text() + + return ctx + + @pytest.fixture(scope="session") def num_gpus_available(): """Get number of GPUs without initializing the CUDA context @@ -1235,8 +1328,8 @@ def _find_free_port() -> int: class LocalAssetServer: address: str port: int - server: Optional[http.server.ThreadingHTTPServer] - thread: Optional[threading.Thread] + server: http.server.ThreadingHTTPServer | None + thread: threading.Thread | None def __init__(self, address: str = "127.0.0.1") -> None: self.address = address diff --git a/auto_round_extension/vllm_ext/tests/test_mxfp4_moe.py b/auto_round_extension/vllm_ext/tests/test_mxfp4_moe.py index 587c12628..a2b192357 100644 --- a/auto_round_extension/vllm_ext/tests/test_mxfp4_moe.py +++ b/auto_round_extension/vllm_ext/tests/test_mxfp4_moe.py @@ -15,7 +15,11 @@ import pytest from vllm.platforms import current_platform -MODELS = ["/data5/yliu7/HF_HOME/unsloth-gpt-oss-20b-BF16-ar-MXFP4/"] +MODELS = [ + # "/data5/yliu7/HF_HOME/unsloth-gpt-oss-20b-BF16-ar-MXFP4/" + # "/data5/yliu7/HF_HOME/Qwen2.5-0.5B-Instruct-test-FP8_STATIC-fp8kv/" + "/data6/yiliu4/Qwen3-15B-A2B-Base-MXFP4-fp8attention" +] @pytest.mark.skipif( @@ -27,4 +31,4 @@ def test_auto_round(vllm_runner, model): with vllm_runner(model, enforce_eager=True) as llm: output = llm.generate_greedy(["The capital of France is"], max_tokens=8) assert output - print(f"{output[0][1]}") + print(f"output is: {output[0][1]}") From 9f65bd1ed6b6d72dc6b8ecb04bb5c41b6f3e09c5 Mon Sep 17 00:00:00 2001 From: yiliu30 Date: Fri, 24 Oct 2025 04:04:46 -0400 Subject: [PATCH 12/52] fix Signed-off-by: yiliu30 --- auto_round_extension/vllm_ext/auto_round_ext.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/auto_round_extension/vllm_ext/auto_round_ext.py b/auto_round_extension/vllm_ext/auto_round_ext.py index 69d1d3f92..93cbec42f 100644 --- a/auto_round_extension/vllm_ext/auto_round_ext.py +++ b/auto_round_extension/vllm_ext/auto_round_ext.py @@ -37,6 +37,8 @@ def get_quant_method(self, layer: torch.nn.Module, prefix: str): return quant_method elif isinstance(layer, LinearBase): return UnquantizedLinearMethod() + else: + return None @staticmethod def _parse_quant_scheme(config: dict): From 8038a5fabc3daa57c3ec5863a44ba1be10d3f158 Mon Sep 17 00:00:00 2001 From: yiliu30 Date: Fri, 24 Oct 2025 04:06:44 -0400 Subject: [PATCH 13/52] fix model name Signed-off-by: yiliu30 --- auto_round_extension/vllm_ext/tests/test_mxfp4_moe.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/auto_round_extension/vllm_ext/tests/test_mxfp4_moe.py b/auto_round_extension/vllm_ext/tests/test_mxfp4_moe.py index a2b192357..a2de32832 100644 --- a/auto_round_extension/vllm_ext/tests/test_mxfp4_moe.py +++ b/auto_round_extension/vllm_ext/tests/test_mxfp4_moe.py @@ -18,7 +18,7 @@ MODELS = [ # "/data5/yliu7/HF_HOME/unsloth-gpt-oss-20b-BF16-ar-MXFP4/" # "/data5/yliu7/HF_HOME/Qwen2.5-0.5B-Instruct-test-FP8_STATIC-fp8kv/" - "/data6/yiliu4/Qwen3-15B-A2B-Base-MXFP4-fp8attention" + "/data6/yiliu4/Qwen3-15B-A2B-Base-MXFP4" ] From c82bce10d6b66d6cb4ae1004d008ca43b22c4400 Mon Sep 17 00:00:00 2001 From: yiliu30 Date: Mon, 27 Oct 2025 02:03:16 -0400 Subject: [PATCH 14/52] fix Signed-off-by: yiliu30 --- auto_round_extension/vllm_ext/mxfp4_qdq_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/auto_round_extension/vllm_ext/mxfp4_qdq_utils.py b/auto_round_extension/vllm_ext/mxfp4_qdq_utils.py index ad43ac5e2..81f19881e 100644 --- a/auto_round_extension/vllm_ext/mxfp4_qdq_utils.py +++ b/auto_round_extension/vllm_ext/mxfp4_qdq_utils.py @@ -82,7 +82,7 @@ def to_dtype( last_dim = orig_shape[-1] data_lp = data_lp.reshape(-1, last_dim) result_shape = orig_shape[:-1] + (last_dim * 2,) - assert data_lp.is_contiguous, f"Data must be contiguous, got {data_lp.stride()}" + assert data_lp.is_contiguous(), f"Data must be contiguous, got {data_lp.stride()}" assert elem_dtype == "fp4_e2m1", f"Expected 'fp4_e2m1', got {elem_dtype}" From adf7ebf177dd5cff9b05fbe5ce80bfee22221a2c Mon Sep 17 00:00:00 2001 From: yiliu30 Date: Mon, 27 Oct 2025 02:22:02 -0400 Subject: [PATCH 15/52] use absolute path Signed-off-by: yiliu30 --- auto_round_extension/vllm_ext/__init__.py | 4 ++-- auto_round_extension/vllm_ext/auto_round_ext.py | 2 +- auto_round_extension/vllm_ext/moe_impl_mxfp4.py | 15 ++++++++------- auto_round_extension/vllm_ext/mxfp4_qdq_utils.py | 5 +++-- auto_round_extension/vllm_ext/quant_method_moe.py | 2 +- 5 files changed, 15 insertions(+), 13 deletions(-) diff --git a/auto_round_extension/vllm_ext/__init__.py b/auto_round_extension/vllm_ext/__init__.py index 181ab399b..3e145334f 100644 --- a/auto_round_extension/vllm_ext/__init__.py +++ b/auto_round_extension/vllm_ext/__init__.py @@ -20,7 +20,7 @@ def apply(): import vllm.model_executor.layers.quantization.auto_round as auto_round_module - from .auto_round_ext import AutoRoundExtensionConfig + from auto_round_extension.vllm_ext.auto_round_ext import AutoRoundExtensionConfig auto_round_module.AutoRoundConfig = AutoRoundExtensionConfig - from .envs_ext import extra_environment_variables + from auto_round_extension.vllm_ext.envs_ext import extra_environment_variables diff --git a/auto_round_extension/vllm_ext/auto_round_ext.py b/auto_round_extension/vllm_ext/auto_round_ext.py index 93cbec42f..cce2c912d 100644 --- a/auto_round_extension/vllm_ext/auto_round_ext.py +++ b/auto_round_extension/vllm_ext/auto_round_ext.py @@ -15,13 +15,13 @@ from typing import Any import torch -from quant_method_moe import AutoRoundMoEMethod from vllm.logger import init_logger from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.linear import LinearBase, UnquantizedLinearMethod from vllm.model_executor.layers.quantization.auto_round import AutoRoundConfig from auto_round.schemes import QuantizationScheme +from auto_round_extension.vllm_ext.quant_method_moe import AutoRoundMoEMethod logger = init_logger(__name__) diff --git a/auto_round_extension/vllm_ext/moe_impl_mxfp4.py b/auto_round_extension/vllm_ext/moe_impl_mxfp4.py index 1bfccb623..5938e878e 100644 --- a/auto_round_extension/vllm_ext/moe_impl_mxfp4.py +++ b/auto_round_extension/vllm_ext/moe_impl_mxfp4.py @@ -27,8 +27,15 @@ from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig from vllm.model_executor.utils import set_weight_attrs +import auto_round_extension.vllm_ext.mxfp4_qdq_utils as mxfp4_utils +from auto_round_extension.vllm_ext.mxfp4_qdq_utils import ( + dequant_mxfp4_to_fp8, + mxfp4_gemm_with_unpacked_weight, + run_mxfp4_emulations, +) +from auto_round_extension.vllm_ext.quant_method_moe import AutoRoundMoEMethod + logger = init_logger(__name__) -from quant_method_moe import AutoRoundMoEMethod def apply_act(local_w1_out: torch.Tensor, local_w3_out: torch.Tensor, activation: str) -> torch.Tensor: @@ -173,7 +180,6 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: if envs.VLLM_ENABLE_STATIC_MOE: if envs.VLLM_MXFP4_PRE_UNPACK_WEIGHTS: weight_name_lst = ["w13_weight", "w2_weight"] - from mxfp4_qdq_utils import dequant_mxfp4_to_fp8 for weight_name_prefix in weight_name_lst: weight_name = f"{weight_name_prefix}_packed" @@ -248,7 +254,6 @@ def revert_interleaved_bias(bias): layer.w13_bias.data.copy_(w13_bias_swapped) if envs.VLLM_MXFP4_PRE_UNPACK_WEIGHTS: - import mxfp4_qdq_utils as mxfp4_utils w1 = layer.w13_weight_packed w1_scale = layer.w13_weight_scale @@ -436,8 +441,6 @@ def apply( local_w3_packed = local_w13_packed[intermediate_size_per_partition:, ...] local_w3_scale = local_w13_scale[intermediate_size_per_partition:, ...] - from mxfp4_qdq_utils import run_mxfp4_emulations - local_w1_bias = None local_w2_bias = None local_w3_bias = None @@ -504,8 +507,6 @@ def apply( local_w3_bias = local_w13_bias[intermediate_size_per_partition:] local_w2_bias = layer.w2_bias[expert_index] - from mxfp4_qdq_utils import mxfp4_gemm_with_unpacked_weight - local_w1_out = mxfp4_gemm_with_unpacked_weight( x=current_state_static, weight_fp8=local_unpacked_w1, diff --git a/auto_round_extension/vllm_ext/mxfp4_qdq_utils.py b/auto_round_extension/vllm_ext/mxfp4_qdq_utils.py index 81f19881e..60cf02c62 100644 --- a/auto_round_extension/vllm_ext/mxfp4_qdq_utils.py +++ b/auto_round_extension/vllm_ext/mxfp4_qdq_utils.py @@ -15,8 +15,9 @@ from typing import Union import torch -from fp4_utils import cast_to_fp4, pack_fp4_to_uint8, unpack_fp4_from_uint8 -from utils import _to_mx_rceil, get_fp_scale + +from auto_round_extension.vllm_ext.fp4_utils import cast_to_fp4, pack_fp4_to_uint8, unpack_fp4_from_uint8 +from auto_round_extension.vllm_ext.utils import _to_mx_rceil, get_fp_scale F4_E2M1_MAX = 6.0 F32_EXP_BIAS = 127 diff --git a/auto_round_extension/vllm_ext/quant_method_moe.py b/auto_round_extension/vllm_ext/quant_method_moe.py index b3394f004..ae7b81775 100644 --- a/auto_round_extension/vllm_ext/quant_method_moe.py +++ b/auto_round_extension/vllm_ext/quant_method_moe.py @@ -69,7 +69,7 @@ def get_impl(scheme: QuantizationScheme): return UnquantizedFusedMoEMethod(layer.moe_config) elif _is_mxfp4_w4a4(scheme): - from moe_impl_mxfp4 import AutoRoundMoEMethodMXFp4Impl + from auto_round_extension.vllm_ext.moe_impl_mxfp4 import AutoRoundMoEMethodMXFp4Impl return AutoRoundMoEMethodMXFp4Impl(quant_config, layer.moe_config) From ad8537c61c0f4fc840b2c8da8177bf2136d54f99 Mon Sep 17 00:00:00 2001 From: yiliu30 Date: Wed, 29 Oct 2025 20:51:00 -0400 Subject: [PATCH 16/52] fix Signed-off-by: yiliu30 --- auto_round_extension/vllm_ext/fp4_utils.py | 3 +-- auto_round_extension/vllm_ext/moe_impl_mxfp4.py | 2 +- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/auto_round_extension/vllm_ext/fp4_utils.py b/auto_round_extension/vllm_ext/fp4_utils.py index cf8148992..adb67d190 100644 --- a/auto_round_extension/vllm_ext/fp4_utils.py +++ b/auto_round_extension/vllm_ext/fp4_utils.py @@ -51,8 +51,7 @@ def pack_fp4_to_uint8(x: torch.Tensor) -> torch.Tensor: indices = indices.reshape(-1) # Handle odd length by padding if necessary - if indices.numel() % 2 != 0: - indices = torch.cat([indices, torch.zeros(1, dtype=torch.long, device=device)]) + assert indices.numel() % 2 != 0, f"Expected even number of elements, got {indices.numel()}" # Reshape to pair consecutive elements indices = indices.reshape(-1, 2) diff --git a/auto_round_extension/vllm_ext/moe_impl_mxfp4.py b/auto_round_extension/vllm_ext/moe_impl_mxfp4.py index 5938e878e..77ec6570d 100644 --- a/auto_round_extension/vllm_ext/moe_impl_mxfp4.py +++ b/auto_round_extension/vllm_ext/moe_impl_mxfp4.py @@ -149,7 +149,7 @@ def create_weights( set_weight_attrs(w13_bias, extra_weight_attrs) w2_bias = torch.nn.Parameter( - torch.zeros(num_experts, hidden_size, dtype=bias_dtype), + torch.zeros(E, H, dtype=bias_dtype), requires_grad=False, ) layer.register_parameter("w2_bias", w2_bias) From 77844f6bbbdd7bae1087ac66a3bc242e2fe03b2d Mon Sep 17 00:00:00 2001 From: yiliu30 Date: Wed, 29 Oct 2025 23:37:31 -0400 Subject: [PATCH 17/52] mark round method as todo Signed-off-by: yiliu30 --- auto_round_extension/vllm_ext/mxfp4_qdq_utils.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/auto_round_extension/vllm_ext/mxfp4_qdq_utils.py b/auto_round_extension/vllm_ext/mxfp4_qdq_utils.py index 60cf02c62..197d757fa 100644 --- a/auto_round_extension/vllm_ext/mxfp4_qdq_utils.py +++ b/auto_round_extension/vllm_ext/mxfp4_qdq_utils.py @@ -111,6 +111,7 @@ def run_mxfp4_emulations( weight_scale: torch.Tensor, bias: torch.Tensor | None = None, ): + # TODO: select the rounding mode based on config group_size = 32 # quantize input to (FP4 and interleaved block scale) input_scale, x_q = to_mxfp4_rceil( @@ -207,6 +208,7 @@ def fp4_121_scaled_even_rounding(x: torch.Tensor) -> torch.Tensor: def qdq_mxfp4( x: torch.Tensor, ) -> torch.Tensor: + # TODO: select the rounding mode based on config block_size = 32 shape = x.shape x = x.reshape(-1, block_size) From ce985efce355d634cfa30b22504d325f9275f4fa Mon Sep 17 00:00:00 2001 From: yiliu30 Date: Thu, 30 Oct 2025 22:41:58 -0400 Subject: [PATCH 18/52] tmp wa for llmc Signed-off-by: yiliu30 --- auto_round/compressors/base.py | 47 ++++++++++++++++++---------------- auto_round/utils/model.py | 3 ++- 2 files changed, 27 insertions(+), 23 deletions(-) diff --git a/auto_round/compressors/base.py b/auto_round/compressors/base.py index 869fc74de..73eb606ba 100644 --- a/auto_round/compressors/base.py +++ b/auto_round/compressors/base.py @@ -2439,20 +2439,21 @@ def _quantize_block( new_layer = convert_fp8_layer_to_linear(m, self.amp_dtype).to(device) set_module(block, n, new_layer) - if self.device_map == "auto" or (isinstance(self.device_map, str) and "," in self.device_map): - set_auto_device_map_for_block_with_tuning( - block, self.device_map, input_ids, self.low_gpu_mem_usage, self.mem_per_param_scale - ) - - if self.device_map is not None: - for n, m in block.named_modules(): - if len(list(m.children())) != 0 or not hasattr(m, "tuning_device"): - continue - from accelerate.hooks import AlignDevicesHook, add_hook_to_module - - hook = AlignDevicesHook(m.tuning_device, io_same_device=True) - add_hook_to_module(m, hook, True) - + # >>>>>>>>>>>>> @yi + # if self.device_map == "auto" or (isinstance(self.device_map, str) and "," in self.device_map): + # set_auto_device_map_for_block_with_tuning( + # block, self.device_map, input_ids, self.low_gpu_mem_usage, self.mem_per_param_scale + # ) + + # if self.device_map is not None: + # for n, m in block.named_modules(): + # if len(list(m.children())) != 0 or not hasattr(m, "tuning_device"): + # continue + # from accelerate.hooks import AlignDevicesHook, add_hook_to_module + + # hook = AlignDevicesHook(m.tuning_device, io_same_device=True) + # add_hook_to_module(m, hook, True) + # <<<<<<<<<<<< @yi if q_input is None: hook_handles = self._register_act_max_hook(block) @@ -2460,8 +2461,8 @@ def _quantize_block( block, input_ids, input_others, self.batch_size * self.infer_bs_coeff, device, self.cache_device ) - for handle in hook_handles: - handle.remove() + # for handle in hook_handles: + # handle.remove() else: output = self._get_block_outputs( block, input_ids, input_others, self.batch_size * self.infer_bs_coeff, device, self.cache_device @@ -2640,17 +2641,19 @@ def _quantize_block( device, cache_device=self.cache_device, ) - if self.device_map is not None: - accelerate.hooks.remove_hook_from_submodules(block) - mv_module_from_gpu(block) + # @yi + # if self.device_map is not None: + # accelerate.hooks.remove_hook_from_submodules(block) + # mv_module_from_gpu(block) clear_memory(input_ids) return q_outputs, output else: - if self.device_map is not None: - accelerate.hooks.remove_hook_from_submodules(block) - mv_module_from_gpu(block) + # @yi + # if self.device_map is not None: + # accelerate.hooks.remove_hook_from_submodules(block) + # mv_module_from_gpu(block) clear_memory(input_ids) return None, output diff --git a/auto_round/utils/model.py b/auto_round/utils/model.py index 7a1c66de8..59fe91cd4 100644 --- a/auto_round/utils/model.py +++ b/auto_round/utils/model.py @@ -443,7 +443,8 @@ def is_mllm_model(model_or_path: Union[str, torch.nn.Module]): "pre_mm_projector_norm", "vision", ] - + # FIXME: yi, fix it later + return False model_path = model_or_path if isinstance(model_or_path, str) else model_or_path.name_or_path if not os.path.isdir(model_path): model_path = download_hf_model(model_path) From 8832530c783e307cc1b96a8b6e4226779e213890 Mon Sep 17 00:00:00 2001 From: yiliu30 Date: Thu, 30 Oct 2025 22:41:58 -0400 Subject: [PATCH 19/52] tmp wa for llmc Signed-off-by: yiliu30 --- auto_round/compressors/base.py | 47 ++++++++++++++++++---------------- auto_round/utils/model.py | 3 ++- 2 files changed, 27 insertions(+), 23 deletions(-) diff --git a/auto_round/compressors/base.py b/auto_round/compressors/base.py index fe0939c1f..234b4d129 100644 --- a/auto_round/compressors/base.py +++ b/auto_round/compressors/base.py @@ -2446,20 +2446,21 @@ def _quantize_block( new_layer = convert_fp8_layer_to_linear(m, self.amp_dtype).to(device) set_module(block, n, new_layer) - if self.device_map == "auto" or (isinstance(self.device_map, str) and "," in self.device_map): - set_auto_device_map_for_block_with_tuning( - block, self.device_map, input_ids, self.low_gpu_mem_usage, self.mem_per_param_scale - ) - - if self.device_map is not None: - for n, m in block.named_modules(): - if len(list(m.children())) != 0 or not hasattr(m, "tuning_device"): - continue - from accelerate.hooks import AlignDevicesHook, add_hook_to_module - - hook = AlignDevicesHook(m.tuning_device, io_same_device=True) - add_hook_to_module(m, hook, True) - + # >>>>>>>>>>>>> @yi + # if self.device_map == "auto" or (isinstance(self.device_map, str) and "," in self.device_map): + # set_auto_device_map_for_block_with_tuning( + # block, self.device_map, input_ids, self.low_gpu_mem_usage, self.mem_per_param_scale + # ) + + # if self.device_map is not None: + # for n, m in block.named_modules(): + # if len(list(m.children())) != 0 or not hasattr(m, "tuning_device"): + # continue + # from accelerate.hooks import AlignDevicesHook, add_hook_to_module + + # hook = AlignDevicesHook(m.tuning_device, io_same_device=True) + # add_hook_to_module(m, hook, True) + # <<<<<<<<<<<< @yi if q_input is None: hook_handles = self._register_act_max_hook(block) @@ -2467,8 +2468,8 @@ def _quantize_block( block, input_ids, input_others, self.batch_size * self.infer_bs_coeff, device, self.cache_device ) - for handle in hook_handles: - handle.remove() + # for handle in hook_handles: + # handle.remove() else: output = self._get_block_outputs( block, input_ids, input_others, self.batch_size * self.infer_bs_coeff, device, self.cache_device @@ -2647,17 +2648,19 @@ def _quantize_block( device, cache_device=self.cache_device, ) - if self.device_map is not None: - accelerate.hooks.remove_hook_from_submodules(block) - mv_module_from_gpu(block) + # @yi + # if self.device_map is not None: + # accelerate.hooks.remove_hook_from_submodules(block) + # mv_module_from_gpu(block) clear_memory(input_ids) return q_outputs, output else: - if self.device_map is not None: - accelerate.hooks.remove_hook_from_submodules(block) - mv_module_from_gpu(block) + # @yi + # if self.device_map is not None: + # accelerate.hooks.remove_hook_from_submodules(block) + # mv_module_from_gpu(block) clear_memory(input_ids) return None, output diff --git a/auto_round/utils/model.py b/auto_round/utils/model.py index 7a1c66de8..59fe91cd4 100644 --- a/auto_round/utils/model.py +++ b/auto_round/utils/model.py @@ -443,7 +443,8 @@ def is_mllm_model(model_or_path: Union[str, torch.nn.Module]): "pre_mm_projector_norm", "vision", ] - + # FIXME: yi, fix it later + return False model_path = model_or_path if isinstance(model_or_path, str) else model_or_path.name_or_path if not os.path.isdir(model_path): model_path = download_hf_model(model_path) From 361491f7655bdfcfef5ee61d0d8de5b9a981237f Mon Sep 17 00:00:00 2001 From: yiliu30 Date: Sun, 2 Nov 2025 01:56:13 -0400 Subject: [PATCH 20/52] return ds Signed-off-by: yiliu30 --- auto_round/calib_dataset.py | 12 +++--------- 1 file changed, 3 insertions(+), 9 deletions(-) diff --git a/auto_round/calib_dataset.py b/auto_round/calib_dataset.py index a219aa787..f6fa1a237 100644 --- a/auto_round/calib_dataset.py +++ b/auto_round/calib_dataset.py @@ -632,14 +632,7 @@ def select_dataset(dataset, indices): return dataset -def get_dataloader( - tokenizer, - seqlen, - dataset_name="NeelNanda/pile-10k", - seed=42, - bs=8, - nsamples=512, -): +def get_dataloader(tokenizer, seqlen, dataset_name="NeelNanda/pile-10k", seed=42, bs=8, nsamples=512, return_ds=False): """Generate a DataLoader for calibration using specified parameters. Args: @@ -851,6 +844,7 @@ def collate_batch(batch): if len(dataset_final) > nsamples: dataset_final = select_dataset(dataset_final, range(nsamples)) - + if return_ds: + return dataset_final calib_dataloader = DataLoader(dataset_final, batch_size=bs, shuffle=False, collate_fn=collate_batch) return calib_dataloader From db65d74bd581cfe047e1d1d2145747df7910b11d Mon Sep 17 00:00:00 2001 From: yiliu30 Date: Sun, 2 Nov 2025 01:56:46 -0400 Subject: [PATCH 21/52] add more log Signed-off-by: yiliu30 --- auto_round/compressors/base.py | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/auto_round/compressors/base.py b/auto_round/compressors/base.py index 234b4d129..51a880f51 100644 --- a/auto_round/compressors/base.py +++ b/auto_round/compressors/base.py @@ -99,6 +99,27 @@ from auto_round.wrapper import WrapperLinear, WrapperMultiblock, unwrapper_block, unwrapper_layer, wrapper_block +# function decorator to dump the func time +def time_logger(func: Callable) -> Callable: + """Decorator to log the execution time of a function. + + Args: + func (Callable): The function to be decorated. + + Returns: + Callable: The wrapped function with time logging. + """ + + def wrapper(*args, **kwargs): + start_time = time.time() + result = func(*args, **kwargs) + end_time = time.time() + logger.info(f"Function '{func.__name__}' executed in {end_time - start_time:.4f} seconds.") + return result + + return wrapper + + class BaseCompressor(object): """Base compressor for LLM quantization @@ -2420,6 +2441,7 @@ def _get_current_num_elm( current_input_ids = [input_ids[i] for i in indices] return sum(id.numel() for id in current_input_ids) + @time_logger def _quantize_block( self, block: torch.nn.Module, @@ -2566,6 +2588,7 @@ def _quantize_block( best_params = {} total_loss = 0 for i in range(self.iters): + logger.trace(f"Quant block iteration {i}/{self.iters}, best loss so far: {best_loss}") total_loss = 0 if self.sampler == "rand": whole_indices = torch.randperm(nsamples)[:pick_samples] From 60a00232c29467d7dd6a67b4ad2fb4fb67ab8c4a Mon Sep 17 00:00:00 2001 From: yiliu30 Date: Sun, 2 Nov 2025 20:37:53 -0500 Subject: [PATCH 22/52] refine code Signed-off-by: yiliu30 --- auto_round/compressors/base.py | 29 ++++++++++++++++------------- 1 file changed, 16 insertions(+), 13 deletions(-) diff --git a/auto_round/compressors/base.py b/auto_round/compressors/base.py index 51a880f51..099cc4bfb 100644 --- a/auto_round/compressors/base.py +++ b/auto_round/compressors/base.py @@ -1516,6 +1516,21 @@ def _update_inputs(self, inputs: dict, q_inputs: dict) -> tuple[dict, torch.Tens q_inputs = q_inputs.pop(input_id_str[0], None) return inputs, q_inputs + def configure_layer_config(self, enable_gguf_official_mixed: None | bool = False): + self.layer_config, self.has_qlayer_outside_block, self.regex_config = set_layer_config( + self.model, + self.layer_config, + self.scheme, + self.scale_dtype, + self.supported_types, + self.inner_supported_types, + self.quant_block_list, + self.fp_layers, + self.quant_lm_head, + enable_gguf_official_mixed=enable_gguf_official_mixed, + is_mllm=self.mllm, + ) + def quantize(self) -> tuple[torch.nn.Module, dict[str, Any]]: """Quantize the model and return the quantized model along with layer configurations.The entry of AutoRound. Returns: @@ -1534,20 +1549,8 @@ def quantize(self) -> tuple[torch.nn.Module, dict[str, Any]]: enable_gguf_official_mixed = True else: enable_gguf_official_mixed = False - self.layer_config, self.has_qlayer_outside_block, self.regex_config = set_layer_config( - self.model, - self.layer_config, - self.scheme, - self.scale_dtype, - self.supported_types, - self.inner_supported_types, - self.quant_block_list, - self.fp_layers, - self.quant_lm_head, - enable_gguf_official_mixed=enable_gguf_official_mixed, - is_mllm=self.mllm, - ) + self.configure_layer_config(enable_gguf_official_mixed=enable_gguf_official_mixed) if not hasattr(self, "formats"): logger.warning("this API is deprecated, please use `quantize_and_save` instead") else: From 7a1716e00e698b33b05ab54758ed47fc97bcef84 Mon Sep 17 00:00:00 2001 From: root Date: Sun, 2 Nov 2025 19:31:02 -0800 Subject: [PATCH 23/52] refactor Signed-off-by: yi --- auto_round/calib_dataset.py | 35 ++++++++++++++++++++++++++-------- auto_round/compressors/base.py | 9 +++++---- 2 files changed, 32 insertions(+), 12 deletions(-) diff --git a/auto_round/calib_dataset.py b/auto_round/calib_dataset.py index f6fa1a237..aeda07fb0 100644 --- a/auto_round/calib_dataset.py +++ b/auto_round/calib_dataset.py @@ -632,8 +632,8 @@ def select_dataset(dataset, indices): return dataset -def get_dataloader(tokenizer, seqlen, dataset_name="NeelNanda/pile-10k", seed=42, bs=8, nsamples=512, return_ds=False): - """Generate a DataLoader for calibration using specified parameters. +def get_dataset(tokenizer, seqlen, dataset_name="NeelNanda/pile-10k", seed=42, bs=8, nsamples=512): + """Generate a dataset for calibration. Args: tokenizer (Tokenizer): The tokenizer to use for tokenization. @@ -648,7 +648,7 @@ def get_dataloader(tokenizer, seqlen, dataset_name="NeelNanda/pile-10k", seed=42 apply_chat_template: Whether to apply chat template in tokenization. Returns: - DataLoader: The DataLoader for the calibrated dataset. + Dataset: The processed dataset ready for calibration. """ dataset_names = dataset_name.split(",") @@ -816,7 +816,29 @@ def concat_dataset_element(dataset): else: dataset_final = datasets[0] - # dataset_final = datasets[0] + if len(dataset_final) > nsamples: + dataset_final = select_dataset(dataset_final, range(nsamples)) + return dataset_final + + +def get_dataloader(tokenizer, seqlen, dataset_name="NeelNanda/pile-10k", seed=42, bs=8, nsamples=512): + """Generate a DataLoader for calibration using specified parameters. + + Args: + tokenizer (Tokenizer): The tokenizer to use for tokenization. + seqlen (int): The exact sequence length. samples < seqlen will be dropped, + samples longer than seqlen will be truncated + dataset_name (str, optional): The name of the dataset or datasets separated by commas. + Defaults to "NeelNanda/pile-10k". + split (str, optional): The data split to use. Defaults to None. + seed (int, optional): The random seed for reproducibility. Defaults to 42. + bs (int, optional): The batch size. Defaults to 4. + nsamples (int, optional): The total number of samples to include. Defaults to 512. + apply_chat_template: Whether to apply chat template in tokenization. + + Returns: + DataLoader: The DataLoader for the calibrated dataset. + """ @torch.no_grad() def collate_batch(batch): @@ -842,9 +864,6 @@ def collate_batch(batch): res = {"input_ids": input_ids_new, "attention_mask": attention_mask_new} return res - if len(dataset_final) > nsamples: - dataset_final = select_dataset(dataset_final, range(nsamples)) - if return_ds: - return dataset_final + dataset_final = get_dataset(tokenizer, seqlen, dataset_name, seed, bs, nsamples) calib_dataloader = DataLoader(dataset_final, batch_size=bs, shuffle=False, collate_fn=collate_batch) return calib_dataloader diff --git a/auto_round/compressors/base.py b/auto_round/compressors/base.py index 099cc4bfb..8feb8aa07 100644 --- a/auto_round/compressors/base.py +++ b/auto_round/compressors/base.py @@ -372,7 +372,8 @@ def __init__( # Some helpers if "hpu" in str(self.device): self.inner_supported_types = tuple(x for x in INNER_SUPPORTED_LAYER_TYPES if x != "FP8Linear") - self.batch_dim = None + # TODO: check with heng/weiwei + self.batch_dim = 0 self.infer_bs_coeff = 1 self.block_forward = compile_func(block_forward, self.device) if self.enable_torch_compile else block_forward @@ -2445,7 +2446,7 @@ def _get_current_num_elm( return sum(id.numel() for id in current_input_ids) @time_logger - def _quantize_block( + def quantize_block( self, block: torch.nn.Module, input_ids: Union[list[torch.Tensor], dict], @@ -2762,9 +2763,9 @@ def _quantize_blocks( else: logger.info("using algorithm extension for quantization.") except (ImportError, ModuleNotFoundError): - quantize_block = self._quantize_block + quantize_block = self.quantize_block else: - quantize_block = self._quantize_block + quantize_block = self.quantize_block if pbar is None: pbar = tqdm(range(0, len(block_names), nblocks)) From a20f9df7394663e103b0a35f4bab408fc492d079 Mon Sep 17 00:00:00 2001 From: root Date: Mon, 3 Nov 2025 19:22:39 -0800 Subject: [PATCH 24/52] refactor Signed-off-by: root --- auto_round/compressors/base.py | 17 ++++++++++------- auto_round/compressors/utils.py | 2 +- auto_round/utils/common.py | 25 +++++++++++++++++++++++++ 3 files changed, 36 insertions(+), 8 deletions(-) diff --git a/auto_round/compressors/base.py b/auto_round/compressors/base.py index 8feb8aa07..fa27ed18a 100644 --- a/auto_round/compressors/base.py +++ b/auto_round/compressors/base.py @@ -20,7 +20,7 @@ import traceback from collections import defaultdict from dataclasses import asdict, fields -from typing import Any, Callable, Union +from typing import Any, Callable, Union, Optional import accelerate import torch @@ -90,6 +90,7 @@ to_device, to_dtype, unsupported_meta_device, + normalize_input, ) from auto_round.utils.device import ( get_major_device, @@ -2449,9 +2450,9 @@ def _get_current_num_elm( def quantize_block( self, block: torch.nn.Module, - input_ids: Union[list[torch.Tensor], dict], - input_others: dict, + inputs: tuple[Union[list[torch.Tensor], dict, Any], Optional[dict]], q_input: Union[torch.Tensor, dict, None] = None, + normalize_inputs: bool = False, device: Union[str, torch.device] = "cpu", ): """Quantize the weights of a given block of the model. @@ -2471,7 +2472,10 @@ def quantize_block( if is_fp8_linear(m): new_layer = convert_fp8_layer_to_linear(m, self.amp_dtype).to(device) set_module(block, n, new_layer) - + if normalize_inputs: + input_ids, input_others = normalize_input(inputs) + else: + input_ids, input_others = inputs # >>>>>>>>>>>>> @yi # if self.device_map == "auto" or (isinstance(self.device_map, str) and "," in self.device_map): # set_auto_device_map_for_block_with_tuning( @@ -2615,7 +2619,7 @@ def quantize_block( else: tmp_attention_mask = 1.0 if self.amp: - with autocast(device_type=device.split(":")[0], dtype=self.amp_dtype): + with autocast(device_type=str(device).split(":")[0], dtype=self.amp_dtype): loss = mse_loss( # pylint: disable=not-callable output_q * tmp_attention_mask, current_output * tmp_attention_mask ) @@ -2786,8 +2790,7 @@ def _quantize_blocks( m = m.to(device) q_input, input_ids = quantize_block( m, - input_ids, - input_others, + (input_ids,input_others), q_input=q_input, device=device, ) diff --git a/auto_round/compressors/utils.py b/auto_round/compressors/utils.py index 8a8fe602c..5b5fe12dc 100644 --- a/auto_round/compressors/utils.py +++ b/auto_round/compressors/utils.py @@ -111,7 +111,7 @@ def block_forward( alibi = input_others["alibi"] input_others["alibi"] = alibi.reshape(-1, alibi.shape[2], alibi.shape[3]) if amp: - with autocast(device_type=device.split(":")[0], dtype=amp_dtype): # pragma: no cover + with autocast(device_type=str(device).split(":")[0], dtype=amp_dtype): # pragma: no cover output = block(input_ids, *input_tuple, **input_others) else: output = block(input_ids, *input_tuple, **input_others) diff --git a/auto_round/utils/common.py b/auto_round/utils/common.py index dea3f8c81..ef2a2c939 100644 --- a/auto_round/utils/common.py +++ b/auto_round/utils/common.py @@ -297,3 +297,28 @@ def get_reciprocal(tensor): else: tensor = torch.where(torch.abs(tensor) < 1e-30, 0, tensor) return torch.where(tensor != 0, 1 / tensor, torch.zeros_like(tensor)) + +def normalize_input(cur_inputs): + # TODO: move it to auto-round + input_ids = [] + input_others = {} + positional_inputs = [] + attention_mask = None + position_ids = None + cache_position = None + position_embeddings = (None, None) + for cur_inp in cur_inputs: + input_ids.append(cur_inp[0][0][0]) + for key, val in cur_inp[0][1].items(): + if key == "position_ids": + position_ids = val + elif key == "position_embeddings": + position_embeddings = val + elif key == "cache_position": + cache_position = val + input_others["position_ids"] = position_ids + input_others["positional_inputs"] = positional_inputs + input_others["attention_mask"] = attention_mask + input_others["position_embeddings"] = position_embeddings + input_others["cache_position"] = cache_position + return input_ids, input_others From 553ee5c834e92004be087144f978ca20d73304ae Mon Sep 17 00:00:00 2001 From: root Date: Tue, 4 Nov 2025 17:20:08 -0800 Subject: [PATCH 25/52] fix offloaf Signed-off-by: root --- auto_round/compressors/base.py | 1 + 1 file changed, 1 insertion(+) diff --git a/auto_round/compressors/base.py b/auto_round/compressors/base.py index fa27ed18a..6b0a617c7 100644 --- a/auto_round/compressors/base.py +++ b/auto_round/compressors/base.py @@ -2454,6 +2454,7 @@ def quantize_block( q_input: Union[torch.Tensor, dict, None] = None, normalize_inputs: bool = False, device: Union[str, torch.device] = "cpu", + auto_offload=True ): """Quantize the weights of a given block of the model. From 2bd3c4b196949af92cacfcbe2ce639bafab5b890 Mon Sep 17 00:00:00 2001 From: root Date: Tue, 4 Nov 2025 19:38:38 -0800 Subject: [PATCH 26/52] fix Signed-off-by: root --- auto_round/compressors/base.py | 58 +++++++++++++++------------------- 1 file changed, 25 insertions(+), 33 deletions(-) diff --git a/auto_round/compressors/base.py b/auto_round/compressors/base.py index 6b0a617c7..dc1f96be1 100644 --- a/auto_round/compressors/base.py +++ b/auto_round/compressors/base.py @@ -2477,30 +2477,30 @@ def quantize_block( input_ids, input_others = normalize_input(inputs) else: input_ids, input_others = inputs - # >>>>>>>>>>>>> @yi - # if self.device_map == "auto" or (isinstance(self.device_map, str) and "," in self.device_map): - # set_auto_device_map_for_block_with_tuning( - # block, self.device_map, input_ids, self.low_gpu_mem_usage, self.mem_per_param_scale - # ) - - # if self.device_map is not None: - # for n, m in block.named_modules(): - # if len(list(m.children())) != 0 or not hasattr(m, "tuning_device"): - # continue - # from accelerate.hooks import AlignDevicesHook, add_hook_to_module - - # hook = AlignDevicesHook(m.tuning_device, io_same_device=True) - # add_hook_to_module(m, hook, True) - # <<<<<<<<<<<< @yi + if auto_offload: + if self.device_map == "auto" or (isinstance(self.device_map, str) and "," in self.device_map): + set_auto_device_map_for_block_with_tuning( + block, self.device_map, input_ids, self.low_gpu_mem_usage, self.mem_per_param_scale + ) + + if self.device_map is not None: + for n, m in block.named_modules(): + if len(list(m.children())) != 0 or not hasattr(m, "tuning_device"): + continue + from accelerate.hooks import AlignDevicesHook, add_hook_to_module + + hook = AlignDevicesHook(m.tuning_device, io_same_device=True) + add_hook_to_module(m, hook, True) + if q_input is None: hook_handles = self._register_act_max_hook(block) output = self._get_block_outputs( block, input_ids, input_others, self.batch_size * self.infer_bs_coeff, device, self.cache_device ) - - # for handle in hook_handles: - # handle.remove() + if auto_offload: + for handle in hook_handles: + handle.remove() else: output = self._get_block_outputs( block, input_ids, input_others, self.batch_size * self.infer_bs_coeff, device, self.cache_device @@ -2669,7 +2669,7 @@ def quantize_block( if is_nv_fp(self.act_data_type): # enable moe experts act_max automatic generation for WrapperWALayer set_amax_for_all_moe_layers(block, attr_name="orig_layer.act_max") - + q_outputs = None if self.enable_quanted_input: clear_memory() q_outputs = self._get_block_outputs( @@ -2680,21 +2680,13 @@ def quantize_block( device, cache_device=self.cache_device, ) - # @yi - # if self.device_map is not None: - # accelerate.hooks.remove_hook_from_submodules(block) - # mv_module_from_gpu(block) - clear_memory(input_ids) - - return q_outputs, output + if auto_offload: + if self.device_map is not None: + accelerate.hooks.remove_hook_from_submodules(block) + mv_module_from_gpu(block) + clear_memory(input_ids) - else: - # @yi - # if self.device_map is not None: - # accelerate.hooks.remove_hook_from_submodules(block) - # mv_module_from_gpu(block) - clear_memory(input_ids) - return None, output + return q_outputs, output def _split_inputs(self, inputs: dict) -> tuple[torch.Tensor, dict]: input_ids = inputs["input_ids"] From b992c3196a9acdfa08f14fa27360dbee56e0d239 Mon Sep 17 00:00:00 2001 From: yiliu30 Date: Wed, 5 Nov 2025 00:10:39 -0500 Subject: [PATCH 27/52] remove time Signed-off-by: yiliu30 --- auto_round/compressors/base.py | 30 ++++-------------------------- 1 file changed, 4 insertions(+), 26 deletions(-) diff --git a/auto_round/compressors/base.py b/auto_round/compressors/base.py index dc1f96be1..fc5e61871 100644 --- a/auto_round/compressors/base.py +++ b/auto_round/compressors/base.py @@ -20,7 +20,7 @@ import traceback from collections import defaultdict from dataclasses import asdict, fields -from typing import Any, Callable, Union, Optional +from typing import Any, Callable, Optional, Union import accelerate import torch @@ -85,12 +85,12 @@ is_hpex_available, llm_load_model, mv_module_from_gpu, + normalize_input, set_amax_for_all_moe_layers, set_module, to_device, to_dtype, unsupported_meta_device, - normalize_input, ) from auto_round.utils.device import ( get_major_device, @@ -100,27 +100,6 @@ from auto_round.wrapper import WrapperLinear, WrapperMultiblock, unwrapper_block, unwrapper_layer, wrapper_block -# function decorator to dump the func time -def time_logger(func: Callable) -> Callable: - """Decorator to log the execution time of a function. - - Args: - func (Callable): The function to be decorated. - - Returns: - Callable: The wrapped function with time logging. - """ - - def wrapper(*args, **kwargs): - start_time = time.time() - result = func(*args, **kwargs) - end_time = time.time() - logger.info(f"Function '{func.__name__}' executed in {end_time - start_time:.4f} seconds.") - return result - - return wrapper - - class BaseCompressor(object): """Base compressor for LLM quantization @@ -2446,7 +2425,6 @@ def _get_current_num_elm( current_input_ids = [input_ids[i] for i in indices] return sum(id.numel() for id in current_input_ids) - @time_logger def quantize_block( self, block: torch.nn.Module, @@ -2454,7 +2432,7 @@ def quantize_block( q_input: Union[torch.Tensor, dict, None] = None, normalize_inputs: bool = False, device: Union[str, torch.device] = "cpu", - auto_offload=True + auto_offload=True, ): """Quantize the weights of a given block of the model. @@ -2783,7 +2761,7 @@ def _quantize_blocks( m = m.to(device) q_input, input_ids = quantize_block( m, - (input_ids,input_others), + (input_ids, input_others), q_input=q_input, device=device, ) From 0354c2ba5b1e7aed80005007f4a0530e9d06ea96 Mon Sep 17 00:00:00 2001 From: yiliu30 Date: Wed, 5 Nov 2025 00:11:37 -0500 Subject: [PATCH 28/52] update Signed-off-by: yiliu30 --- auto_round/utils/common.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/auto_round/utils/common.py b/auto_round/utils/common.py index ef2a2c939..cd495e4a5 100644 --- a/auto_round/utils/common.py +++ b/auto_round/utils/common.py @@ -298,8 +298,8 @@ def get_reciprocal(tensor): tensor = torch.where(torch.abs(tensor) < 1e-30, 0, tensor) return torch.where(tensor != 0, 1 / tensor, torch.zeros_like(tensor)) + def normalize_input(cur_inputs): - # TODO: move it to auto-round input_ids = [] input_others = {} positional_inputs = [] From e2e2d42b69bb39f23c418ef74194eccf5ad7db2a Mon Sep 17 00:00:00 2001 From: yiliu30 Date: Thu, 6 Nov 2025 00:06:27 -0800 Subject: [PATCH 29/52] add batch dim Signed-off-by: yiliu30 --- auto_round/compressors/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/auto_round/compressors/base.py b/auto_round/compressors/base.py index fc5e61871..3b9b8a332 100644 --- a/auto_round/compressors/base.py +++ b/auto_round/compressors/base.py @@ -353,7 +353,7 @@ def __init__( if "hpu" in str(self.device): self.inner_supported_types = tuple(x for x in INNER_SUPPORTED_LAYER_TYPES if x != "FP8Linear") # TODO: check with heng/weiwei - self.batch_dim = 0 + self.batch_dim = kwargs.pop("batch_dim", None) self.infer_bs_coeff = 1 self.block_forward = compile_func(block_forward, self.device) if self.enable_torch_compile else block_forward From b01c91ff8d3177ba3c3e30c9ff5add8c7f483f40 Mon Sep 17 00:00:00 2001 From: yiliu30 Date: Fri, 7 Nov 2025 19:53:37 -0800 Subject: [PATCH 30/52] update Signed-off-by: yiliu30 --- auto_round/compressors/base.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/auto_round/compressors/base.py b/auto_round/compressors/base.py index 3b9b8a332..1bcd2a4cd 100644 --- a/auto_round/compressors/base.py +++ b/auto_round/compressors/base.py @@ -238,7 +238,7 @@ def __init__( self.supported_types = SUPPORTED_LAYER_TYPES self.inner_supported_types = INNER_SUPPORTED_LAYER_TYPES self.scale_dtype = convert_dtype_str2torch(scale_dtype) - + self.batch_dim = kwargs.pop("batch_dim", None) if kwargs: logger.warning(f"unrecognized keys {list(kwargs.keys())} were passed. Please check them.") if "CUBLAS_WORKSPACE_CONFIG" not in os.environ: @@ -352,8 +352,6 @@ def __init__( # Some helpers if "hpu" in str(self.device): self.inner_supported_types = tuple(x for x in INNER_SUPPORTED_LAYER_TYPES if x != "FP8Linear") - # TODO: check with heng/weiwei - self.batch_dim = kwargs.pop("batch_dim", None) self.infer_bs_coeff = 1 self.block_forward = compile_func(block_forward, self.device) if self.enable_torch_compile else block_forward From 6800cfe06158f8fbe5035ed3e2698969b169c2f5 Mon Sep 17 00:00:00 2001 From: yiliu30 Date: Sun, 9 Nov 2025 17:04:28 -0800 Subject: [PATCH 31/52] add hints Signed-off-by: yiliu30 --- auto_round/utils/common.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/auto_round/utils/common.py b/auto_round/utils/common.py index cd495e4a5..2567b0c1f 100644 --- a/auto_round/utils/common.py +++ b/auto_round/utils/common.py @@ -299,7 +299,8 @@ def get_reciprocal(tensor): return torch.where(tensor != 0, 1 / tensor, torch.zeros_like(tensor)) -def normalize_input(cur_inputs): +def normalize_input(decoding_layer_inputs: list[tuple[Any]]) -> Tuple[List[torch.Tensor], Dict[str, Any]]: + """Normalize the decoding layer inputs into input_ids and other inputs.""" input_ids = [] input_others = {} positional_inputs = [] @@ -307,7 +308,7 @@ def normalize_input(cur_inputs): position_ids = None cache_position = None position_embeddings = (None, None) - for cur_inp in cur_inputs: + for cur_inp in decoding_layer_inputs: input_ids.append(cur_inp[0][0][0]) for key, val in cur_inp[0][1].items(): if key == "position_ids": From fe53fb3abcae5881b6d07946e6d805e2d3957f36 Mon Sep 17 00:00:00 2001 From: yiliu30 Date: Sun, 9 Nov 2025 18:53:13 -0800 Subject: [PATCH 32/52] clean Signed-off-by: yiliu30 --- auto_round/compressors/base.py | 25 ++++++++++++++++--------- auto_round/utils/common.py | 19 +++---------------- 2 files changed, 19 insertions(+), 25 deletions(-) diff --git a/auto_round/compressors/base.py b/auto_round/compressors/base.py index 1bcd2a4cd..6b458b053 100644 --- a/auto_round/compressors/base.py +++ b/auto_round/compressors/base.py @@ -2428,7 +2428,18 @@ def quantize_block( block: torch.nn.Module, inputs: tuple[Union[list[torch.Tensor], dict, Any], Optional[dict]], q_input: Union[torch.Tensor, dict, None] = None, - normalize_inputs: bool = False, + device: Union[str, torch.device] = "cpu", + auto_offload=True, + ): + input_ids, input_others = normalize_input(inputs) + return self._quantize_block(block, input_ids, input_others, q_input, device, auto_offload) + + def _quantize_block( + self, + block: torch.nn.Module, + input_ids: Union[list[torch.Tensor], dict], + input_others: dict, + q_input: Union[torch.Tensor, dict, None] = None, device: Union[str, torch.device] = "cpu", auto_offload=True, ): @@ -2449,10 +2460,6 @@ def quantize_block( if is_fp8_linear(m): new_layer = convert_fp8_layer_to_linear(m, self.amp_dtype).to(device) set_module(block, n, new_layer) - if normalize_inputs: - input_ids, input_others = normalize_input(inputs) - else: - input_ids, input_others = inputs if auto_offload: if self.device_map == "auto" or (isinstance(self.device_map, str) and "," in self.device_map): set_auto_device_map_for_block_with_tuning( @@ -2573,7 +2580,6 @@ def quantize_block( best_params = {} total_loss = 0 for i in range(self.iters): - logger.trace(f"Quant block iteration {i}/{self.iters}, best loss so far: {best_loss}") total_loss = 0 if self.sampler == "rand": whole_indices = torch.randperm(nsamples)[:pick_samples] @@ -2736,9 +2742,9 @@ def _quantize_blocks( else: logger.info("using algorithm extension for quantization.") except (ImportError, ModuleNotFoundError): - quantize_block = self.quantize_block + quantize_block = self._quantize_block else: - quantize_block = self.quantize_block + quantize_block = self._quantize_block if pbar is None: pbar = tqdm(range(0, len(block_names), nblocks)) @@ -2759,7 +2765,8 @@ def _quantize_blocks( m = m.to(device) q_input, input_ids = quantize_block( m, - (input_ids, input_others), + input_ids, + input_others, q_input=q_input, device=device, ) diff --git a/auto_round/utils/common.py b/auto_round/utils/common.py index 2567b0c1f..a39364288 100644 --- a/auto_round/utils/common.py +++ b/auto_round/utils/common.py @@ -303,23 +303,10 @@ def normalize_input(decoding_layer_inputs: list[tuple[Any]]) -> Tuple[List[torch """Normalize the decoding layer inputs into input_ids and other inputs.""" input_ids = [] input_others = {} - positional_inputs = [] - attention_mask = None - position_ids = None - cache_position = None - position_embeddings = (None, None) + input_others["positional_inputs"] = [] for cur_inp in decoding_layer_inputs: input_ids.append(cur_inp[0][0][0]) for key, val in cur_inp[0][1].items(): - if key == "position_ids": - position_ids = val - elif key == "position_embeddings": - position_embeddings = val - elif key == "cache_position": - cache_position = val - input_others["position_ids"] = position_ids - input_others["positional_inputs"] = positional_inputs - input_others["attention_mask"] = attention_mask - input_others["position_embeddings"] = position_embeddings - input_others["cache_position"] = cache_position + input_others[key] = val + return input_ids, input_others From 4e72f6f0f53d6c9693f95a796428374b86dc0a8a Mon Sep 17 00:00:00 2001 From: yiliu30 Date: Sun, 9 Nov 2025 18:55:25 -0800 Subject: [PATCH 33/52] refine Signed-off-by: yiliu30 --- auto_round/calib_dataset.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/auto_round/calib_dataset.py b/auto_round/calib_dataset.py index aeda07fb0..25c4cd21e 100644 --- a/auto_round/calib_dataset.py +++ b/auto_round/calib_dataset.py @@ -632,7 +632,7 @@ def select_dataset(dataset, indices): return dataset -def get_dataset(tokenizer, seqlen, dataset_name="NeelNanda/pile-10k", seed=42, bs=8, nsamples=512): +def get_dataset(tokenizer, seqlen, dataset_name="NeelNanda/pile-10k", seed=42, nsamples=512): """Generate a dataset for calibration. Args: @@ -643,7 +643,6 @@ def get_dataset(tokenizer, seqlen, dataset_name="NeelNanda/pile-10k", seed=42, b Defaults to "NeelNanda/pile-10k". split (str, optional): The data split to use. Defaults to None. seed (int, optional): The random seed for reproducibility. Defaults to 42. - bs (int, optional): The batch size. Defaults to 4. nsamples (int, optional): The total number of samples to include. Defaults to 512. apply_chat_template: Whether to apply chat template in tokenization. From b9dd50f1f152907ba9d3f0afb55cb62261a1397e Mon Sep 17 00:00:00 2001 From: yiliu30 Date: Sun, 9 Nov 2025 19:23:59 -0800 Subject: [PATCH 34/52] add auto offload back Signed-off-by: yiliu30 --- auto_round/compressors/base.py | 55 ++++++++++++++++++++-------------- 1 file changed, 32 insertions(+), 23 deletions(-) diff --git a/auto_round/compressors/base.py b/auto_round/compressors/base.py index da5ec02ca..deca482ea 100644 --- a/auto_round/compressors/base.py +++ b/auto_round/compressors/base.py @@ -2512,24 +2512,26 @@ def _quantize_block( if is_fp8_linear(m): new_layer = convert_fp8_layer_to_linear(m, self.amp_dtype, self.device).to(device) set_module(block, n, new_layer) - # card_0_in_high_risk indicates that card_0 memory is already in high usage (90%) w/o any weights - # loss_device is used to calculate loss on the second device if available and card_0_in_high_risk - if self.device_map == "auto" or ((isinstance(self.device_map, str) and "," in self.device_map)): - card_0_in_high_risk, loss_device = set_auto_device_map_for_block_with_tuning( - block, self.device_map, input_ids, self.low_gpu_mem_usage, self.batch_size, device - ) + if auto_offload: + # card_0_in_high_risk indicates that card_0 memory is already in high usage (90%) w/o any weights + # loss_device is used to calculate loss on the second device if available and card_0_in_high_risk + if self.device_map == "auto" or ((isinstance(self.device_map, str) and "," in self.device_map)): + card_0_in_high_risk, loss_device = set_auto_device_map_for_block_with_tuning( + block, self.device_map, input_ids, self.low_gpu_mem_usage, self.batch_size, device + ) + else: + block = block.to(device) + card_0_in_high_risk, loss_device = False, device else: - block = block.to(device) card_0_in_high_risk, loss_device = False, device - if len(self.device_list) > 1: + if len(self.device_list) > 1 and auto_offload: for n, m in block.named_modules(): if len(list(m.children())) != 0 or not hasattr(m, "tuning_device"): continue from accelerate.hooks import AlignDevicesHook, add_hook_to_module - - hook = AlignDevicesHook(m.tuning_device, io_same_device=True) - add_hook_to_module(m, hook, True) + hook = AlignDevicesHook(m.tuning_device, io_same_device=True) + add_hook_to_module(m, hook, True) if q_input is None: hook_handles = self._register_act_max_hook(block) @@ -2537,9 +2539,8 @@ def _quantize_block( output = self._get_block_outputs( block, input_ids, input_others, self.batch_size * self.infer_bs_coeff, device, self.cache_device ) - if auto_offload: - for handle in hook_handles: - handle.remove() + for handle in hook_handles: + handle.remove() else: output = self._get_block_outputs( block, input_ids, input_others, self.batch_size * self.infer_bs_coeff, device, self.cache_device @@ -2664,7 +2665,7 @@ def _quantize_block( else: tmp_attention_mask = 1.0 if self.amp: - with autocast(device_type=loss_device.split(":")[0], dtype=self.amp_dtype): + with autocast(device_type=str(loss_device).split(":")[0], dtype=self.amp_dtype): loss = mse_loss( # pylint: disable=not-callable output_q * tmp_attention_mask, current_output * tmp_attention_mask ) @@ -2673,9 +2674,11 @@ def _quantize_block( output_q.to(torch.float32) * tmp_attention_mask, current_output.to(torch.float32) * tmp_attention_mask, ) + total_loss += loss.item() / num_elm + if self.low_gpu_mem_usage and card_0_in_high_risk: # clear memory to avoid OOM due to memory fragmentation - + clear_memory_if_reached_threshold(threshold=0.5, device_list=self.device_list) self._scale_loss_and_backward(scaler, loss) if self.low_gpu_mem_usage and card_0_in_high_risk: @@ -2731,19 +2734,25 @@ def _quantize_block( cache_device=self.cache_device, ) - if len(self.device_list) > 1: + if len(self.device_list) > 1 and auto_offload: accelerate.hooks.remove_hook_from_submodules(block) - mv_module_from_gpu(block) - clear_memory(input_ids) + mv_module_from_gpu(block) + clear_memory(input_ids) + return q_outputs, output else: - if len(self.device_list) > 1: + if len(self.device_list) > 1 and auto_offload: accelerate.hooks.remove_hook_from_submodules(block) - mv_module_from_gpu(block) + mv_module_from_gpu(block) clear_memory(input_ids) return None, output - def _split_inputs(self, inputs: dict) -> tuple[torch.Tensor, dict]: + input_ids = inputs["input_ids"] + inputs.pop("input_ids", None) + input_others = inputs + return input_ids, input_others + + def _quantize_blocks( self, model: torch.nn.Module, inputs: dict, @@ -2752,7 +2761,7 @@ def _split_inputs(self, inputs: dict) -> tuple[torch.Tensor, dict]: nblocks: int = 1, device: str = "cpu", pbar: tqdm = None, - ): + ) -> tuple[torch.Tensor, dict]: """Quantize and dequantize the weights of the specified blocks in the model. Args: From e566783d9b39fd8b7948248f78ca8dc66ae41aba Mon Sep 17 00:00:00 2001 From: yiliu30 Date: Sun, 9 Nov 2025 19:26:06 -0800 Subject: [PATCH 35/52] clean Signed-off-by: yiliu30 --- auto_round/compressors/base.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/auto_round/compressors/base.py b/auto_round/compressors/base.py index deca482ea..a8b87e327 100644 --- a/auto_round/compressors/base.py +++ b/auto_round/compressors/base.py @@ -2674,11 +2674,13 @@ def _quantize_block( output_q.to(torch.float32) * tmp_attention_mask, current_output.to(torch.float32) * tmp_attention_mask, ) + total_loss += loss.item() / num_elm if self.low_gpu_mem_usage and card_0_in_high_risk: # clear memory to avoid OOM due to memory fragmentation clear_memory_if_reached_threshold(threshold=0.5, device_list=self.device_list) + self._scale_loss_and_backward(scaler, loss) if self.low_gpu_mem_usage and card_0_in_high_risk: @@ -2746,7 +2748,7 @@ def _quantize_block( clear_memory(input_ids) return None, output - def _split_inputs(self, inputs: dict) -> tuple[torch.Tensor, dict]: + def _split_inputs(self, inputs: dict): input_ids = inputs["input_ids"] inputs.pop("input_ids", None) input_others = inputs From 329f67de39e4e7175eec033292879b584b367fea Mon Sep 17 00:00:00 2001 From: yiliu30 Date: Sun, 9 Nov 2025 19:27:41 -0800 Subject: [PATCH 36/52] fix Signed-off-by: yiliu30 --- auto_round/compressors/base.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/auto_round/compressors/base.py b/auto_round/compressors/base.py index a8b87e327..dd43c5eb3 100644 --- a/auto_round/compressors/base.py +++ b/auto_round/compressors/base.py @@ -2530,6 +2530,7 @@ def _quantize_block( if len(list(m.children())) != 0 or not hasattr(m, "tuning_device"): continue from accelerate.hooks import AlignDevicesHook, add_hook_to_module + hook = AlignDevicesHook(m.tuning_device, io_same_device=True) add_hook_to_module(m, hook, True) @@ -2539,6 +2540,7 @@ def _quantize_block( output = self._get_block_outputs( block, input_ids, input_others, self.batch_size * self.infer_bs_coeff, device, self.cache_device ) + for handle in hook_handles: handle.remove() else: @@ -2725,7 +2727,7 @@ def _quantize_block( if is_nv_fp(self.act_data_type): # enable moe experts act_max automatic generation for WrapperWALayer set_amax_for_all_moe_layers(block, attr_name="orig_layer.act_max") - q_outputs = None + if self.enable_quanted_input: q_outputs = self._get_block_outputs( block, From da6fe1654aff71e2ba9d3e6acd8f44dd9dd0d017 Mon Sep 17 00:00:00 2001 From: yiliu30 Date: Sun, 9 Nov 2025 19:28:55 -0800 Subject: [PATCH 37/52] clean Signed-off-by: yiliu30 --- auto_round/compressors/base.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/auto_round/compressors/base.py b/auto_round/compressors/base.py index dd43c5eb3..bc2f52cae 100644 --- a/auto_round/compressors/base.py +++ b/auto_round/compressors/base.py @@ -2741,7 +2741,9 @@ def _quantize_block( if len(self.device_list) > 1 and auto_offload: accelerate.hooks.remove_hook_from_submodules(block) mv_module_from_gpu(block) + clear_memory(input_ids) + return q_outputs, output else: if len(self.device_list) > 1 and auto_offload: From 70f54fe00fda51e12a95174c70ecffbefdebb37a Mon Sep 17 00:00:00 2001 From: yiliu30 Date: Sun, 9 Nov 2025 19:30:20 -0800 Subject: [PATCH 38/52] revert Signed-off-by: yiliu30 --- auto_round/compressors/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/auto_round/compressors/base.py b/auto_round/compressors/base.py index bc2f52cae..ff81b4e36 100644 --- a/auto_round/compressors/base.py +++ b/auto_round/compressors/base.py @@ -2752,7 +2752,7 @@ def _quantize_block( clear_memory(input_ids) return None, output - def _split_inputs(self, inputs: dict): + def _split_inputs(self, inputs: dict) -> tuple[torch.Tensor, dict]: input_ids = inputs["input_ids"] inputs.pop("input_ids", None) input_others = inputs From a7be1dd35c945235952b6ffee10bbd12eedd810e Mon Sep 17 00:00:00 2001 From: yiliu30 Date: Sun, 9 Nov 2025 19:32:36 -0800 Subject: [PATCH 39/52] revert Signed-off-by: yiliu30 --- auto_round/compressors/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/auto_round/compressors/base.py b/auto_round/compressors/base.py index ff81b4e36..ebebcfa68 100644 --- a/auto_round/compressors/base.py +++ b/auto_round/compressors/base.py @@ -2767,7 +2767,7 @@ def _quantize_blocks( nblocks: int = 1, device: str = "cpu", pbar: tqdm = None, - ) -> tuple[torch.Tensor, dict]: + ): """Quantize and dequantize the weights of the specified blocks in the model. Args: From 69ff60bc3e07634785f7c4b115a77d0256d7fb94 Mon Sep 17 00:00:00 2001 From: yiliu30 Date: Sun, 9 Nov 2025 19:49:56 -0800 Subject: [PATCH 40/52] remove hard model name path Signed-off-by: yiliu30 --- auto_round/utils/model.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/auto_round/utils/model.py b/auto_round/utils/model.py index 9d3e5301e..5f9e1941b 100644 --- a/auto_round/utils/model.py +++ b/auto_round/utils/model.py @@ -514,8 +514,6 @@ def is_mllm_model(model_or_path: Union[str, torch.nn.Module], platform: str = No "pre_mm_projector_norm", "vision", ] - # FIXME: yi, fix it later - return False model_path = model_or_path if isinstance(model_or_path, str) else model_or_path.name_or_path if not os.path.isdir(model_path): model_path = download_or_get_path(model_path, platform=platform) From f3f41a7c8bfe7c48d289420f9331ba17c11a5277 Mon Sep 17 00:00:00 2001 From: yiliu30 Date: Sun, 9 Nov 2025 20:57:48 -0800 Subject: [PATCH 41/52] fix Signed-off-by: yiliu30 --- auto_round/calib_dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/auto_round/calib_dataset.py b/auto_round/calib_dataset.py index 25c4cd21e..96d80f7ea 100644 --- a/auto_round/calib_dataset.py +++ b/auto_round/calib_dataset.py @@ -863,6 +863,6 @@ def collate_batch(batch): res = {"input_ids": input_ids_new, "attention_mask": attention_mask_new} return res - dataset_final = get_dataset(tokenizer, seqlen, dataset_name, seed, bs, nsamples) + dataset_final = get_dataset(tokenizer, seqlen, dataset_name, seed, nsamples) calib_dataloader = DataLoader(dataset_final, batch_size=bs, shuffle=False, collate_fn=collate_batch) return calib_dataloader From a4581fe5b11350b338391a99d55e53c83770cd32 Mon Sep 17 00:00:00 2001 From: yiliu30 Date: Sun, 9 Nov 2025 21:00:08 -0800 Subject: [PATCH 42/52] revert Signed-off-by: yiliu30 --- auto_round/compressors/base.py | 1 - 1 file changed, 1 deletion(-) diff --git a/auto_round/compressors/base.py b/auto_round/compressors/base.py index ebebcfa68..83c03cfb1 100644 --- a/auto_round/compressors/base.py +++ b/auto_round/compressors/base.py @@ -249,7 +249,6 @@ def __init__( self.supported_types = SUPPORTED_LAYER_TYPES self.inner_supported_types = INNER_SUPPORTED_LAYER_TYPES self.scale_dtype = convert_dtype_str2torch(scale_dtype) - self.batch_dim = kwargs.pop("batch_dim", None) self.low_cpu_mem_usage = low_cpu_mem_usage if kwargs: From 79added74c109fa2396a76232f397cee414615f7 Mon Sep 17 00:00:00 2001 From: yiliu30 Date: Sun, 9 Nov 2025 21:02:46 -0800 Subject: [PATCH 43/52] update Signed-off-by: yiliu30 --- auto_round/compressors/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/auto_round/compressors/base.py b/auto_round/compressors/base.py index 83c03cfb1..cd8f0b2ff 100644 --- a/auto_round/compressors/base.py +++ b/auto_round/compressors/base.py @@ -1528,7 +1528,7 @@ def _update_inputs(self, inputs: dict, q_inputs: dict) -> tuple[dict, torch.Tens q_inputs = q_inputs.pop(input_id_str[0], None) return inputs, q_inputs - def configure_layer_config(self, enable_gguf_official_mixed: None | bool = False): + def configure_layer_config(self, enable_gguf_official_mixed: None | bool = True): self.layer_config, self.has_qlayer_outside_block, self.regex_config = set_layer_config( self.model, self.layer_config, From 67dba40a487367ec9ab32fd3b6a5cbd14c80faa4 Mon Sep 17 00:00:00 2001 From: yiliu30 Date: Sun, 9 Nov 2025 21:09:31 -0800 Subject: [PATCH 44/52] revert Signed-off-by: yiliu30 --- auto_round/compressors/base.py | 1 + 1 file changed, 1 insertion(+) diff --git a/auto_round/compressors/base.py b/auto_round/compressors/base.py index cd8f0b2ff..cbd70415c 100644 --- a/auto_round/compressors/base.py +++ b/auto_round/compressors/base.py @@ -367,6 +367,7 @@ def __init__( # Some helpers if "hpu" in str(self.device): self.inner_supported_types = tuple(x for x in INNER_SUPPORTED_LAYER_TYPES if x != "FP8Linear") + self.batch_dim = None self.infer_bs_coeff = 1 self.block_forward = compile_func(block_forward, self.device) if self.enable_torch_compile else block_forward From 5ad35e1715a86e1a4157ecadf7d9b9a84522a870 Mon Sep 17 00:00:00 2001 From: yiliu30 Date: Sun, 9 Nov 2025 21:43:57 -0800 Subject: [PATCH 45/52] fix use cache Signed-off-by: yiliu30 --- auto_round/utils/common.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/auto_round/utils/common.py b/auto_round/utils/common.py index a39364288..9d4e4c98a 100644 --- a/auto_round/utils/common.py +++ b/auto_round/utils/common.py @@ -308,5 +308,8 @@ def normalize_input(decoding_layer_inputs: list[tuple[Any]]) -> Tuple[List[torch input_ids.append(cur_inp[0][0][0]) for key, val in cur_inp[0][1].items(): input_others[key] = val - + # Force 'use_cache' to be False + if "use_cache" in input_others and input_others["use_cache"] is True: + logger.warning_once("Forcing 'use_cache' to be False during calibration.") + input_others["use_cache"] = False return input_ids, input_others From 021662bdd3c84f1a4e455b348a648f1c720cf1be Mon Sep 17 00:00:00 2001 From: yiliu30 Date: Sun, 9 Nov 2025 21:49:14 -0800 Subject: [PATCH 46/52] add more checker Signed-off-by: yiliu30 --- auto_round/compressors/base.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/auto_round/compressors/base.py b/auto_round/compressors/base.py index cbd70415c..3c40cdeb1 100644 --- a/auto_round/compressors/base.py +++ b/auto_round/compressors/base.py @@ -2483,6 +2483,11 @@ def quantize_block( device: Union[str, torch.device] = "cpu", auto_offload=True, ): + # TODO: relase below assertion after supporting MLLM and diffusion model quantization with quantize_block + assert self.__class__.__name__ not in [ + "DiffusionCompressor", + "MLLMCompressor", + ], f"Currently, {self.__class__.__name__} does not support support quantize block with this function." input_ids, input_others = normalize_input(inputs) return self._quantize_block(block, input_ids, input_others, q_input, device, auto_offload) From e5b58abb84b781d4c7d1c651a7d1404ad226fbcf Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 10 Nov 2025 05:49:38 +0000 Subject: [PATCH 47/52] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- auto_round/compressors/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/auto_round/compressors/base.py b/auto_round/compressors/base.py index 3c40cdeb1..75eaf0c75 100644 --- a/auto_round/compressors/base.py +++ b/auto_round/compressors/base.py @@ -2483,7 +2483,7 @@ def quantize_block( device: Union[str, torch.device] = "cpu", auto_offload=True, ): - # TODO: relase below assertion after supporting MLLM and diffusion model quantization with quantize_block + # TODO: release below assertion after supporting MLLM and diffusion model quantization with quantize_block assert self.__class__.__name__ not in [ "DiffusionCompressor", "MLLMCompressor", From 5d78028fc5e65ae092fbf049b38ae94ff2f23831 Mon Sep 17 00:00:00 2001 From: yiliu30 Date: Sun, 9 Nov 2025 22:14:58 -0800 Subject: [PATCH 48/52] add more docs Signed-off-by: yiliu30 --- auto_round/compressors/base.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/auto_round/compressors/base.py b/auto_round/compressors/base.py index 3c40cdeb1..4c798db25 100644 --- a/auto_round/compressors/base.py +++ b/auto_round/compressors/base.py @@ -2483,6 +2483,13 @@ def quantize_block( device: Union[str, torch.device] = "cpu", auto_offload=True, ): + """ + This function quantizes a specific decoded block of a model. + It is primarily used by LLM-Compressor. For more details, please refer to the following PR: + https://github.com/vllm-project/llm-compressor/pull/1994 + """ + + # TODO: relase below assertion after supporting MLLM and diffusion model quantization with quantize_block assert self.__class__.__name__ not in [ "DiffusionCompressor", From 4816a85ac422bb7f18051b19fa4054456483d9fa Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 10 Nov 2025 06:18:25 +0000 Subject: [PATCH 49/52] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- auto_round/compressors/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/auto_round/compressors/base.py b/auto_round/compressors/base.py index 8d0320a8d..551e9f93a 100644 --- a/auto_round/compressors/base.py +++ b/auto_round/compressors/base.py @@ -2488,7 +2488,7 @@ def quantize_block( It is primarily used by LLM-Compressor. For more details, please refer to the following PR: https://github.com/vllm-project/llm-compressor/pull/1994 """ - + # TODO: release below assertion after supporting MLLM and diffusion model quantization with quantize_block assert self.__class__.__name__ not in [ "DiffusionCompressor", From 073132da7cb23097d57852daf1d2bffdcc0c1bc6 Mon Sep 17 00:00:00 2001 From: yiliu30 Date: Sun, 9 Nov 2025 22:23:44 -0800 Subject: [PATCH 50/52] fix Signed-off-by: yiliu30 --- auto_round/compressors/base.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/auto_round/compressors/base.py b/auto_round/compressors/base.py index 8d0320a8d..ff77f456c 100644 --- a/auto_round/compressors/base.py +++ b/auto_round/compressors/base.py @@ -2751,6 +2751,7 @@ def _quantize_block( if len(self.device_list) > 1 and auto_offload: accelerate.hooks.remove_hook_from_submodules(block) + if auto_offload: mv_module_from_gpu(block) clear_memory(input_ids) @@ -2759,6 +2760,7 @@ def _quantize_block( else: if len(self.device_list) > 1 and auto_offload: accelerate.hooks.remove_hook_from_submodules(block) + if auto_offload: mv_module_from_gpu(block) clear_memory(input_ids) return None, output From 1d868767c057b8fed77b6eee87493ae82b493d81 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 10 Nov 2025 08:29:49 +0000 Subject: [PATCH 51/52] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- auto_round/compressors/base.py | 1 - 1 file changed, 1 deletion(-) diff --git a/auto_round/compressors/base.py b/auto_round/compressors/base.py index 5b7543712..879193409 100644 --- a/auto_round/compressors/base.py +++ b/auto_round/compressors/base.py @@ -2757,7 +2757,6 @@ def _quantize_block( clear_memory(input_ids) - return q_outputs, output else: if len(self.device_list) > 1 and auto_offload: From 3334c65f91c9f2fb64737c7cc544b7806bbb0631 Mon Sep 17 00:00:00 2001 From: yiliu30 Date: Mon, 10 Nov 2025 00:32:21 -0800 Subject: [PATCH 52/52] fix Signed-off-by: yiliu30 --- auto_round/compressors/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/auto_round/compressors/base.py b/auto_round/compressors/base.py index 5b7543712..b9c1855a3 100644 --- a/auto_round/compressors/base.py +++ b/auto_round/compressors/base.py @@ -2527,7 +2527,7 @@ def _quantize_block( if auto_offload: # card_0_in_high_risk indicates that card_0 memory is already in high usage (90%) w/o any weights # loss_device is used to calculate loss on the second device if available and card_0_in_high_risk - if self.device_map == "auto" or ((isinstance(self.device_map, str) and "," in self.device_map)): + if is_auto_device_mapping(self.device_map) and len(self.device_list) > 1: card_0_in_high_risk, loss_device = set_auto_device_map_for_block_with_tuning( block, self.device_map, input_ids, self.low_gpu_mem_usage, self.batch_size, device )