diff --git a/examples/windows/onnx_ptq/genai_llm/README.md b/examples/windows/onnx_ptq/genai_llm/README.md index b833d44dc..f46606307 100644 --- a/examples/windows/onnx_ptq/genai_llm/README.md +++ b/examples/windows/onnx_ptq/genai_llm/README.md @@ -60,6 +60,7 @@ The table below lists key command-line arguments of the ONNX PTQ example script. | `--layers_8bit` | Default: None | Use this option to Overrides default mixed quant strategy| | `--gather_quantize_axis` | Default: None | Use this option to enable INT4 quantization of Gather nodes - choose 0 or 1| | `--gather_block_size` | Default: 32 | Block-size for Gather node's INT4 quantization (when its enabled using gather_quantize_axis option)| +| `--use_column_major` | Default: disabled | Apply column-major storage optimization for execution providers that need it. Only applicable for DQ-only quantization.| Run the following command to view all available parameters in the script: diff --git a/examples/windows/onnx_ptq/genai_llm/quantize.py b/examples/windows/onnx_ptq/genai_llm/quantize.py index 57021ed4d..cc3d6f216 100644 --- a/examples/windows/onnx_ptq/genai_llm/quantize.py +++ b/examples/windows/onnx_ptq/genai_llm/quantize.py @@ -369,7 +369,7 @@ def main(args): f"batch_size={args.batch_size}, block_size={args.block_size}, add-position-ids={args.add_position_ids}, " f"past-kv={args.add_past_kv_inputs}, rcalib={args.use_random_calib}, device={args.device}, " f"use_zero_point={args.use_zero_point}, use_fp32={args.use_fp32} enable_mixed_quant={args.enable_mixed_quant}, " - f"layers_8bit={args.layers_8bit}\n" + f"layers_8bit={args.layers_8bit}, use_column_major={args.use_column_major}\n" ) print( @@ -443,6 +443,7 @@ def main(args): layers_8bit=args.layers_8bit, gather_block_size=args.gather_block_size, gather_quantize_axis=args.gather_quantize_axis, + use_column_major=args.use_column_major, ) logging.info(f"\nQuantization process took {time.time() - t} seconds") @@ -629,5 +630,14 @@ def main(args): default="", help=("Overrides default mixed quant strategy. Example: 'layers.0,lm_head'"), ) + parser.add_argument( + "--use_column_major", + default=False, + action="store_true", + help=( + "Apply column-major storage optimization for execution providers that need it. " + "Only applicable for DQ-only quantization (e.g., rtn_dq, awq_lite, awq_clip)." + ), + ) args = parser.parse_args() main(args) diff --git a/modelopt/onnx/quantization/int4.py b/modelopt/onnx/quantization/int4.py index a6e98a579..b17431fb9 100644 --- a/modelopt/onnx/quantization/int4.py +++ b/modelopt/onnx/quantization/int4.py @@ -220,7 +220,21 @@ def quantize_rtn( Always selects the first dimension (0) to block over. This is because we must batch over the Cin dimension, and in ONNX, weights are always plugged into the RHS (i.e. y = x @ W). + + Args: + use_column_major: If True, apply column-major storage optimization for execution + providers that need it. Passed via kwargs. """ + use_column_major = kwargs.get("use_column_major", False) + + # Column-major only makes sense for DQ-only mode + if use_column_major and not dq_only: + logger.warning( + "use_column_major=True has no effect in QDQ mode. " + "Column-major optimization only applies to DQ-only quantization." + ) + use_column_major = False + logger.info("Starting RTN quantization") t_start = time.time() @@ -295,8 +309,15 @@ def quantize_rtn( qw = np.asnumpy(qw) scales[name] = np.asnumpy(scales[name]) gemm_weights_quantized[name] = numpy.asarray(qw) + # Apply column-major optimization if flag is set + # Transposes the weights and scales in-place + if use_column_major: + qdq.apply_column_major_transformation(gemm_weights_quantized, scales) + dq_node_attributes = {"axis": 1, "block_size": block_size} + else: + dq_node_attributes = {"axis": 0, "block_size": block_size} + scales = reshape_scales_for_per_channel_nodes(scales, block_size, layer_info) - dq_node_attributes = {"axis": 0, "block_size": block_size} qdq.insert_dq_nodes( graph, scales, @@ -305,6 +326,10 @@ def quantize_rtn( layer_info=layer_info, ) + # Add transpose nodes for column-major if needed + if use_column_major: + qdq.insert_transpose_nodes_for_column_major(graph) + if gather_w_map is not None: gather_dq_node_attributes = { "axis": gather_quantize_axis, @@ -605,7 +630,14 @@ def _quantize_awq_clip( ) t = time.time() - dq_node_attributes = {"axis": 0, "block_size": block_size} + # Apply column-major optimization if flag is set + # Transposes the weights and scales in-place + use_column_major = kwargs.get("use_column_major", False) + if use_column_major: + qdq.apply_column_major_transformation(gemm_weights_quantized, scales) + dq_node_attributes = {"axis": 1, "block_size": block_size} + else: + dq_node_attributes = {"axis": 0, "block_size": block_size} scales = reshape_scales_for_per_channel_nodes(scales, block_size, layer_info) qdq.insert_dq_nodes( graph_gs, @@ -614,6 +646,9 @@ def _quantize_awq_clip( attributes=dq_node_attributes, layer_info=layer_info, ) + # Add transpose nodes for column-major if needed + if use_column_major: + qdq.insert_transpose_nodes_for_column_major(graph_gs) if gather_w_map is not None: assert gather_s_map is not None, "scale-map not found for quantizable gather nodes" gather_dq_node_attributes = {"axis": gather_quantize_axis, "block_size": gather_block_size} @@ -1308,7 +1343,14 @@ def _quantize_awq_lite( ) t = time.time() - dq_node_attributes = {"axis": 0, "block_size": block_size} + # Apply column-major optimization if flag is set + # Transposes the weights and scales in-place + use_column_major = kwargs.get("use_column_major", False) + if use_column_major: + qdq.apply_column_major_transformation(gemm_weights_quantized, scales) + dq_node_attributes = {"axis": 1, "block_size": block_size} + else: + dq_node_attributes = {"axis": 0, "block_size": block_size} scales = reshape_scales_for_per_channel_nodes(scales, block_size, layer_info) qdq.insert_dq_nodes( graph_gs, @@ -1318,6 +1360,9 @@ def _quantize_awq_lite( zero_points=zero_points if use_zero_point else None, layer_info=layer_info, ) + # Add transpose nodes for column-major if needed + if use_column_major: + qdq.insert_transpose_nodes_for_column_major(graph_gs) if gather_w_map is not None: assert gather_s_map is not None, "scale-map not found for quantizable gather nodes" assert not use_zero_point or gather_zp_map, ( @@ -1420,10 +1465,19 @@ def quantize( Default: False. - **layers_8bit** (str): comma-separated list of layer patterns to quantize to INT8 instead of INT4. Default: []. + - **use_column_major** (bool): If True, apply column-major storage optimization for + execution providers that need it. This transposes + weights and adds Transpose nodes around MatMul operations. + Only applies to DQ-only quantization mode. + Default: False. **Returns**: A quantized ONNX model in ONNX ModelProto format. """ configure_logging(level=log_level.upper()) logger.info(f"Starting INT4 quantization with method: {calibration_method}") + + # Log if column-major optimization is enabled (works for all methods) + if kwargs.get("use_column_major", False): + logger.info("Column-major storage optimization enabled via use_column_major flag") t_start = time.time() if cupy_warning_msg: diff --git a/modelopt/onnx/quantization/qdq_utils.py b/modelopt/onnx/quantization/qdq_utils.py index 026b8d062..d8b947e72 100644 --- a/modelopt/onnx/quantization/qdq_utils.py +++ b/modelopt/onnx/quantization/qdq_utils.py @@ -1022,6 +1022,140 @@ def replace_zero_scale_with_smallest_nonzero(onnx_model: onnx.ModelProto) -> onn return onnx_model +# ============================================================================= +# Column-major weight storage transformation for execution providers that need it +# ============================================================================= + + +def _apply_transpose_perm_to_shape(shape, perm): + """Apply transpose permutation to a shape to get the output shape. + + Args: + shape: Input shape as a list/tuple + perm: Permutation indices + + Returns: + Transposed shape or None if inputs are None + """ + if shape is None or perm is None: + return None + return [shape[i] for i in perm] + + +def insert_transpose_nodes_for_column_major(graph: gs.Graph): + """Add a single Transpose node after each DequantizeLinear for column-major weights. + + This implements the simple transformation: A @ B = A @ ((B^T)^T) + where B^T is stored in the DequantizeLinear node, and we add a Transpose + node after DQ to recover B before the MatMul. + + Graph transformation: + Before: DQ(W) -> MatMul/Gemm + After: DQ(W^T) -> Transpose -> W -> MatMul/Gemm + + Args: + graph: ONNX GraphSurgeon graph to modify in-place + """ + nodes_to_add = [] + dq_nodes_processed = set() + + for node in graph.nodes: + if node.op in ["MatMul", "Gemm"]: + # Check if second input (weight) is from DequantizeLinear + weight_input = node.inputs[1] + if not isinstance(weight_input, gs.Variable): + continue + + # Find the producer of the weight input + producer_nodes = [n for n in graph.nodes if weight_input in n.outputs] + if not producer_nodes: + continue + + producer_node = producer_nodes[0] + if producer_node.op != DEQUANTIZE_NODE_NAME: + continue + + # Skip if we already processed this DQ node + if producer_node.name in dq_nodes_processed: + continue + dq_nodes_processed.add(producer_node.name) + + # For Gemm nodes, check if transB is already set + if node.op == "Gemm": + trans_b = False + if hasattr(node, "attrs") and "transB" in node.attrs: + trans_b = node.attrs["transB"] > 0 + if trans_b: + logger.debug(f"Gemm node {node.name} already has transB=1, skipping") + continue + + # Get weight shape and dtype from DQ output + # DQ outputs W^T (transposed), shape is [N, K] instead of [K, N] + weight_shape = weight_input.shape if hasattr(weight_input, "shape") else None + weight_dtype = weight_input.dtype if hasattr(weight_input, "dtype") else None + + # Permutation for 2D weights: [1, 0] to transpose back + # The stored weight is B^T (transposed), we need to get B back + # For 2D [N, K] (stored as transposed): perm [1, 0] -> [K, N] (original) + perm = [1, 0] + + # Compute the transposed shape (original weight shape) + transposed_weight_shape = _apply_transpose_perm_to_shape(weight_shape, perm) + + # Create output variable for the transpose node + transpose_out = gs.Variable( + f"{producer_node.name}_transposed_back", + dtype=weight_dtype, + shape=transposed_weight_shape, + ) + + # Create transpose node: (B^T)^T = B + transpose_node = gs.Node( + op="Transpose", + name=f"{producer_node.name}_transpose_back", + inputs=[weight_input], + outputs=[transpose_out], + attrs={"perm": perm}, + ) + + # Update MatMul/Gemm to use the transposed weight + node.inputs[1] = transpose_out + + # Add transpose node to list + nodes_to_add.append(transpose_node) + + # Add all new nodes to graph + if nodes_to_add: + graph.nodes.extend(nodes_to_add) + logger.info(f"Added {len(nodes_to_add)} transpose nodes for column-major optimization") + + # Clean up and reorder graph + graph.cleanup().toposort() + + +def apply_column_major_transformation( + gemm_weights_quantized: dict, + scales: dict, +) -> None: + """Transpose quantized weights and scales in-place for column-major storage. + + Note: After calling this function and inserting DQ nodes with axis=1, + you should call insert_transpose_nodes_for_column_major() on the graph. + + Args: + gemm_weights_quantized: Dictionary mapping weight names to quantized weight arrays + scales: Dictionary mapping weight names to scale arrays + """ + logger.info("Applying column-major storage optimization") + + # Transpose weights and scales in-place + for name in list(gemm_weights_quantized.keys()): + gemm_weights_quantized[name] = gemm_weights_quantized[name].T + + for name in list(scales.keys()): + scales[name] = scales[name].T + + def cast_initializer_to_dtype( node: onnx.NodeProto, dtype: str, initializer_map: dict[str, onnx.TensorProto] ): diff --git a/tests/unit/onnx/test_qdq_utils.py b/tests/unit/onnx/test_qdq_utils.py index 2acc4046a..9127f3905 100644 --- a/tests/unit/onnx/test_qdq_utils.py +++ b/tests/unit/onnx/test_qdq_utils.py @@ -630,3 +630,238 @@ def test_fp4qdq_conversion(self, with_transpose): # Verify Cast nodes are added for input type conversion cast_nodes = [node for node in converted_model.graph.node if node.op_type == "Cast"] assert len(cast_nodes) >= 1 # At least one cast node should be added + + +def create_test_model_with_int4_dq_matmul(): + """Create a simple test model with INT4 DequantizeLinear -> MatMul pattern. + + Returns the model and original weight/scale arrays for verification. + """ + from modelopt.onnx.quantization.quant_utils import pack_float32_to_4bit_cpp_based + + # Create INT4 quantized weight tensor (K=32, N=16) + # Using int8 storage for INT4 values in range [-8, 7] + weight_data = np.random.randint(-8, 8, size=(32, 16), dtype=np.int8) + + # Pack INT4 data (2 values per byte) for ORT compatibility + packed_weight = pack_float32_to_4bit_cpp_based(weight_data, signed=True).astype(np.int8) + weight_tensor = helper.make_tensor( + "weight", + TensorProto.INT4, + dims=weight_data.shape, + vals=packed_weight.tobytes(), + raw=True, + ) + + # Create scale tensor for block quantization (block_size=32, so 1 scale per column) + scale_data = np.random.uniform(0.1, 1.0, size=(1, 16)).astype(np.float16) + scale_tensor = numpy_helper.from_array(scale_data, "scale") + + # Create input tensor for MatMul (batch=4, K=32) + input_tensor = helper.make_tensor_value_info("input", TensorProto.FLOAT16, [4, 32]) + + # Create DequantizeLinear node with INT4 blocked quantization + dq_node = helper.make_node( + "DequantizeLinear", + inputs=["weight", "scale"], + outputs=["dq_output"], + name="weight_dq", + axis=0, + block_size=32, + ) + + # Create MatMul node: input (4, 32) @ weight (32, 16) -> output (4, 16) + matmul_node = helper.make_node( + "MatMul", + inputs=["input", "dq_output"], + outputs=["output"], + name="matmul", + ) + + graph = helper.make_graph( + nodes=[dq_node, matmul_node], + name="test_graph", + inputs=[input_tensor], + outputs=[helper.make_tensor_value_info("output", TensorProto.FLOAT16, [4, 16])], + initializer=[weight_tensor, scale_tensor], + ) + + model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 21)]) + model.ir_version = 10 # ORT only supports IR version up to 10 + return model, weight_data, scale_data + + +class TestColumnMajorTransformation: + """Test suite for column-major storage transformation functions.""" + + def test_column_major_transformation_graph_structure(self): + """Test that column-major transformation produces correct graph structure. + + Verifies: DQ(W) -> MatMul becomes DQ(W^T) -> Transpose -> MatMul + """ + import onnx_graphsurgeon as gs + + from modelopt.onnx.quantization.qdq_utils import ( + apply_column_major_transformation, + insert_transpose_nodes_for_column_major, + ) + + model, original_weight, original_scale = create_test_model_with_int4_dq_matmul() + + # Get weights and scales as dicts (simulating what int4.py does) + weights_dict = {"weight": original_weight.copy()} + scales_dict = {"scale": original_scale.copy()} + + # Apply column-major transformation (transposes in-place) + apply_column_major_transformation(weights_dict, scales_dict) + + # Verify weights and scales are transposed + assert weights_dict["weight"].shape == (16, 32), ( + f"Expected transposed weight shape (16, 32), got {weights_dict['weight'].shape}" + ) + assert scales_dict["scale"].shape == (16, 1), ( + f"Expected transposed scale shape (16, 1), got {scales_dict['scale'].shape}" + ) + + # Verify the transposed values match + assert np.array_equal(weights_dict["weight"], original_weight.T) + assert np.array_equal(scales_dict["scale"], original_scale.T) + + # Now test insert_transpose_nodes_for_column_major on a graph + # Create a fresh model and apply the full transformation + model2, _, _ = create_test_model_with_int4_dq_matmul() + graph2 = gs.import_onnx(model2) + + # Add transpose nodes for column-major + insert_transpose_nodes_for_column_major(graph2) + + # Export and verify structure + transformed_model = gs.export_onnx(graph2) + + # Check that Transpose node was added + node_types = [node.op_type for node in transformed_model.graph.node] + assert "Transpose" in node_types, "Transpose node should be added after DQ" + assert "DequantizeLinear" in node_types + assert "MatMul" in node_types + + # Verify the order: DQ -> Transpose -> MatMul + dq_node = next(n for n in transformed_model.graph.node if n.op_type == "DequantizeLinear") + transpose_node = next(n for n in transformed_model.graph.node if n.op_type == "Transpose") + matmul_node = next(n for n in transformed_model.graph.node if n.op_type == "MatMul") + + # DQ output should be Transpose input + assert dq_node.output[0] == transpose_node.input[0], "DQ output should feed into Transpose" + # Transpose output should be MatMul weight input + assert transpose_node.output[0] == matmul_node.input[1], ( + "Transpose output should feed into MatMul" + ) + + # Verify transpose permutation is [1, 0] + perm_attr = next((a for a in transpose_node.attribute if a.name == "perm"), None) + assert perm_attr is not None, "Transpose should have perm attribute" + assert list(perm_attr.ints) == [1, 0], "Transpose perm should be [1, 0]" + + def test_column_major_transformation_output_equivalence(self): + """Test that column-major transformed graph produces equivalent output. + + Creates two graphs: + 1. Original: DQ(W) -> MatMul + 2. Transformed: DQ(W^T) -> Transpose -> MatMul + + Verifies both produce the same output for the same input. + """ + import onnxruntime as ort + + from modelopt.onnx.quantization.quant_utils import pack_float32_to_4bit_cpp_based + + # Create original model + original_model, original_weight, original_scale = create_test_model_with_int4_dq_matmul() + + # Create input data + input_data = np.random.randn(4, 32).astype(np.float16) + + # Run original model + original_session = ort.InferenceSession(original_model.SerializeToString()) + original_output = original_session.run(None, {"input": input_data})[0] + + # Create transformed model + # We need to manually create a model with transposed weights + transposed_weight = original_weight.T.copy() # Shape: (16, 32) + transposed_scale = original_scale.T.copy() # Shape: (16, 1) + + # Pack INT4 data (2 values per byte) for ORT compatibility + packed_transposed_weight = pack_float32_to_4bit_cpp_based( + transposed_weight, signed=True + ).astype(np.int8) + weight_tensor = helper.make_tensor( + "weight", + TensorProto.INT4, + dims=transposed_weight.shape, + vals=packed_transposed_weight.tobytes(), + raw=True, + ) + scale_tensor = numpy_helper.from_array(transposed_scale, "scale") + + input_tensor = helper.make_tensor_value_info("input", TensorProto.FLOAT16, [4, 32]) + + # DQ node with axis=1 for column-major (transposed weight) + dq_node = helper.make_node( + "DequantizeLinear", + inputs=["weight", "scale"], + outputs=["dq_output"], + name="weight_dq", + axis=1, + block_size=32, + ) + + # Transpose node to convert back: (16, 32) -> (32, 16) + transpose_node = helper.make_node( + "Transpose", + inputs=["dq_output"], + outputs=["transpose_output"], + name="transpose_back", + perm=[1, 0], + ) + + # MatMul: input (4, 32) @ transposed_back (32, 16) -> output (4, 16) + matmul_node = helper.make_node( + "MatMul", + inputs=["input", "transpose_output"], + outputs=["output"], + name="matmul", + ) + + transformed_graph = helper.make_graph( + nodes=[dq_node, transpose_node, matmul_node], + name="test_graph", + inputs=[input_tensor], + outputs=[helper.make_tensor_value_info("output", TensorProto.FLOAT16, [4, 16])], + initializer=[weight_tensor, scale_tensor], + ) + + transformed_model = helper.make_model( + transformed_graph, opset_imports=[helper.make_opsetid("", 21)] + ) + transformed_model.ir_version = 10 # ORT only supports IR version up to 10 + + # Run transformed model + transformed_session = ort.InferenceSession(transformed_model.SerializeToString()) + transformed_output = transformed_session.run(None, {"input": input_data})[0] + + # Print output values for visibility + print(f"Original model output shape: {original_output.shape}") + print(f"Transformed model output shape: {transformed_output.shape}") + print(f"Original output (first 5): {original_output.flatten()[:5]}") + print(f"Transformed output (first 5): {transformed_output.flatten()[:5]}") + + # Verify outputs are equivalent (allowing small numerical tolerance) + assert original_output.shape == transformed_output.shape, ( + f"Output shapes should match: {original_output.shape} vs {transformed_output.shape}" + ) + np.testing.assert_allclose( + original_output, + transformed_output, + rtol=1e-3, + atol=1e-3, + err_msg="Column-major transformed model should produce equivalent output", + )