Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions examples/windows/onnx_ptq/genai_llm/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ The table below lists key command-line arguments of the ONNX PTQ example script.
| `--layers_8bit` | Default: None | Use this option to Overrides default mixed quant strategy|
| `--gather_quantize_axis` | Default: None | Use this option to enable INT4 quantization of Gather nodes - choose 0 or 1|
| `--gather_block_size` | Default: 32 | Block-size for Gather node's INT4 quantization (when its enabled using gather_quantize_axis option)|
| `--use_column_major` | Default: disabled | Apply column-major storage optimization for execution providers that need it. Only applicable for DQ-only quantization.|

Run the following command to view all available parameters in the script:

Expand Down
12 changes: 11 additions & 1 deletion examples/windows/onnx_ptq/genai_llm/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,7 +369,7 @@ def main(args):
f"batch_size={args.batch_size}, block_size={args.block_size}, add-position-ids={args.add_position_ids}, "
f"past-kv={args.add_past_kv_inputs}, rcalib={args.use_random_calib}, device={args.device}, "
f"use_zero_point={args.use_zero_point}, use_fp32={args.use_fp32} enable_mixed_quant={args.enable_mixed_quant}, "
f"layers_8bit={args.layers_8bit}\n"
f"layers_8bit={args.layers_8bit}, use_column_major={args.use_column_major}\n"
)

print(
Expand Down Expand Up @@ -443,6 +443,7 @@ def main(args):
layers_8bit=args.layers_8bit,
gather_block_size=args.gather_block_size,
gather_quantize_axis=args.gather_quantize_axis,
use_column_major=args.use_column_major,
)
logging.info(f"\nQuantization process took {time.time() - t} seconds")

