Skip to content

[low precision] More INT4 support#579

Open
yueming-yuan wants to merge 2 commits intomainfrom
feat/int4_asymmetric
Open

[low precision] More INT4 support#579
yueming-yuan wants to merge 2 commits intomainfrom
feat/int4_asymmetric

Conversation

@yueming-yuan
Copy link
Collaborator

@yueming-yuan yueming-yuan commented Feb 10, 2026

asymmetric quantization
convert_hf_to_int4.py and convert_hf_to_int4_direct.py utils

TODO

  • adapt changes to sglang-miles branch under sglang
  • update Docker to apply the sglang-miles branch
  • add docs
  • polish examples/scripts

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @yueming-yuan, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly expands the framework's capabilities for INT4 quantization. It introduces a more sophisticated quantization scheme that supports asymmetric weight representation and provides dedicated utilities for converting Hugging Face models to this new INT4 format. These changes aim to improve efficiency and flexibility when working with low-precision models, particularly in distributed environments, by refining how weights are processed and synchronized.

Highlights

  • Enhanced INT4 Quantization: Introduced a new WQLinear_GEMM class and refactored the core INT4 quantization logic, enabling more flexible and robust low-precision weight handling.
  • Asymmetric Quantization Support: The quantization process now explicitly supports asymmetric INT4 quantization, allowing for better representation of weights that are not symmetrically distributed around zero.
  • New HF Model Conversion Utilities: Added two new Python scripts: convert_hf_to_int4.py for converting Hugging Face models to INT4 using llmcompressor and convert_hf_to_int4_direct.py for a custom, multi-threaded direct conversion approach.
  • Improved Weight Update Synchronization: Adjusted the synchronization and engine control flow during distributed and tensor-based weight updates, ensuring proper pausing and continuation of generation engines around quantization post-processing.
Changelog
  • miles/backends/megatron_utils/megatron_to_hf/processors/quantizer_compressed_tensors.py
    • Removed Literal import and added torch.nn and fake_int4_quant_cuda imports.
    • Introduced WQLinear_GEMM class for handling quantized linear layers, including packing weights, scales, and zero points.
    • Modified pack_to_int32 to accept a sym parameter for symmetric/asymmetric packing and removed the torch.int8 type check.
    • Removed pack_int4_to_int32 and int4_block_quantize functions, replacing them with new, more generalized quantization helpers.
    • Added round_to_quantized_type_dtype, quantize, if_quant, and pack_layer functions to implement the new quantization logic.
    • Rewrote the quantize_params_compressed_tensors function to utilize the new pack_layer for processing and storing quantized weights, scales, and optional zero points, and expanded ignore rules.
  • miles/backends/megatron_utils/update_weight/update_weight_from_distributed.py
    • Adjusted the order of dist.barrier and ray.get calls to ensure proper synchronization and engine continuation during weight updates.
  • miles/backends/megatron_utils/update_weight/update_weight_from_tensor.py
    • Added a ray.get call to pause generation engines before flushing the cache.
    • Inserted an additional dist.barrier for better synchronization.
    • Reordered ray.get for engine continuation to occur after quantization post-processing.
  • tools/convert_hf_to_int4.py
    • Added a new utility script for converting Hugging Face models to INT4 using llmcompressor and GPTQ quantization.
  • tools/convert_hf_to_int4_direct.py
    • Added a new utility script for direct, multi-threaded conversion of Hugging Face models to INT4, implementing custom packing and quantization functions.
Activity
  • The pull request description includes a 'TODO' list, indicating that further work is planned, such as adapting changes to a specific SGLang branch, updating Docker, adding documentation, and polishing examples/scripts.
  • No human review comments or approvals have been recorded yet, suggesting the PR is still in its early stages or awaiting initial feedback.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces support for INT4 asymmetric quantization, including new conversion scripts and updates to the core quantization logic. The changes are a valuable addition. My review focuses on improving robustness and maintainability. I've identified a few instances of hardcoded device names that should be dynamic, a significant code duplication issue in one of the new scripts that should be refactored, and a common bug in command-line argument parsing. Addressing these points will make the code more robust and easier to maintain.

awq_linear.bias = linear.bias.clone().half()

pack_num = 32 // awq_linear.w_bit
device = torch.device(f"cuda:{torch.cuda.current_device()}")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The device is hardcoded to the current CUDA device. This can cause issues if the model's linear layer is on a different device (e.g., cuda:1 when the current device is cuda:0). It's better to use the device of the input linear layer's weight.

Suggested change
device = torch.device(f"cuda:{torch.cuda.current_device()}")
device = linear.weight.device

qw, s, zp = pack_layer(param, group_size, is_symmetric)
qweight_name = name.replace(".weight", ".weight_packed")
scale_name = name.replace(".weight", ".weight_scale")
weight_shape = torch.tensor(param.shape, dtype=torch.int32, device="cuda")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The device for weight_shape is hardcoded to 'cuda'. This will fail if the parameter is on a different device (e.g., CPU). It should use the device of the parameter itself.

Suggested change
weight_shape = torch.tensor(param.shape, dtype=torch.int32, device="cuda")
weight_shape = torch.tensor(param.shape, dtype=torch.int32, device=param.device)

Comment on lines +30 to +151
def pack_to_int32(
value,
num_bits,
packed_dim=1,
sym=False,
):
# if value.dtype is not torch.int8:
# raise ValueError("Tensor must be quantized to torch.int8 before packing")

if num_bits > 8:
raise ValueError("Packing is only supported for less than 8 bits")

if num_bits < 1:
raise ValueError(f"num_bits must be at least 1, got {num_bits}")

