-
Notifications
You must be signed in to change notification settings - Fork 247
Added column-major storage of weights and scales in INT4 quantization for model load time improvement in TRT-RTX #811
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
…improvement in TRT-RTX Signed-off-by: Hrishith Thadicherla <hthadicherla@nvidia.com>
|
Important Review skippedAuto incremental reviews are disabled on this repository. Please check the settings in the CodeRabbit UI or the You can disable this status message by setting the
📝 WalkthroughWalkthroughThis PR introduces a column-major storage optimization feature for ONNX INT4 quantization targeting the NvTensorRtRtx execution provider. It adds a CLI flag to the quantization script, integrates it through the quantization pipeline, and provides utility functions for applying column-major transformations to GEMM weights and inserting transpose operations in DQ-only quantization modes. Changes
Sequence Diagram(s)sequenceDiagram
actor User
participant CLI as quantize.py<br/>(CLI)
participant API as int4.py<br/>(quantize)
participant Transform as qdq_utils.py<br/>(apply_column_major)
participant Graph as Graph<br/>(ONNX)
User->>CLI: --use_column_major flag
CLI->>API: quantize(...,<br/>use_column_major=True)
API->>Transform: apply_column_major_transformation(<br/>weights, scales, ...)
Transform->>Transform: Transpose weights &<br/>scales in-place
Transform->>API: Return DQ attributes<br/>(axis=1)
API->>Graph: Create DQ nodes with<br/>column-major attributes
API->>Transform: add_transpose_nodes_for_column_major(graph)
Transform->>Graph: Insert Transpose nodes<br/>after DQ nodes
Transform->>Graph: Update MatMul/Gemm<br/>inputs
Graph-->>User: Optimized ONNX model
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~20 minutes 🚥 Pre-merge checks | ✅ 3✅ Passed checks (3 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
🤖 Fix all issues with AI agents
In `@modelopt/onnx/quantization/qdq_utils.py`:
- Around line 1083-1091: The current Gemm handling in
apply_column_major_transformation (qdq_utils.py) skips nodes with node.op ==
"Gemm" when node.attrs contains transB=1, which breaks semantics for
column-major weights; instead, when encountering a Gemm with transB set, update
the node.attrs transB to 0 (or remove/normalize it to zero) so the graph expects
B^T (matching the earlier weight transpose) and do not skip inserting the
transpose-back; locate the Gemm handling block (check for node.op == "Gemm" and
the transB logic) and replace the early continue with logic that flips
node.attrs["transB"] to 0 (or deletes the attr) so outputs remain correct while
keeping the transpose-back insertion.
| # For Gemm nodes, check if transB is already set | ||
| if node.op == "Gemm": | ||
| trans_b = False | ||
| if hasattr(node, "attrs") and "transB" in node.attrs: | ||
| trans_b = node.attrs["transB"] > 0 | ||
| if trans_b: | ||
| logger.debug(f"Gemm node {node.name} already has transB=1, skipping") | ||
| continue | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Gemm transB=1 skip breaks correctness with column‑major weights.
apply_column_major_transformation already transposes weights. If a Gemm has transB=1, skipping the transpose‑back makes Gemm consume B instead of B^T, changing outputs. Either always insert the transpose‑back or flip transB to 0 so Gemm consumes B^T directly.
🐛 Proposed fix (flip transB to 0 and keep semantics)
- # For Gemm nodes, check if transB is already set
- if node.op == "Gemm":
- trans_b = False
- if hasattr(node, "attrs") and "transB" in node.attrs:
- trans_b = node.attrs["transB"] > 0
- if trans_b:
- logger.debug(f"Gemm node {node.name} already has transB=1, skipping")
- continue
+ # For Gemm nodes with transB=1, flip to 0 since weights are already transposed
+ if node.op == "Gemm":
+ trans_b = bool((node.attrs or {}).get("transB", 0))
+ if trans_b:
+ node.attrs = node.attrs or {}
+ node.attrs["transB"] = 0
+ logger.debug(
+ f"Gemm node {node.name} has transB=1; setting transB=0 for column-major weights"
+ )
+ continue📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| # For Gemm nodes, check if transB is already set | |
| if node.op == "Gemm": | |
| trans_b = False | |
| if hasattr(node, "attrs") and "transB" in node.attrs: | |
| trans_b = node.attrs["transB"] > 0 | |
| if trans_b: | |
| logger.debug(f"Gemm node {node.name} already has transB=1, skipping") | |
| continue | |
| # For Gemm nodes with transB=1, flip to 0 since weights are already transposed | |
| if node.op == "Gemm": | |
| trans_b = bool((node.attrs or {}).get("transB", 0)) | |
| if trans_b: | |
| node.attrs = node.attrs or {} | |
| node.attrs["transB"] = 0 | |
| logger.debug( | |
| f"Gemm node {node.name} has transB=1; setting transB=0 for column-major weights" | |
| ) | |
| continue |
🤖 Prompt for AI Agents
In `@modelopt/onnx/quantization/qdq_utils.py` around lines 1083 - 1091, The
current Gemm handling in apply_column_major_transformation (qdq_utils.py) skips
nodes with node.op == "Gemm" when node.attrs contains transB=1, which breaks
semantics for column-major weights; instead, when encountering a Gemm with
transB set, update the node.attrs transB to 0 (or remove/normalize it to zero)
so the graph expects B^T (matching the earlier weight transpose) and do not skip
inserting the transpose-back; locate the Gemm handling block (check for node.op
== "Gemm" and the transB logic) and replace the early continue with logic that
flips node.attrs["transB"] to 0 (or deletes the attr) so outputs remain correct
while keeping the transpose-back insertion.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@hthadicherla Please review this and close/resolve as needed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🧩 Analysis chain
🏁 Script executed:
#!/bin/bash
# Search for Gemm nodes with transB attribute in ONNX files and code
echo "=== Searching for Gemm with transB in Python code ==="
rg -n "transB" --type=py -C2
echo -e "\n=== Searching for test files related to column_major ==="
fd -t f -e py test | xargs rg -l "column_major" 2>/dev/null || echo "No test files found with column_major"
echo -e "\n=== Searching for Gemm usage in test files ==="
fd -t f -e py test | xargs rg -n "Gemm" -C1 2>/dev/null | head -50 || echo "No Gemm usage in tests"
Length of output: 3647
🏁 Script executed:
#!/bin/bash
# Look at the column_major test to see what cases are covered
echo "=== Column major test content ==="
rg -n "column_major" tests/unit/onnx/test_qdq_utils.py -A 30 -B 5
echo -e "\n=== Check if any test creates Gemm with transB ==="
rg -n "transB" tests/ --type=py -C3
Length of output: 5042
@vishalpandya1990, after reviewing the code and test coverage, I can confirm this is a valid correctness bug that should be fixed.
Analysis
The current code has a logic error when handling Gemm nodes with transB=1:
Current buggy behavior:
apply_column_major_transformationtransposes weights: W → W^T- DQ node outputs the transposed weight W^T
- Code skips inserting transpose-back node for Gemm with transB=1
- Gemm with transB=1 receives W^T and applies transpose: (W^T)^T = W
- Result: A @ W (incorrect! Original was A @ W^T)
With the proposed fix:
- Weights are transposed: W → W^T
- DQ outputs W^T
- Code flips transB from 1 to 0 and skips transpose node
- Gemm with transB=0 receives W^T and uses it directly
- Result: A @ W^T (correct! Matches original semantics)
Test Coverage Gap
The existing test in test_qdq_utils.py::TestColumnMajorTransformation only covers MatMul nodes, not Gemm with transB=1, which means this edge case is currently untested.
Recommendation
✅ Accept the CodeRabbit suggestion - the fix is correct and preserves semantics by flipping transB to 0 when weights are already transposed.
However, I'd recommend one clarification to the log message in the fix:
# For Gemm nodes with transB=1, flip to 0 since weights are already transposed
if node.op == "Gemm":
trans_b = bool((node.attrs or {}).get("transB", 0))
if trans_b:
node.attrs = node.attrs or {}
node.attrs["transB"] = 0
logger.debug(
f"Gemm node {node.name}: setting transB=0 to preserve semantics with column-major weights"
)
continueThe continue is intentional - it skips transpose node insertion because flipping transB achieves the same result more efficiently.
🧬 Code Graph Analysis Results
File: modelopt/onnx/quantization/graph_utils.py (relevant snippet)
Line range: 283-302
def get_tensor_consumer_nodes(
graph: onnx.GraphProto,
) -> dict[str, list[onnx.NodeProto]]:
"""Returns a dictionary of tensor name and their consumer node object mapping.
Args:
graph: ONNX model graph.
Returns:
Dictionary, key is tensor name and value is their consumer node object
"""
# Create a dictionary to store tensor consumer nodes
tensor_consumers = defaultdict(list)
# Traverse the graph to find consumer nodes for each tensor
for node in graph.node:
for input_name in node.input:
tensor_consumers[input_name].append(node)
return tensor_consumers
Line range: 216-234
def get_tensor_from_name(graph: onnx.GraphProto, tensor_name: str) -> onnx.ValueInfoProto | None:
"""Returns a ValueInfoProto given a tensor name.
Args:
graph: ONNX model graph
tensor_name: String with tensor name.
Returns:
onnx.ValueInfoProto: actual graph tensor.
"""
# Search in inputs
vi = next((vi for vi in graph.input if vi.name == tensor_name), None)
# If not found, search in outputs
if vi is None:
vi = next((vi for vi in graph.output if vi.name == tensor_name), None)
# If not found, search in value_info (intermediate tensors)
if vi is None:
vi = next((vi for vi in graph.value_info if vi.name == tensor_name), None)
return vi
Line range: 237-280
def get_tensor_producer_nodes(
graph: onnx.GraphProto,
get_initializer_producers: bool = False,
) -> dict[str, onnx.NodeProto]:
"""Returns a dictionary of tensor name and their producer node object mapping.
Note. we create a special Root type node as external inputs producer for ease of implementation.
Args:
graph: ONNX model graph.
Returns:
Dictionary, key is tensor name and value is their producer node object
"""
# Create a dictionary to store tensor producer nodes
tensor_producers = defaultdict(None)
# Special Root type producer node
root_node = onnx.helper.make_node(
op_type="Root",
inputs=[],
outputs=[i.name for i in graph.input],
name="root_0",
)
input_names = [graph_input.name for graph_input in graph.input]
initializer_names = [initializer.name for initializer in graph.initializer]
external_input_names = list(np.setdiff1d(input_names, initializer_names))
# Note. We are marking external inputs as non-constant by adding a parent,
# so that we can quantize the first node of the graph if appropriate
for graph_input in external_input_names:
tensor_producers[graph_input] = root_node
# Traverse the graph to find producer nodes for each tensor
for node in graph.node:
for output_name in node.output:
tensor_producers[output_name] = node
if get_initializer_producers:
for initializer in graph.initializer:
tensor_producers[initializer.name] = initializer
return tensor_producers
File: modelopt/onnx/quantization/graph_utils.py (relevant snippet)
Line range: 1800-1882
def remove_redundant_cast_nodes(graph: onnx.GraphProto) -> None:
"""Remove redundant Cast nodes from the ONNX graph to optimize model performance.
This function identifies and removes two types of redundant Cast nodes:
1. Cast nodes where input and output types are identical
- Before: t1 (dtype=fp16) -> cast (to=fp16) -> t2 -> Op
- After: t1 (dtype=fp16) -> Op
2. Cast nodes that can be fused with initializers
- Before: (initializer) t1 (dtype=fp32) -> cast (to=fp16) -> t2 -> Op
- After: (initializer) t1 (dtype=fp16) -> Op
The function preserves Cast nodes that:
- Have outputs that are graph outputs
- Are necessary for type conversion
- Have dynamic inputs (not initializers)
Args:
graph: ONNX graph to optimize. The graph will be modified in-place.
Note:
- This optimization is particularly useful for models with many Cast operations
- The function modifies the graph in-place
- All tensor consumers are updated to maintain graph connectivity
- Initializer data types are converted when possible to eliminate Cast nodes
"""
initializers = {init.name: init for init in graph.initializer}
tensor_consumers = get_tensor_consumer_nodes(graph)
value_info_map = {info.name: info for info in graph.value_info}
cast_indices = []
output_names = {output.name for output in graph.output}
def _get_tensor_type(tensor_name: str) -> int | None:
"""Get the tensor type for a given tensor name."""
if tensor_name in value_info_map:
return value_info_map[tensor_name].type.tensor_type.elem_type
if tensor_name in initializers:
return initializers[tensor_name].data_type
return None
for node_idx, node in enumerate(graph.node):
if node.op_type != "Cast":
continue
# Skip if output is a graph output
if any(out_name in output_names for out_name in node.output):
continue
input_name = node.input[0]
input_type = _get_tensor_type(input_name)
if input_type is None:
continue
# Get target type from Cast node attributes
attr = next((attr for attr in node.attribute if attr.name == "to"), None)
if attr is None:
continue
# Pattern 1: Input and output types are the same
if input_type == attr.i:
cast_indices.append(node_idx)
# Pattern 2: Convert and fuse Cast node for initializers
elif input_name in initializers:
cast_indices.append(node_idx)
cast_input = onnx.numpy_helper.to_array(initializers[input_name])
dtype = onnx.helper.tensor_dtype_to_np_dtype(attr.i)
converted_tensor = onnx.numpy_helper.from_array(cast_input.astype(dtype), input_name)
initializers[input_name].CopyFrom(converted_tensor)
else:
continue
# Update consumer nodes
for consumer in tensor_consumers.get(node.output[0], []):
for i, input_tensor in enumerate(consumer.input):
if input_tensor == node.output[0]:
consumer.input[i] = input_name
break
# Remove Cast nodes in reverse order
logger.info(f"Removing {len(cast_indices)} redundant Cast nodes")
for node_idx in sorted(cast_indices, reverse=True):
del graph.node[node_idx]
File: modelopt/onnx/quantization/quant_utils.py (relevant snippet)
Line range: 189-204
def get_num_bits(layer_info: dict[str, dict] | None = None, name: str | None = None) -> int:
"""Determine the layer configuration for quantization from layer_info.
Args:
layer_info (dict[str, dict] | None): Optional dictionary mapping tensor names
to layer configuration dict.
name (str | None): Name of the tensor.
Returns:
int: Number of bits to use for quantization. Defaults to 4 if not specified.
"""
if layer_info and name in layer_info:
num_bits = layer_info[name]["precision"]
else:
num_bits = 4
return num_bits
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #811 +/- ##
==========================================
- Coverage 74.17% 73.80% -0.38%
==========================================
Files 192 193 +1
Lines 19246 19814 +568
==========================================
+ Hits 14276 14623 +347
- Misses 4970 5191 +221 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
|
@tcherckez-nvidia - do you mind reviewing? |
|
Please add a unit test for this. For reference: https://github.com/NVIDIA/Model-Optimizer/blob/main/tests/unit/onnx Besides, also check/compare create_test_model_with_int4_dq_reshape_transpose_matmul() in https://github.com/NVIDIA/Model-Optimizer/blob/main/tests/unit/onnx/test_qdq_utils.py. |
The pattern in the test case you mentioned seems to be DequantizeLinear -> Reshape -> Transpose -> MatMul . I'm not sure why this is being tested, i saw that reshape and transpose nodes are being removed by int4quantexporter later anyway. See https://github.com/NVIDIA/Model-Optimizer/blob/main/modelopt/onnx/export/int4_exporter.py#L33-L121 regardless it is different from our pattern which is DequantizeLinear(W^T) -> Transpose ->Matmul. But what would the test case be though we create the pattern and then what ? One test case i'm thinking of is have dummy weight values and activation/layernorm values and create DequantizeLinear(W^T)->Transpose->Matmul pattern and DequantizeLinear(W) ->Matmul and see if the matmul output is the same or not. |
…viders that need it), added use_column_major to log output and README, and renamed add_transpose_nodes_for_column_major to insert_transpose_nodes_for_column_major with inline comments. Signed-off-by: Hrishith Thadicherla <hthadicherla@nvidia.com>
|
@vishalpandya1990 I addressed most of the comments , can you look at the new changes i made and also look at some of the questions that i had regarding some of the changes you suggested ? |
Yes, we can check that quantized model resulting after this transformation is enabled - is valid and as we would expect. For instance, we can do sanity check on quantized graph / nodes (layout, shapes) and the output (if feasible). You can also skim through some existing unit tests to get further insight on potential test-cases. |
Signed-off-by: Hrishith Thadicherla <hthadicherla@nvidia.com>
| Verifies both produce the same output for the same input. | ||
| """ | ||
| import onnxruntime as ort | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Perhaps, we can simplify this a bit by: creating a simple 1 linear (matmul) model, running quantize API 2 times (one with column-major on and another with column-major off). And then compare/validate output1 and output2, with a utility for model's inference run. Can be done in follow-up PR.
What does this PR do?
Type of change: ? New feature
Overview:
TensorRT-RTX requires the weights and scales in the ONNX models to be in column-major format. So whenever the model loads TRT-RTX JIT transposes the weights and scales during load time, causing increased load time.
Proposed feature is after quantization, transpose the weights and scales in DQ node and add a transpose node right after i.e,
A × B = A × ((Bᵀ)ᵀ)
The transformation is post processing step and is disabled by default. It can be enabled by quantizing with --use_column_major
Usage
Testing
Tested a few LLM's and their MMLU scores with and without this transformation. No degradations were observed.
Summary by CodeRabbit
Release Notes
--use_column_majorcommand-line flag to ONNX quantization script for enabling column-major weight storage optimization compatible with NvTensorRtRtx execution provider. This optimization applies to DQ-only quantization modes (rtn_dq, awq_lite, awq_clip).✏️ Tip: You can customize this high-level summary in your review settings.