Expand Down Expand Up @@ -629,5 +630,14 @@ def main(args):
default="",
help=("Overrides default mixed quant strategy. Example: 'layers.0,lm_head'"),
)
parser.add_argument(
"--use_column_major",
default=False,
action="store_true",
help=(
"Apply column-major storage optimization for execution providers that need it. "
"Only applicable for DQ-only quantization (e.g., rtn_dq, awq_lite, awq_clip)."
),
)
args = parser.parse_args()
main(args)
60 changes: 57 additions & 3 deletions modelopt/onnx/quantization/int4.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,21 @@ def quantize_rtn(

Always selects the first dimension (0) to block over. This is because we must batch over the Cin
dimension, and in ONNX, weights are always plugged into the RHS (i.e. y = x @ W).

Args:
use_column_major: If True, apply column-major storage optimization for execution
providers that need it. Passed via kwargs.
"""
use_column_major = kwargs.get("use_column_major", False)

# Column-major only makes sense for DQ-only mode
if use_column_major and not dq_only:
logger.warning(
"use_column_major=True has no effect in QDQ mode. "
"Column-major optimization only applies to DQ-only quantization."
)
use_column_major = False

logger.info("Starting RTN quantization")
t_start = time.time()

Expand Down Expand Up @@ -295,8 +309,15 @@ def quantize_rtn(
qw = np.asnumpy(qw)
scales[name] = np.asnumpy(scales[name])
gemm_weights_quantized[name] = numpy.asarray(qw)
# Apply column-major optimization if flag is set
# Transposes the weights and scales in-place
if use_column_major:
qdq.apply_column_major_transformation(gemm_weights_quantized, scales)
dq_node_attributes = {"axis": 1, "block_size": block_size}
else:
dq_node_attributes = {"axis": 0, "block_size": block_size}

scales = reshape_scales_for_per_channel_nodes(scales, block_size, layer_info)
dq_node_attributes = {"axis": 0, "block_size": block_size}
qdq.insert_dq_nodes(
graph,
scales,
Expand All @@ -305,6 +326,10 @@ def quantize_rtn(
layer_info=layer_info,
)

# Add transpose nodes for column-major if needed
if use_column_major:
qdq.insert_transpose_nodes_for_column_major(graph)

if gather_w_map is not None:
gather_dq_node_attributes = {
"axis": gather_quantize_axis,
Expand Down Expand Up @@ -605,7 +630,14 @@ def _quantize_awq_clip(
)

t = time.time()
dq_node_attributes = {"axis": 0, "block_size": block_size}
# Apply column-major optimization if flag is set
# Transposes the weights and scales in-place
use_column_major = kwargs.get("use_column_major", False)
if use_column_major:
qdq.apply_column_major_transformation(gemm_weights_quantized, scales)
dq_node_attributes = {"axis": 1, "block_size": block_size}
else:
dq_node_attributes = {"axis": 0, "block_size": block_size}
scales = reshape_scales_for_per_channel_nodes(scales, block_size, layer_info)
qdq.insert_dq_nodes(
graph_gs,
Expand All @@ -614,6 +646,9 @@ def _quantize_awq_clip(
attributes=dq_node_attributes,
layer_info=layer_info,
)
# Add transpose nodes for column-major if needed
if use_column_major:
qdq.insert_transpose_nodes_for_column_major(graph_gs)
if gather_w_map is not None:
assert gather_s_map is not None, "scale-map not found for quantizable gather nodes"
gather_dq_node_attributes = {"axis": gather_quantize_axis, "block_size": gather_block_size}
Expand Down Expand Up @@ -1308,7 +1343,14 @@ def _quantize_awq_lite(
)

t = time.time()
dq_node_attributes = {"axis": 0, "block_size": block_size}
# Apply column-major optimization if flag is set
# Transposes the weights and scales in-place
use_column_major = kwargs.get("use_column_major", False)
if use_column_major:
qdq.apply_column_major_transformation(gemm_weights_quantized, scales)
dq_node_attributes = {"axis": 1, "block_size": block_size}
else:
dq_node_attributes = {"axis": 0, "block_size": block_size}
scales = reshape_scales_for_per_channel_nodes(scales, block_size, layer_info)
qdq.insert_dq_nodes(
graph_gs,
Expand All @@ -1318,6 +1360,9 @@ def _quantize_awq_lite(
zero_points=zero_points if use_zero_point else None,
layer_info=layer_info,
)
# Add transpose nodes for column-major if needed
if use_column_major:
qdq.insert_transpose_nodes_for_column_major(graph_gs)
if gather_w_map is not None:
assert gather_s_map is not None, "scale-map not found for quantizable gather nodes"
assert not use_zero_point or gather_zp_map, (
Expand Down Expand Up @@ -1420,10 +1465,19 @@ def quantize(
Default: False.
- **layers_8bit** (str): comma-separated list of layer patterns to quantize to INT8 instead of INT4.
Default: [].
- **use_column_major** (bool): If True, apply column-major storage optimization for
execution providers that need it. This transposes
weights and adds Transpose nodes around MatMul operations.
Only applies to DQ-only quantization mode.
Default: False.
**Returns**: A quantized ONNX model in ONNX ModelProto format.
"""
configure_logging(level=log_level.upper())
logger.info(f"Starting INT4 quantization with method: {calibration_method}")

# Log if column-major optimization is enabled (works for all methods)
if kwargs.get("use_column_major", False):
logger.info("Column-major storage optimization enabled via use_column_major flag")
t_start = time.time()

if cupy_warning_msg:
Expand Down
134 changes: 134 additions & 0 deletions modelopt/onnx/quantization/qdq_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1022,6 +1022,140 @@ def replace_zero_scale_with_smallest_nonzero(onnx_model: onnx.ModelProto) -> onn
return onnx_model


# =============================================================================
# Column-major weight storage transformation for execution providers that need it
# =============================================================================


def _apply_transpose_perm_to_shape(shape, perm):
"""Apply transpose permutation to a shape to get the output shape.

Args:
shape: Input shape as a list/tuple
perm: Permutation indices

Returns:
Transposed shape or None if inputs are None
"""
if shape is None or perm is None:
return None
return [shape[i] for i in perm]


def insert_transpose_nodes_for_column_major(graph: gs.Graph):
"""Add a single Transpose node after each DequantizeLinear for column-major weights.

This implements the simple transformation: A @ B = A @ ((B^T)^T)
where B^T is stored in the DequantizeLinear node, and we add a Transpose
node after DQ to recover B before the MatMul.

Graph transformation:
Before: DQ(W) -> MatMul/Gemm
After: DQ(W^T) -> Transpose -> W -> MatMul/Gemm

Args:
graph: ONNX GraphSurgeon graph to modify in-place
"""
nodes_to_add = []
dq_nodes_processed = set()

for node in graph.nodes:
if node.op in ["MatMul", "Gemm"]:
# Check if second input (weight) is from DequantizeLinear
weight_input = node.inputs[1]
if not isinstance(weight_input, gs.Variable):
continue

# Find the producer of the weight input
producer_nodes = [n for n in graph.nodes if weight_input in n.outputs]
if not producer_nodes:
continue

producer_node = producer_nodes[0]
if producer_node.op != DEQUANTIZE_NODE_NAME:
continue

# Skip if we already processed this DQ node
if producer_node.name in dq_nodes_processed:
continue
dq_nodes_processed.add(producer_node.name)

# For Gemm nodes, check if transB is already set
if node.op == "Gemm":
trans_b = False
if hasattr(node, "attrs") and "transB" in node.attrs:
trans_b = node.attrs["transB"] > 0
if trans_b:
logger.debug(f"Gemm node {node.name} already has transB=1, skipping")
continue

Comment on lines +1083 to +1091
Copy link
Contributor

@coderabbitai coderabbitai bot Jan 23, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Gemm transB=1 skip breaks correctness with column‑major weights.
apply_column_major_transformation already transposes weights. If a Gemm has transB=1, skipping the transpose‑back makes Gemm consume B instead of B^T, changing outputs. Either always insert the transpose‑back or flip transB to 0 so Gemm consumes B^T directly.

🐛 Proposed fix (flip transB to 0 and keep semantics)
-            # For Gemm nodes, check if transB is already set
-            if node.op == "Gemm":
-                trans_b = False
-                if hasattr(node, "attrs") and "transB" in node.attrs:
-                    trans_b = node.attrs["transB"] > 0
-                if trans_b:
-                    logger.debug(f"Gemm node {node.name} already has transB=1, skipping")
-                    continue
+            # For Gemm nodes with transB=1, flip to 0 since weights are already transposed
+            if node.op == "Gemm":
+                trans_b = bool((node.attrs or {}).get("transB", 0))
+                if trans_b:
+                    node.attrs = node.attrs or {}
+                    node.attrs["transB"] = 0
+                    logger.debug(
+                        f"Gemm node {node.name} has transB=1; setting transB=0 for column-major weights"
+                    )
+                    continue
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
# For Gemm nodes, check if transB is already set
if node.op == "Gemm":
trans_b = False
if hasattr(node, "attrs") and "transB" in node.attrs:
trans_b = node.attrs["transB"] > 0
if trans_b:
logger.debug(f"Gemm node {node.name} already has transB=1, skipping")
continue
# For Gemm nodes with transB=1, flip to 0 since weights are already transposed
if node.op == "Gemm":
trans_b = bool((node.attrs or {}).get("transB", 0))
if trans_b:
node.attrs = node.attrs or {}
node.attrs["transB"] = 0
logger.debug(
f"Gemm node {node.name} has transB=1; setting transB=0 for column-major weights"
)
continue
🤖 Prompt for AI Agents
In `@modelopt/onnx/quantization/qdq_utils.py` around lines 1083 - 1091, The
current Gemm handling in apply_column_major_transformation (qdq_utils.py) skips
nodes with node.op == "Gemm" when node.attrs contains transB=1, which breaks
semantics for column-major weights; instead, when encountering a Gemm with
transB set, update the node.attrs transB to 0 (or remove/normalize it to zero)
so the graph expects B^T (matching the earlier weight transpose) and do not skip
inserting the transpose-back; locate the Gemm handling block (check for node.op
== "Gemm" and the transB logic) and replace the early continue with logic that
flips node.attrs["transB"] to 0 (or deletes the attr) so outputs remain correct
while keeping the transpose-back insertion.

Copy link
Contributor

@vishalpandya1990 vishalpandya1990 Jan 30, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@hthadicherla Please review this and close/resolve as needed.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
# Search for Gemm nodes with transB attribute in ONNX files and code
echo "=== Searching for Gemm with transB in Python code ==="
rg -n "transB" --type=py -C2

echo -e "\n=== Searching for test files related to column_major ==="
fd -t f -e py test | xargs rg -l "column_major" 2>/dev/null || echo "No test files found with column_major"

echo -e "\n=== Searching for Gemm usage in test files ==="
fd -t f -e py test | xargs rg -n "Gemm" -C1 2>/dev/null | head -50 || echo "No Gemm usage in tests"

Length of output: 3647


🏁 Script executed:

#!/bin/bash
# Look at the column_major test to see what cases are covered
echo "=== Column major test content ==="
rg -n "column_major" tests/unit/onnx/test_qdq_utils.py -A 30 -B 5

echo -e "\n=== Check if any test creates Gemm with transB ==="
rg -n "transB" tests/ --type=py -C3

Length of output: 5042


@vishalpandya1990, after reviewing the code and test coverage, I can confirm this is a valid correctness bug that should be fixed.

Analysis

The current code has a logic error when handling Gemm nodes with transB=1:

Current buggy behavior:

  1. apply_column_major_transformation transposes weights: W → W^T
  2. DQ node outputs the transposed weight W^T
  3. Code skips inserting transpose-back node for Gemm with transB=1
  4. Gemm with transB=1 receives W^T and applies transpose: (W^T)^T = W
  5. Result: A @ W (incorrect! Original was A @ W^T)

With the proposed fix:

  1. Weights are transposed: W → W^T
  2. DQ outputs W^T
  3. Code flips transB from 1 to 0 and skips transpose node
  4. Gemm with transB=0 receives W^T and uses it directly
  5. Result: A @ W^T (correct! Matches original semantics)

Test Coverage Gap

The existing test in test_qdq_utils.py::TestColumnMajorTransformation only covers MatMul nodes, not Gemm with transB=1, which means this edge case is currently untested.

Recommendation

Accept the CodeRabbit suggestion - the fix is correct and preserves semantics by flipping transB to 0 when weights are already transposed.

However, I'd recommend one clarification to the log message in the fix:

# For Gemm nodes with transB=1, flip to 0 since weights are already transposed
if node.op == "Gemm":
    trans_b = bool((node.attrs or {}).get("transB", 0))
    if trans_b:
        node.attrs = node.attrs or {}
        node.attrs["transB"] = 0
        logger.debug(
            f"Gemm node {node.name}: setting transB=0 to preserve semantics with column-major weights"
        )
        continue

The continue is intentional - it skips transpose node insertion because flipping transB achieves the same result more efficiently.


🧬 Code Graph Analysis Results

File: modelopt/onnx/quantization/graph_utils.py (relevant snippet)

Line range: 283-302

def get_tensor_consumer_nodes(
    graph: onnx.GraphProto,
) -> dict[str, list[onnx.NodeProto]]:
    """Returns a dictionary of tensor name and their consumer node object mapping.

    Args:
        graph: ONNX model graph.

    Returns:
        Dictionary, key is tensor name and value is their consumer node object
    """
    # Create a dictionary to store tensor consumer nodes
    tensor_consumers = defaultdict(list)

    # Traverse the graph to find consumer nodes for each tensor
    for node in graph.node:
        for input_name in node.input:
            tensor_consumers[input_name].append(node)

    return tensor_consumers

Line range: 216-234

def get_tensor_from_name(graph: onnx.GraphProto, tensor_name: str) -> onnx.ValueInfoProto | None:
    """Returns a ValueInfoProto given a tensor name.

    Args:
        graph: ONNX model graph
        tensor_name: String with tensor name.

    Returns:
        onnx.ValueInfoProto: actual graph tensor.
    """
    # Search in inputs
    vi = next((vi for vi in graph.input if vi.name == tensor_name), None)
    # If not found, search in outputs
    if vi is None:
        vi = next((vi for vi in graph.output if vi.name == tensor_name), None)
    # If not found, search in value_info (intermediate tensors)
    if vi is None:
        vi = next((vi for vi in graph.value_info if vi.name == tensor_name), None)
    return vi

Line range: 237-280

def get_tensor_producer_nodes(
    graph: onnx.GraphProto,
    get_initializer_producers: bool = False,
) -> dict[str, onnx.NodeProto]:
    """Returns a dictionary of tensor name and their producer node object mapping.

    Note. we create a special Root type node as external inputs producer for ease of implementation.

    Args:
        graph: ONNX model graph.

    Returns:
        Dictionary, key is tensor name and value is their producer node object
    """
    # Create a dictionary to store tensor producer nodes
    tensor_producers = defaultdict(None)

    # Special Root type producer node
    root_node = onnx.helper.make_node(
        op_type="Root",
        inputs=[],
        outputs=[i.name for i in graph.input],
        name="root_0",
    )

    input_names = [graph_input.name for graph_input in graph.input]
    initializer_names = [initializer.name for initializer in graph.initializer]
    external_input_names = list(np.setdiff1d(input_names, initializer_names))

    # Note. We are marking external inputs as non-constant by adding a parent,
    # so that we can quantize the first node of the graph if appropriate
    for graph_input in external_input_names:
        tensor_producers[graph_input] = root_node

    # Traverse the graph to find producer nodes for each tensor
    for node in graph.node:
        for output_name in node.output:
            tensor_producers[output_name] = node

    if get_initializer_producers:
        for initializer in graph.initializer:
            tensor_producers[initializer.name] = initializer

    return tensor_producers

File: modelopt/onnx/quantization/graph_utils.py (relevant snippet)

Line range: 1800-1882

def remove_redundant_cast_nodes(graph: onnx.GraphProto) -> None:
    """Remove redundant Cast nodes from the ONNX graph to optimize model performance.

    This function identifies and removes two types of redundant Cast nodes:

    1. Cast nodes where input and output types are identical
       - Before: t1 (dtype=fp16) -> cast (to=fp16) -> t2 -> Op
       - After:  t1 (dtype=fp16) -> Op

    2. Cast nodes that can be fused with initializers
       - Before: (initializer) t1 (dtype=fp32) -> cast (to=fp16) -> t2 -> Op
       - After:  (initializer) t1 (dtype=fp16) -> Op

    The function preserves Cast nodes that:
    - Have outputs that are graph outputs
    - Are necessary for type conversion
    - Have dynamic inputs (not initializers)

    Args:
        graph: ONNX graph to optimize. The graph will be modified in-place.

    Note:
        - This optimization is particularly useful for models with many Cast operations
        - The function modifies the graph in-place
        - All tensor consumers are updated to maintain graph connectivity
        - Initializer data types are converted when possible to eliminate Cast nodes
    """
    initializers = {init.name: init for init in graph.initializer}
    tensor_consumers = get_tensor_consumer_nodes(graph)
    value_info_map = {info.name: info for info in graph.value_info}
    cast_indices = []
    output_names = {output.name for output in graph.output}

    def _get_tensor_type(tensor_name: str) -> int | None:
        """Get the tensor type for a given tensor name."""
        if tensor_name in value_info_map:
            return value_info_map[tensor_name].type.tensor_type.elem_type
        if tensor_name in initializers:
            return initializers[tensor_name].data_type
        return None

    for node_idx, node in enumerate(graph.node):
        if node.op_type != "Cast":
            continue

        # Skip if output is a graph output
        if any(out_name in output_names for out_name in node.output):
            continue

        input_name = node.input[0]
        input_type = _get_tensor_type(input_name)
        if input_type is None:
            continue

        # Get target type from Cast node attributes
        attr = next((attr for attr in node.attribute if attr.name == "to"), None)
        if attr is None:
            continue

        # Pattern 1: Input and output types are the same
        if input_type == attr.i:
            cast_indices.append(node_idx)
        # Pattern 2: Convert and fuse Cast node for initializers
        elif input_name in initializers:
            cast_indices.append(node_idx)
            cast_input = onnx.numpy_helper.to_array(initializers[input_name])
            dtype = onnx.helper.tensor_dtype_to_np_dtype(attr.i)
            converted_tensor = onnx.numpy_helper.from_array(cast_input.astype(dtype), input_name)
            initializers[input_name].CopyFrom(converted_tensor)
        else:
            continue

        # Update consumer nodes
        for consumer in tensor_consumers.get(node.output[0], []):
            for i, input_tensor in enumerate(consumer.input):
                if input_tensor == node.output[0]:
                    consumer.input[i] = input_name
                    break

    # Remove Cast nodes in reverse order
    logger.info(f"Removing {len(cast_indices)} redundant Cast nodes")
    for node_idx in sorted(cast_indices, reverse=True):
        del graph.node[node_idx]

File: modelopt/onnx/quantization/quant_utils.py (relevant snippet)

Line range: 189-204

def get_num_bits(layer_info: dict[str, dict] | None = None, name: str | None = None) -> int:
    """Determine the layer configuration for quantization from layer_info.

    Args:
        layer_info (dict[str, dict] | None): Optional dictionary mapping tensor names
            to layer configuration dict.
        name (str | None): Name of the tensor.

    Returns:
        int: Number of bits to use for quantization. Defaults to 4 if not specified.
    """
    if layer_info and name in layer_info:
        num_bits = layer_info[name]["precision"]
    else:
        num_bits = 4
    return num_bits

# Get weight shape and dtype from DQ output
# DQ outputs W^T (transposed), shape is [N, K] instead of [K, N]
weight_shape = weight_input.shape if hasattr(weight_input, "shape") else None
weight_dtype = weight_input.dtype if hasattr(weight_input, "dtype") else None

# Permutation for 2D weights: [1, 0] to transpose back
# The stored weight is B^T (transposed), we need to get B back
# For 2D [N, K] (stored as transposed): perm [1, 0] -> [K, N] (original)
perm = [1, 0]

# Compute the transposed shape (original weight shape)
transposed_weight_shape = _apply_transpose_perm_to_shape(weight_shape, perm)

# Create output variable for the transpose node
transpose_out = gs.Variable(
f"{producer_node.name}_transposed_back",
dtype=weight_dtype,
shape=transposed_weight_shape,
)

# Create transpose node: (B^T)^T = B
transpose_node = gs.Node(
op="Transpose",
name=f"{producer_node.name}_transpose_back",
inputs=[weight_input],
outputs=[transpose_out],
attrs={"perm": perm},
)

# Update MatMul/Gemm to use the transposed weight
node.inputs[1] = transpose_out

# Add transpose node to list
nodes_to_add.append(transpose_node)

# Add all new nodes to graph
if nodes_to_add:
graph.nodes.extend(nodes_to_add)
logger.info(f"Added {len(nodes_to_add)} transpose nodes for column-major optimization")

# Clean up and reorder graph
graph.cleanup().toposort()


def apply_column_major_transformation(
gemm_weights_quantized: dict,
scales: dict,
) -> None:
"""Transpose quantized weights and scales in-place for column-major storage.

Note: After calling this function and inserting DQ nodes with axis=1,
you should call insert_transpose_nodes_for_column_major() on the graph.

Args:
gemm_weights_quantized: Dictionary mapping weight names to quantized weight arrays
scales: Dictionary mapping weight names to scale arrays
"""
logger.info("Applying column-major storage optimization")

# Transpose weights and scales in-place
for name in list(gemm_weights_quantized.keys()):
gemm_weights_quantized[name] = gemm_weights_quantized[name].T

for name in list(scales.keys()):
scales[name] = scales[name].T


def cast_initializer_to_dtype(
node: onnx.NodeProto, dtype: str, initializer_map: dict[str, onnx.TensorProto]
):
Expand Down
Loading