# Convert to unsigned range for packing, matching quantization offset
if sym:
offset = 1 << (num_bits - 1)
value = (value + offset).to(torch.uint8)
device = value.device

pack_factor = 32 // num_bits

if packed_dim == 0:
value = value.transpose(0, 1)

rows, cols = value.shape
padded_cols = math.ceil(cols / pack_factor) * pack_factor
pad_len = padded_cols - cols

if pad_len > 0:
value = torch.nn.functional.pad(value, (0, pad_len))

num_groups = padded_cols // pack_factor

# Use int32 here
reshaped = value.view(rows, num_groups, pack_factor).to(torch.int32)
bit_shifts = torch.arange(pack_factor, device=device, dtype=torch.int32) * num_bits
packed = (reshaped << bit_shifts).sum(dim=2, dtype=torch.int32)

if packed_dim == 0:
packed = packed.transpose(0, 1)

return packed


def round_to_quantized_type_dtype(
tensor,
dtype,
cast_to_original_dtype=False,
):
original_dtype = tensor.dtype
iinfo = torch.iinfo(dtype)
rounded = torch.round(torch.clamp(tensor, iinfo.min, iinfo.max)).to(dtype)
if cast_to_original_dtype:
return rounded.to(original_dtype)
return rounded


@torch.no_grad()
def quantize(
x,
scale,
zero_point,
dtype=torch.int8,
):
group_size = x.shape[-1] // scale.shape[-1]
output_dtype = dtype
output = torch.zeros_like(x).to(output_dtype)

reshaped_dims = (
math.ceil(x.shape[-1] / group_size),
group_size,
)
x = x.unflatten(-1, reshaped_dims)

scaled = x / scale.unsqueeze(-1)

if zero_point is not None:
zero_point = zero_point.unsqueeze(-1)
scaled += zero_point.to(x.dtype)

# clamp and round
output = round_to_quantized_type_dtype(tensor=scaled, dtype=dtype)

output = output.flatten(start_dim=-2)
output = output.to(output_dtype)

return output


def pack_layer(weight, group_size, sym=True):
w, scale, zp = fake_int4_quant_cuda.fake_int4_quant_cuda(weight, (1, group_size), sym)
w = w.view(weight.shape[0], 1, weight.shape[1] // group_size, group_size)
scale = scale.view(weight.shape[0], 1, weight.shape[1] // group_size, 1)
zp = zp.view(weight.shape[0], 1, weight.shape[1] // group_size, 1)
if sym:
w = w * scale
else:
w = (w - zp) * scale
w = w.view(weight.shape)
scale = scale.view(weight.shape[0], -1).contiguous()
if not sym:
zp = zp.view(weight.shape[0], -1)
zeros = zp.t().contiguous().to(torch.float32)
zeros = zeros.to(dtype=torch.int32, device=w.device)
zeros = zeros.reshape(-1, zeros.shape[1] // 8, 8)
new_order_map = torch.tensor([0, 4, 1, 5, 2, 6, 3, 7], device=zeros.device) * 4
zeros = zeros << new_order_map
packed_zp = torch.sum(zeros, dim=-1).to(torch.int32)
else:
zp = None
packed_zp = None

quantized_weight = quantize(
x=w,
scale=scale,
zero_point=zp,
dtype=torch.int8 if sym else torch.uint8,
)
packed_weight = pack_to_int32(quantized_weight, 4, sym=sym)
return packed_weight, scale, packed_zp
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

There is significant code duplication here. The functions pack_to_int32, round_to_quantized_type_dtype, quantize, and pack_layer are copied from miles/backends/megatron_utils/megatron_to_hf/processors/quantizer_compressed_tensors.py. This will make maintenance difficult. This utility code should be extracted into a shared module and imported in both places to avoid duplication.

qw, s, zp = pack_layer(weight, group_size, is_symmetric)
qweight_name = name.replace(".weight", ".weight_packed")
scale_name = name.replace(".weight", ".weight_scale")
weight_shape = torch.tensor(weight.shape, dtype=torch.int32, device="cuda")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The device for weight_shape is hardcoded to 'cuda'. This will fail if the weight tensor is on a different device (e.g., CPU). It should use the device of the weight tensor itself. This issue would be resolved by de-duplicating the quantization logic as mentioned in another comment.

Suggested change
weight_shape = torch.tensor(weight.shape, dtype=torch.int32, device="cuda")
weight_shape = torch.tensor(weight.shape, dtype=torch.int32, device=weight.device)

parser.add_argument("--model-dir", type=str, required=True, help="local BF16 path")
parser.add_argument("--save-dir", type=str, required=True)
parser.add_argument("--group-size", type=int, default=32, help="Group Size")
parser.add_argument("--is-symmetric", type=bool, default=True, help="Is Symmetric")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Using type=bool with argparse does not work as expected for command-line flags. For example, --is-symmetric False will be interpreted as True because bool('False') is True. The recommended way to handle boolean flags is with action='store_true' and action='store_false', or action=argparse.BooleanOptionalAction for Python 3.9+.

Suggested change
parser.add_argument("--is-symmetric", type=bool, default=True, help="Is Symmetric")
parser.add_argument("--is-symmetric", action=argparse.BooleanOptionalAction, default=True, help="Is Symmetric")


data_list = []
for _ in range(num_samples):
i = random.randint(0, encoded.shape[0] - seq_len - 1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

random.randint(a, b) raises a ValueError if a > b. In this case, if the encoded text is shorter than seq_len, encoded.shape[0] - seq_len - 1 can become negative, causing an error. It would be more robust to add a check before this loop to ensure the dataset is large enough.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant