-
Notifications
You must be signed in to change notification settings - Fork 247
Added column-major storage of weights and scales in INT4 quantization for model load time improvement in TRT-RTX #811
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
dc4096d
73252a0
ab6316e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -1022,6 +1022,140 @@ def replace_zero_scale_with_smallest_nonzero(onnx_model: onnx.ModelProto) -> onn | |||||||||||||||||||||||||||||||||||||
| return onnx_model | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| # ============================================================================= | ||||||||||||||||||||||||||||||||||||||
| # Column-major weight storage transformation for execution providers that need it | ||||||||||||||||||||||||||||||||||||||
| # ============================================================================= | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| def _apply_transpose_perm_to_shape(shape, perm): | ||||||||||||||||||||||||||||||||||||||
| """Apply transpose permutation to a shape to get the output shape. | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| Args: | ||||||||||||||||||||||||||||||||||||||
| shape: Input shape as a list/tuple | ||||||||||||||||||||||||||||||||||||||
| perm: Permutation indices | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| Returns: | ||||||||||||||||||||||||||||||||||||||
| Transposed shape or None if inputs are None | ||||||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||||||
| if shape is None or perm is None: | ||||||||||||||||||||||||||||||||||||||
| return None | ||||||||||||||||||||||||||||||||||||||
| return [shape[i] for i in perm] | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| def insert_transpose_nodes_for_column_major(graph: gs.Graph): | ||||||||||||||||||||||||||||||||||||||
| """Add a single Transpose node after each DequantizeLinear for column-major weights. | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| This implements the simple transformation: A @ B = A @ ((B^T)^T) | ||||||||||||||||||||||||||||||||||||||
| where B^T is stored in the DequantizeLinear node, and we add a Transpose | ||||||||||||||||||||||||||||||||||||||
| node after DQ to recover B before the MatMul. | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| Graph transformation: | ||||||||||||||||||||||||||||||||||||||
| Before: DQ(W) -> MatMul/Gemm | ||||||||||||||||||||||||||||||||||||||
| After: DQ(W^T) -> Transpose -> W -> MatMul/Gemm | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| Args: | ||||||||||||||||||||||||||||||||||||||
| graph: ONNX GraphSurgeon graph to modify in-place | ||||||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||||||
| nodes_to_add = [] | ||||||||||||||||||||||||||||||||||||||
| dq_nodes_processed = set() | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| for node in graph.nodes: | ||||||||||||||||||||||||||||||||||||||
| if node.op in ["MatMul", "Gemm"]: | ||||||||||||||||||||||||||||||||||||||
| # Check if second input (weight) is from DequantizeLinear | ||||||||||||||||||||||||||||||||||||||
| weight_input = node.inputs[1] | ||||||||||||||||||||||||||||||||||||||
| if not isinstance(weight_input, gs.Variable): | ||||||||||||||||||||||||||||||||||||||
| continue | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| # Find the producer of the weight input | ||||||||||||||||||||||||||||||||||||||
| producer_nodes = [n for n in graph.nodes if weight_input in n.outputs] | ||||||||||||||||||||||||||||||||||||||
| if not producer_nodes: | ||||||||||||||||||||||||||||||||||||||
| continue | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| producer_node = producer_nodes[0] | ||||||||||||||||||||||||||||||||||||||
| if producer_node.op != DEQUANTIZE_NODE_NAME: | ||||||||||||||||||||||||||||||||||||||
| continue | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| # Skip if we already processed this DQ node | ||||||||||||||||||||||||||||||||||||||
| if producer_node.name in dq_nodes_processed: | ||||||||||||||||||||||||||||||||||||||
| continue | ||||||||||||||||||||||||||||||||||||||
| dq_nodes_processed.add(producer_node.name) | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| # For Gemm nodes, check if transB is already set | ||||||||||||||||||||||||||||||||||||||
| if node.op == "Gemm": | ||||||||||||||||||||||||||||||||||||||
| trans_b = False | ||||||||||||||||||||||||||||||||||||||
| if hasattr(node, "attrs") and "transB" in node.attrs: | ||||||||||||||||||||||||||||||||||||||
| trans_b = node.attrs["transB"] > 0 | ||||||||||||||||||||||||||||||||||||||
| if trans_b: | ||||||||||||||||||||||||||||||||||||||
| logger.debug(f"Gemm node {node.name} already has transB=1, skipping") | ||||||||||||||||||||||||||||||||||||||
| continue | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+1083
to
+1091
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Gemm 🐛 Proposed fix (flip transB to 0 and keep semantics)- # For Gemm nodes, check if transB is already set
- if node.op == "Gemm":
- trans_b = False
- if hasattr(node, "attrs") and "transB" in node.attrs:
- trans_b = node.attrs["transB"] > 0
- if trans_b:
- logger.debug(f"Gemm node {node.name} already has transB=1, skipping")
- continue
+ # For Gemm nodes with transB=1, flip to 0 since weights are already transposed
+ if node.op == "Gemm":
+ trans_b = bool((node.attrs or {}).get("transB", 0))
+ if trans_b:
+ node.attrs = node.attrs or {}
+ node.attrs["transB"] = 0
+ logger.debug(
+ f"Gemm node {node.name} has transB=1; setting transB=0 for column-major weights"
+ )
+ continue📝 Committable suggestion
Suggested change
🤖 Prompt for AI Agents
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @hthadicherla Please review this and close/resolve as needed.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🧩 Analysis chain🏁 Script executed: Length of output: 3647 🏁 Script executed: Length of output: 5042
AnalysisThe current code has a logic error when handling Gemm nodes with Current buggy behavior:
With the proposed fix:
Test Coverage GapThe existing test in Recommendation✅ Accept the CodeRabbit suggestion - the fix is correct and preserves semantics by flipping transB to 0 when weights are already transposed. However, I'd recommend one clarification to the log message in the fix: # For Gemm nodes with transB=1, flip to 0 since weights are already transposed
if node.op == "Gemm":
trans_b = bool((node.attrs or {}).get("transB", 0))
if trans_b:
node.attrs = node.attrs or {}
node.attrs["transB"] = 0
logger.debug(
f"Gemm node {node.name}: setting transB=0 to preserve semantics with column-major weights"
)
continueThe 🧬 Code Graph Analysis ResultsFile:
|
||||||||||||||||||||||||||||||||||||||
| # Get weight shape and dtype from DQ output | ||||||||||||||||||||||||||||||||||||||
| # DQ outputs W^T (transposed), shape is [N, K] instead of [K, N] | ||||||||||||||||||||||||||||||||||||||
| weight_shape = weight_input.shape if hasattr(weight_input, "shape") else None | ||||||||||||||||||||||||||||||||||||||
| weight_dtype = weight_input.dtype if hasattr(weight_input, "dtype") else None | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| # Permutation for 2D weights: [1, 0] to transpose back | ||||||||||||||||||||||||||||||||||||||
| # The stored weight is B^T (transposed), we need to get B back | ||||||||||||||||||||||||||||||||||||||
| # For 2D [N, K] (stored as transposed): perm [1, 0] -> [K, N] (original) | ||||||||||||||||||||||||||||||||||||||
| perm = [1, 0] | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| # Compute the transposed shape (original weight shape) | ||||||||||||||||||||||||||||||||||||||
| transposed_weight_shape = _apply_transpose_perm_to_shape(weight_shape, perm) | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| # Create output variable for the transpose node | ||||||||||||||||||||||||||||||||||||||
| transpose_out = gs.Variable( | ||||||||||||||||||||||||||||||||||||||
| f"{producer_node.name}_transposed_back", | ||||||||||||||||||||||||||||||||||||||
| dtype=weight_dtype, | ||||||||||||||||||||||||||||||||||||||
| shape=transposed_weight_shape, | ||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| # Create transpose node: (B^T)^T = B | ||||||||||||||||||||||||||||||||||||||
| transpose_node = gs.Node( | ||||||||||||||||||||||||||||||||||||||
| op="Transpose", | ||||||||||||||||||||||||||||||||||||||
| name=f"{producer_node.name}_transpose_back", | ||||||||||||||||||||||||||||||||||||||
| inputs=[weight_input], | ||||||||||||||||||||||||||||||||||||||
| outputs=[transpose_out], | ||||||||||||||||||||||||||||||||||||||
| attrs={"perm": perm}, | ||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| # Update MatMul/Gemm to use the transposed weight | ||||||||||||||||||||||||||||||||||||||
| node.inputs[1] = transpose_out | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| # Add transpose node to list | ||||||||||||||||||||||||||||||||||||||
| nodes_to_add.append(transpose_node) | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| # Add all new nodes to graph | ||||||||||||||||||||||||||||||||||||||
| if nodes_to_add: | ||||||||||||||||||||||||||||||||||||||
| graph.nodes.extend(nodes_to_add) | ||||||||||||||||||||||||||||||||||||||
| logger.info(f"Added {len(nodes_to_add)} transpose nodes for column-major optimization") | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| # Clean up and reorder graph | ||||||||||||||||||||||||||||||||||||||
| graph.cleanup().toposort() | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| def apply_column_major_transformation( | ||||||||||||||||||||||||||||||||||||||
| gemm_weights_quantized: dict, | ||||||||||||||||||||||||||||||||||||||
| scales: dict, | ||||||||||||||||||||||||||||||||||||||
| ) -> None: | ||||||||||||||||||||||||||||||||||||||
| """Transpose quantized weights and scales in-place for column-major storage. | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| Note: After calling this function and inserting DQ nodes with axis=1, | ||||||||||||||||||||||||||||||||||||||
| you should call insert_transpose_nodes_for_column_major() on the graph. | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| Args: | ||||||||||||||||||||||||||||||||||||||
| gemm_weights_quantized: Dictionary mapping weight names to quantized weight arrays | ||||||||||||||||||||||||||||||||||||||
| scales: Dictionary mapping weight names to scale arrays | ||||||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||||||
| logger.info("Applying column-major storage optimization") | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| # Transpose weights and scales in-place | ||||||||||||||||||||||||||||||||||||||
| for name in list(gemm_weights_quantized.keys()): | ||||||||||||||||||||||||||||||||||||||
| gemm_weights_quantized[name] = gemm_weights_quantized[name].T | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| for name in list(scales.keys()): | ||||||||||||||||||||||||||||||||||||||
| scales[name] = scales[name].T | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| def cast_initializer_to_dtype( | ||||||||||||||||||||||||||||||||||||||
| node: onnx.NodeProto, dtype: str, initializer_map: dict[str, onnx.TensorProto] | ||||||||||||||||||||||||||||||||||||||
| ): | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
Uh oh!
There was an error while loading. Please reload this page.