From b3483506b737da895489cc718812caaba2fe3562 Mon Sep 17 00:00:00 2001 From: Will Guo Date: Mon, 15 Dec 2025 09:08:33 +0000 Subject: [PATCH] Integrate Automated QDQ placement tool - part 3 Signed-off-by: Will Guo --- .../onnx/quantization/autotune/__init__.py | 61 + .../onnx/quantization/autotune/__main__.py | 24 + .../onnx/quantization/autotune/autotuner.py | 1105 +++++++++++++++++ .../onnx/quantization/autotune/benchmark.py | 780 ++++++++++++ modelopt/onnx/quantization/autotune/cli.py | 294 +++++ .../onnx/quantization/autotune/workflows.py | 417 +++++++ modelopt/onnx/quantization/graph_utils.py | 17 + modelopt/onnx/quantization/qdq_utils.py | 44 + .../quantization/autotune/test_autotuner.py | 409 ++++++ .../onnx/quantization/autotune/test_config.py | 144 +++ 10 files changed, 3295 insertions(+) create mode 100644 modelopt/onnx/quantization/autotune/__init__.py create mode 100644 modelopt/onnx/quantization/autotune/__main__.py create mode 100644 modelopt/onnx/quantization/autotune/autotuner.py create mode 100644 modelopt/onnx/quantization/autotune/benchmark.py create mode 100644 modelopt/onnx/quantization/autotune/cli.py create mode 100644 modelopt/onnx/quantization/autotune/workflows.py create mode 100644 tests/unit/onnx/quantization/autotune/test_autotuner.py create mode 100644 tests/unit/onnx/quantization/autotune/test_config.py diff --git a/modelopt/onnx/quantization/autotune/__init__.py b/modelopt/onnx/quantization/autotune/__init__.py new file mode 100644 index 000000000..aa520e276 --- /dev/null +++ b/modelopt/onnx/quantization/autotune/__init__.py @@ -0,0 +1,61 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Pattern-based Q/DQ autotuning for ONNX models. + +Optimizes Q/DQ node placement in ONNX graphs to minimize TensorRT inference latency +using hierarchical region analysis and pattern-based scheme reuse. +""" + +from modelopt.onnx.quantization.autotune.common import ( + AutotunerError, + AutotunerNotInitializedError, + Config, + InsertionScheme, + InvalidSchemeError, + PatternCache, + PatternSchemes, + Region, + RegionError, + RegionType, +) + +from .insertion_points import ( + ChildRegionInputInsertionPoint, + NodeInputInsertionPoint, + RegionOutputInsertionPoint, + ResolvedInsertionPoint, +) +from .region_pattern import RegionPattern +from .region_search import CombinedRegionSearch + +__all__ = [ + "AutotunerError", + "AutotunerNotInitializedError", + "ChildRegionInputInsertionPoint", + "CombinedRegionSearch", + "Config", + "InsertionScheme", + "InvalidSchemeError", + "NodeInputInsertionPoint", + "PatternCache", + "PatternSchemes", + "Region", + "RegionError", + "RegionOutputInsertionPoint", + "RegionPattern", + "RegionType", + "ResolvedInsertionPoint", +] diff --git a/modelopt/onnx/quantization/autotune/__main__.py b/modelopt/onnx/quantization/autotune/__main__.py new file mode 100644 index 000000000..9071af99e --- /dev/null +++ b/modelopt/onnx/quantization/autotune/__main__.py @@ -0,0 +1,24 @@ +#!/usr/bin/env python3 +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Command-line interface for ONNX Q/DQ autotuning.""" + +import sys + +from modelopt.onnx.quantization.autotune.cli import run_autotune + +if __name__ == "__main__": + sys.exit(run_autotune()) diff --git a/modelopt/onnx/quantization/autotune/autotuner.py b/modelopt/onnx/quantization/autotune/autotuner.py new file mode 100644 index 000000000..38fed560b --- /dev/null +++ b/modelopt/onnx/quantization/autotune/autotuner.py @@ -0,0 +1,1105 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Automatic Q/DQ insertion optimization for ONNX models via pattern-based profiling.""" + +import copy +import os +import random +from collections import deque +from datetime import datetime, timezone + +import numpy as np +import onnx +import onnx_graphsurgeon as gs +import yaml + +from modelopt.onnx.logging_config import logger +from modelopt.onnx.quantization.autotune.common import ( + AutotunerNotInitializedError, + Config, + InsertionScheme, + InvalidSchemeError, + PatternCache, + PatternSchemes, + Region, + RegionType, +) +from modelopt.onnx.quantization.autotune.insertion_points import ( + ResolvedInsertionPoint, + merge_resolved_insertion_points, +) +from modelopt.onnx.quantization.autotune.region_pattern import RegionPattern +from modelopt.onnx.quantization.autotune.region_search import CombinedRegionSearch +from modelopt.onnx.quantization.fp8 import int8_to_fp8 +from modelopt.onnx.quantization.graph_utils import get_tensor_consumer_node_indices + + +class QDQAutotunerBase: + """Base class for pattern-based Q/DQ node insertion optimization in ONNX models.""" + + def __init__(self, model: onnx.ModelProto | gs.Graph): + """Initialize the autotuner with an ONNX model.""" + if isinstance(model, onnx.ModelProto): + self.onnx_model = model + elif isinstance(model, gs.Graph): + self.onnx_model = gs.export_onnx(model) + else: + raise TypeError(f"Expected onnx.ModelProto or gs.Graph, got {type(model)}") + + self.graph = self._copy_graph() + self.graph.tensor_users_map = get_tensor_consumer_node_indices(self.graph) + + self.regions: list[Region] = [] + self.current_profile_region: Region | None = None + + self.profiled_patterns: list[PatternSchemes] = [] + self.current_profile_pattern_schemes: PatternSchemes | None = None + + self.current_insertion_scheme_index: int | None = None + + self.config = Config() + self.initialized = False + self.baseline_latency_ms: float | None = None + + self.pattern_cache: PatternCache | None = None + + logger.debug(f"Initialized autotuner with model type: {type(model).__name__}") + + def initialize( + self, config: Config | None = None, pattern_cache: PatternCache | None = None + ) -> None: + """Initialize autotuning session with configuration and pattern cache.""" + if config is not None: + self.config = config + + if pattern_cache is None: + pattern_cache = PatternCache( + minimum_distance=self.config.pattern_cache_minimum_distance, + max_entries_per_pattern=self.config.pattern_cache_max_entries_per_pattern, + ) + self.pattern_cache = pattern_cache + + logger.debug( + f"Loaded pattern cache with {pattern_cache.num_patterns} patterns and " + f"{pattern_cache.total_schemes} schemes" + ) + + self.initialized = False + self.baseline_latency_ms = None + self.profiled_patterns.clear() + self.regions.clear() + self.current_profile_region = None + self.current_profile_pattern_schemes = None + self.current_insertion_scheme_index = None + + logger.info("Initializing autotuner") + logger.debug( + f"Configuration: q_scale={self.config.default_q_scale}, " + f"q_zero_point={self.config.default_q_zero_point}, quant_type={self.config.default_quant_type}" + ) + + self.initialized = True + + def set_profile_region(self, region: Region | None, commit: bool = True) -> None: + """Set the target region for profiling and scheme generation.""" + if not self.initialized: + raise AutotunerNotInitializedError( + "QDQAutotunerBase not initialized. Call initialize() first." + ) + + if commit: + if self.current_profile_pattern_schemes is not None: + num_schemes = len(self.current_profile_pattern_schemes.schemes) + best_scheme = self.current_profile_pattern_schemes.best_scheme + best_latency = best_scheme.latency_ms if best_scheme else float("inf") + + samples_before_best, time_to_best = self._compute_convergence_metrics( + self.current_profile_pattern_schemes.schemes, best_scheme + ) + + logger.info( + f"Pattern complete: {num_schemes} schemes tested, best latency {best_latency:.3f} ms" + ) + logger.debug( + f"Pattern signature: {self.current_profile_pattern_schemes.pattern_signature}" + ) + if samples_before_best is not None: + logger.debug(f"Convergence: best found at sample {samples_before_best}") + if time_to_best is not None: + logger.debug(f"Time to best: {time_to_best:.2f}s") + self.profiled_patterns.append(self.current_profile_pattern_schemes) + + if commit or region is None: + self.current_profile_region = None + self.current_profile_pattern_schemes = None + self.current_insertion_scheme_index = None + if region is None: + return + + if region not in self.regions: + raise ValueError(f"Region {region.id} not found in regions") + + region_pattern = RegionPattern.from_region(region, self.graph) + + if self._is_region_profiled(region): + logger.info(f"Skipping region {region.id} (pattern already profiled)") + logger.debug(f"Pattern signature: {region_pattern.signature}") + return + + pattern_schemes = None + num_seeded = 0 + + if self.pattern_cache is not None: + cache_schemes = self.pattern_cache.get_pattern_schemes(region_pattern.signature) + + if cache_schemes is not None and len(cache_schemes.schemes) > 0: + pattern_schemes = PatternSchemes() + pattern_schemes.pattern = region_pattern + + for cached_scheme in cache_schemes.schemes: + scheme_copy = copy.deepcopy(cached_scheme) + scheme_copy.latency_ms = float("inf") + scheme_copy.error = False + pattern_schemes.schemes.append(scheme_copy) + num_seeded += 1 + + logger.debug(f"Seeded {num_seeded} scheme(s) from pattern cache") + else: + logger.debug("No pattern cache entries for this region") + + if pattern_schemes is None: + pattern_schemes = PatternSchemes() + pattern_schemes.pattern = region_pattern + logger.debug("Initialized with empty scheme collection") + + self.current_profile_region = region + self.current_profile_pattern_schemes = pattern_schemes + + mode_info = f"seeded with {num_seeded} schemes" if num_seeded > 0 else "starting fresh" + logger.info( + f"Profiling region {region.id} [pattern mode, level {region.get_level()}, " + f"size {region.get_size()}, {mode_info}]" + ) + logger.debug(f"Pattern signature: {region_pattern.signature}") + + def generate(self) -> int: + """Generate a new Q/DQ insertion scheme for the current pattern or region.""" + if not self.initialized: + raise AutotunerNotInitializedError( + "QDQAutotunerBase not initialized. Call initialize() first." + ) + elif self.current_profile_pattern_schemes is None: + raise InvalidSchemeError("No region selected. Call set_profile_region() first.") + + pattern_schemes = self.current_profile_pattern_schemes + cached_schemes = [ + (idx, scheme) + for idx, scheme in enumerate(pattern_schemes.schemes) + if not scheme.is_profiled + ] + + if cached_schemes: + scheme_index, cached_scheme_data = cached_schemes[0] + num_node_points = len(cached_scheme_data.node_inputs) + num_region_composite_points = len(cached_scheme_data.child_region_inputs) + num_region_output_points = len(cached_scheme_data.region_outputs) + total_points = num_node_points + num_region_composite_points + num_region_output_points + + logger.info( + f"Scheme #{scheme_index + 1}: profiling cached scheme ({total_points} Q/DQ points)" + ) + logger.debug( + f"Cached scheme breakdown: {num_node_points} node input, " + f"{num_region_composite_points} region composite, " + f"{num_region_output_points} region output points ({len(cached_schemes)} cached schemes remaining)" + ) + + self.current_insertion_scheme_index = scheme_index + return self.current_insertion_scheme_index + + known_schemes = {scheme.hash for scheme in pattern_schemes.schemes} + max_attempts = getattr(self.config, "maximum_generation_attempts", 100) + + logger.debug(f"Generating new scheme ({len(pattern_schemes.schemes)} schemes exist)") + + for attempts in range(max_attempts): + new_scheme = self._generate_next_insertion_sample() + if new_scheme.hash not in known_schemes and not new_scheme.error: + pattern_schemes.schemes.append(new_scheme) + scheme_index = len(pattern_schemes.schemes) - 1 + num_node_points = len(new_scheme.node_inputs) + num_region_composite_points = len(new_scheme.child_region_inputs) + num_region_output_points = len(new_scheme.region_outputs) + total_points = ( + num_node_points + num_region_composite_points + num_region_output_points + ) + + logger.info( + f"Scheme #{scheme_index + 1}: generated new scheme ({total_points} Q/DQ points)" + ) + logger.debug( + f"Scheme breakdown: {num_node_points} node input, " + f"{num_region_composite_points} region composite, " + f"{num_region_output_points} region output points " + f"(hash: {new_scheme.hash[:16]}..., attempts: {attempts + 1})" + ) + + self.current_insertion_scheme_index = scheme_index + return self.current_insertion_scheme_index + + logger.warning(f"Could not generate unique scheme after {max_attempts} attempts") + return -1 + + def export_onnx( + self, output_path: str | None = None, insert_qdq: bool = True, best: bool = False + ) -> bytes: + """Export ONNX model with Q/DQ nodes inserted according to tested schemes.""" + if not self.initialized: + raise AutotunerNotInitializedError( + "QDQAutotunerBase not initialized. Call initialize() first." + ) + + output_desc = output_path if output_path is not None else "" + original_quant_type = self.config.default_quant_type + needs_fp8_conversion = insert_qdq and original_quant_type == "fp8" + resolved_insertion_points = set() + + logger.debug( + f"Exporting model to {output_desc} (insert_qdq={insert_qdq}, " + f"regions={len(self.regions)}, profiled_patterns={len(self.profiled_patterns)})" + ) + + # Temporarily set quant type to int8 if FP8 is requested + if needs_fp8_conversion: + logger.debug("FP8 conversion: creating INT8 model first") + self.config.default_quant_type = "int8" + + if insert_qdq: + matched_regions = 0 + + logger.debug(f"Resolving Q/DQ insertion points from {len(self.regions)} regions") + + for region in self.regions: + pattern = RegionPattern.from_region(region, self.graph) + logger.debug(f"Region {region.id} (level {region.level})") + logger.debug(f" → Pattern signature: {pattern.signature}") + + matched = next((ps for ps in self.profiled_patterns if ps.pattern == pattern), None) + current_scheme = matched.best_scheme if matched else None + + if matched: + if current_scheme: + logger.debug( + f" → Matched profiled pattern (latency={current_scheme.latency_ms:.3f} ms)" + ) + else: + logger.debug(" → Matched profiled pattern but no valid schemes") + + if current_scheme is None: + current_scheme = self.current_profile_pattern_schemes + if current_scheme is None or pattern != current_scheme.pattern: + pass + elif best: + current_scheme = current_scheme.best_scheme + else: + scheme_index = self.current_insertion_scheme_index + if scheme_index is not None: + assert scheme_index < len(current_scheme.schemes), ( + f"Invalid scheme index: {scheme_index}" + ) + current_scheme = current_scheme.schemes[scheme_index] + logger.debug(f" → Using current pattern scheme #{scheme_index}") + + if current_scheme is None and self.pattern_cache is not None: + pattern_schemes = self.pattern_cache.get_pattern_schemes(pattern.signature) + if pattern_schemes is not None: + schemes = pattern_schemes.schemes + if schemes is not None and len(schemes) == 1 and not schemes[0].is_profiled: + current_scheme = schemes[0] + logger.debug(" → Using imported pattern from cache") + + if current_scheme is None: + logger.debug(" → No scheme available, skipping") + continue + + full_insertion_scheme = pattern.get_full_insertion_scheme(region, self.graph) + assert full_insertion_scheme is not None + all_region_ips = pattern.matches(region, self.graph, full_insertion_scheme) + assert isinstance(all_region_ips, set) + resolved_insertion_points.difference_update(all_region_ips) + excluded_tensors = all_region_ips - resolved_insertion_points + if excluded_tensors: + logger.debug( + f" → Excluded {len(excluded_tensors)} overlapping insertion points" + ) + + new_ips = pattern.matches(region, self.graph, current_scheme) + if new_ips: + resolved_insertion_points.update(new_ips) + matched_regions += 1 + logger.debug(f" → Added {len(new_ips)} insertion points") + + logger.debug( + f"Matched {matched_regions}/{len(self.regions)} regions, " + f"total {len(resolved_insertion_points)} unique insertion points" + ) + + graph_copy = self._copy_graph() + unique_tensors = len(resolved_insertion_points) + + logger.debug(f"Inserting {unique_tensors} Q/DQ pairs into graph") + + if insert_qdq and resolved_insertion_points: + self._insert_qdq_at_tensors(graph_copy, resolved_insertion_points) + + logger.debug("Serializing to ONNX format") + model = gs.export_onnx(graph_copy) + + if insert_qdq and resolved_insertion_points: + self._fix_zero_point_initializers(model) + + if needs_fp8_conversion: + logger.debug("Converting INT8 to FP8") + model = int8_to_fp8(model) + + self.config.default_quant_type = original_quant_type + model_bytes = model.SerializeToString() + quant_type_str = "baseline" + output_dest = "" + + if insert_qdq: + quant_type_str = f"{original_quant_type.upper()}" if needs_fp8_conversion else "INT8" + + if output_path is not None: + onnx.save(model, output_path) + output_dest = f" → {output_path}" + + logger.info( + f"Exported {quant_type_str} model with {unique_tensors} Q/DQ pairs {output_dest}" + ) + return model_bytes + + def submit(self, latency_ms: float, success: bool = True) -> None: + """Submit performance measurement for the most recently generated scheme.""" + if not self.initialized: + raise AutotunerNotInitializedError( + "QDQAutotunerBase not initialized. Call initialize() first." + ) + + if self.baseline_latency_ms is None: + self.baseline_latency_ms = latency_ms + logger.info(f"Baseline latency: {latency_ms:.3f} ms") + return + + if self.current_profile_pattern_schemes is None: + raise InvalidSchemeError( + "No pattern or region selected. Call set_profile_region() first." + ) + + schemes_collection = self.current_profile_pattern_schemes + if not schemes_collection.schemes: + raise InvalidSchemeError("No schemes available. Call generate() first.") + + pattern_schemes = schemes_collection + + if self.current_insertion_scheme_index is not None: + scheme_index = self.current_insertion_scheme_index + if scheme_index >= len(pattern_schemes.schemes): + raise InvalidSchemeError(f"Invalid scheme index: {scheme_index}") + scheme = pattern_schemes.schemes[scheme_index] + else: + scheme = pattern_schemes.schemes[-1] + scheme_index = len(pattern_schemes.schemes) - 1 + + scheme.latency_ms = latency_ms + scheme.error = not success + scheme.profile_timestamp = datetime.now(timezone.utc).isoformat() + display_index = scheme_index + 1 + + if not success: + logger.warning( + f"Scheme #{display_index}: measurement failed (latency={latency_ms:.3f} ms)" + ) + logger.debug("Marking scheme with error flag") + return + + speedup = self.baseline_latency_ms / latency_ms if latency_ms > 0 else 0.0 + + logger.info(f"Scheme #{display_index}: {latency_ms:.3f} ms ({speedup:.2f}x speedup)") + logger.debug(f"Compared to baseline: {self.baseline_latency_ms:.3f} ms") + + old_best = ( + pattern_schemes.schemes[0].latency_ms if pattern_schemes.schemes else float("inf") + ) + pattern_schemes.schemes.sort( + key=lambda s: s.latency_ms if s.latency_ms > 0 else float("inf") + ) + new_best = ( + pattern_schemes.schemes[0].latency_ms if pattern_schemes.schemes else float("inf") + ) + + if new_best < old_best: + new_speedup = self.baseline_latency_ms / new_best if new_best > 0 else 0.0 + logger.info(f" ★ New best: {new_best:.3f} ms ({new_speedup:.2f}x speedup)") + logger.debug(f"Previous best: {old_best:.3f} ms") + + if self.current_profile_pattern_schemes is not None and self.pattern_cache is not None: + self.pattern_cache.add_pattern_schemes(pattern_schemes) + logger.debug( + f"Pattern cache updated: {self.pattern_cache.num_patterns} patterns, " + f"{self.pattern_cache.total_schemes} schemes" + ) + + def save_state(self, output_path: str) -> None: + """Save complete autotuner state to a YAML file for later reuse.""" + current_pattern_sig = None + if self.current_profile_pattern_schemes is not None: + current_pattern_sig = self.current_profile_pattern_schemes.pattern_signature + + state = { + "baseline_latency_ms": self.baseline_latency_ms, + "current_profile_pattern_schemes_signature": current_pattern_sig, + "config": { + "default_q_scale": self.config.default_q_scale, + "default_q_zero_point": self.config.default_q_zero_point, + "default_quant_type": self.config.default_quant_type, + "verbose": self.config.verbose, + }, + "patterns": [pattern_schemes.to_dict() for pattern_schemes in self.profiled_patterns], + } + + with open(output_path, "w") as f: + yaml.dump(state, f, default_flow_style=False, sort_keys=False) + + num_patterns = len(self.profiled_patterns) + total_schemes = sum(len(p.schemes) for p in self.profiled_patterns) + + logger.info( + f"Saved state → {output_path} ({num_patterns} patterns, {total_schemes} schemes)" + ) + logger.debug(f"State: baseline={self.baseline_latency_ms:.3f} ms") + + if self.pattern_cache is not None and self.pattern_cache.num_patterns > 0: + base_path, ext = os.path.splitext(output_path) + cache_path = f"{base_path}_pattern_cache{ext}" + self.pattern_cache.save(cache_path) + + logger.info(f"Saved pattern cache → {cache_path}") + logger.debug( + f"Cache: {self.pattern_cache.num_patterns} patterns, " + f"{self.pattern_cache.total_schemes} schemes" + ) + + def load_state(self, input_path: str) -> None: + """Load autotuner state from a previously saved YAML file.""" + if not self.initialized: + raise AutotunerNotInitializedError( + "QDQAutotunerBase not initialized. Call initialize() first." + ) + + with open(input_path) as f: + state = yaml.safe_load(f) + + if state.get("baseline_latency_ms") is not None: + self.baseline_latency_ms = state["baseline_latency_ms"] + logger.debug(f"Baseline latency: {self.baseline_latency_ms:.3f} ms") + + if "config" in state: + config_data = state["config"] + if "default_q_scale" in config_data: + self.config.default_q_scale = config_data["default_q_scale"] + if "default_q_zero_point" in config_data: + self.config.default_q_zero_point = config_data["default_q_zero_point"] + if "default_quant_type" in config_data: + self.config.default_quant_type = config_data["default_quant_type"] + if "verbose" in config_data: + self.config.verbose = config_data["verbose"] + logger.debug(f"Config merged: quant_type={self.config.default_quant_type}") + + if "patterns" in state: + num_loaded_patterns = 0 + num_loaded_schemes = 0 + + for pattern_data in state["patterns"]: + try: + pattern_schemes = PatternSchemes.from_dict(pattern_data) + + if pattern_schemes.schemes: + self.profiled_patterns.append(pattern_schemes) + num_loaded_patterns += 1 + num_loaded_schemes += len(pattern_schemes.schemes) + else: + logger.debug( + f"Skipped empty pattern {pattern_schemes.pattern_signature[:16]}..." + ) + + except Exception as e: # noqa: PERF203 + logger.warning(f"Failed to load pattern: {e}") + continue + + logger.info( + f"Loaded state from {input_path} ({num_loaded_patterns} patterns, " + f"{num_loaded_schemes} schemes)" + ) + + base_path, ext = os.path.splitext(input_path) + cache_path = f"{base_path}_pattern_cache{ext}" + + if os.path.exists(cache_path): + try: + loaded_cache = PatternCache.load(cache_path) + + if self.pattern_cache is not None: + for pattern_schemes in loaded_cache.pattern_schemes: + self.pattern_cache.add_pattern_schemes(pattern_schemes) + else: + self.pattern_cache = loaded_cache + logger.info( + f"Loaded pattern cache from {cache_path} ({loaded_cache.num_patterns} patterns, " + f"{loaded_cache.total_schemes} schemes)" + ) + except Exception as e: + logger.warning(f"Failed to load pattern cache: {e}") + else: + logger.debug(f"No pattern cache file at {cache_path}") + + def import_insertion_points(self, quantized_tensors: set[str] | list[str]) -> None: + """Import Q/DQ insertion points from a list of quantized tensors and update pattern cache.""" + if not self.initialized: + raise AutotunerNotInitializedError( + "QDQAutotunerBase not initialized. Call initialize() first." + ) + + if isinstance(quantized_tensors, list): + quantized_tensors = set(quantized_tensors) + + logger.info(f"Importing insertion points from {len(quantized_tensors)} quantized tensors") + logger.debug(f"Processing {len(self.regions)} regions") + + if self.pattern_cache is None: + logger.warning("Pattern cache not initialized, skipping import") + return + + patterns_before = self.pattern_cache.num_patterns + schemes_before = self.pattern_cache.total_schemes + + for region in self.regions: + self.pattern_cache.add_pattern_from_region(region, self.graph, quantized_tensors) + + patterns_added = self.pattern_cache.num_patterns - patterns_before + schemes_added = self.pattern_cache.total_schemes - schemes_before + + logger.info( + f"Import complete: {patterns_added} patterns, {schemes_added} schemes added to cache" + ) + logger.debug( + f"Total cache: {self.pattern_cache.num_patterns} patterns, " + f"{self.pattern_cache.total_schemes} schemes" + ) + + def _compute_convergence_metrics( + self, schemes: list[InsertionScheme], best_scheme: InsertionScheme | None + ) -> tuple[int | None, float | None]: + """Compute convergence metrics for a collection of schemes.""" + samples_before_best = None + time_to_best = None + + if not best_scheme or not best_scheme.profile_timestamp: + return samples_before_best, time_to_best + + schemes_with_time = [s for s in schemes if s.profile_timestamp is not None] + + if not schemes_with_time: + return samples_before_best, time_to_best + + from datetime import datetime + + schemes_with_time.sort(key=lambda s: s.profile_timestamp or "") + + try: + best_position = next( + i for i, s in enumerate(schemes_with_time) if s.hash == best_scheme.hash + ) + samples_before_best = best_position + + first_ts = schemes_with_time[0].profile_timestamp + best_ts = best_scheme.profile_timestamp + assert first_ts is not None and best_ts is not None + first_timestamp = datetime.fromisoformat(first_ts) + best_timestamp = datetime.fromisoformat(best_ts) + time_to_best = (best_timestamp - first_timestamp).total_seconds() + except (StopIteration, ValueError): + pass + + return samples_before_best, time_to_best + + def _is_region_profiled(self, region: Region) -> bool: + """Check if a region's pattern has already been fully profiled.""" + + def match_pattern(pattern: PatternSchemes, region: Region) -> bool: + """Check if a pattern matches a region.""" + if pattern.pattern is None or not pattern.pattern.matches(region, self.graph): + return False + return not any(not scheme.is_profiled for scheme in pattern.schemes) + + return any(match_pattern(pattern, region) for pattern in self.profiled_patterns) + + def _mutate_insertion_points( + self, base_points, all_points, point_type: str, max_mutations: int + ) -> list: + """Mutate a set of insertion points by adding, removing, or both.""" + key_fn = { + "node input points": lambda p: (p.node_index, p.input_index), + "region composite points": lambda p: (p.region_index, p.input_index), + "region output points": lambda p: (p.region_index, p.node_index, p.output_index), + }.get(point_type) + + if not key_fn: + return [] + + current_points = set(base_points) + initial_count = len(current_points) + mutation_type = random.choice(["add", "remove", "both"]) + + if mutation_type in ["add", "both"] and len(current_points) < len(all_points): + all_keys = {key_fn(p) for p in all_points} + available_keys = all_keys - current_points + if available_keys: + max_add = min(max_mutations, len(available_keys)) + num_to_add = random.randint(1, max_add) + to_add = random.sample(list(available_keys), num_to_add) + current_points.update(to_add) + + if mutation_type in ["remove", "both"] and current_points: + max_remove = min(max_mutations, len(current_points)) + num_to_remove = random.randint(1, max_remove) if len(current_points) > 1 else 1 + num_to_remove = min(num_to_remove, len(current_points)) + to_remove = random.sample(list(current_points), num_to_remove) + for p in to_remove: + current_points.discard(p) + + logger.debug( + f"Mutated {point_type}: {initial_count} → {len(current_points)} ({mutation_type})" + ) + + return [p for p in all_points if key_fn(p) in current_points] + + def _generate_next_insertion_sample(self) -> InsertionScheme: + """Generate a new insertion scheme by mutating top performers.""" + # Validate current profile region is set + if self.current_profile_region is None: + return InsertionScheme() + + # Determine which schemes collection is active (mutually exclusive) + if self.current_profile_pattern_schemes is not None: + schemes_collection = self.current_profile_pattern_schemes + else: + return InsertionScheme() + + region = self.current_profile_region + pattern_schemes = schemes_collection + + if not isinstance(schemes_collection, PatternSchemes) or schemes_collection.pattern is None: + return InsertionScheme() + pattern = schemes_collection.pattern + full_insertion_scheme = pattern.get_full_insertion_scheme(region, self.graph) + + logger.debug( + f"Available insertion points: {len(full_insertion_scheme.node_inputs)} node input, " + f"{len(full_insertion_scheme.child_region_inputs)} region composite, " + f"{len(full_insertion_scheme.region_outputs)} region output" + ) + + top_percent = getattr(self.config, "top_percent_to_mutate", 0.1) + minimum_schemes = getattr(self.config, "minimum_schemes_to_mutate", 1) + + measured_schemes = [s for s in pattern_schemes.schemes if s.latency_ms > 0 and not s.error] + measured_schemes.sort(key=lambda s: s.latency_ms) + + num_top_schemes = max( + int(len(measured_schemes) * top_percent), min(minimum_schemes, len(measured_schemes)) + ) + top_schemes = measured_schemes[:num_top_schemes] + + if len(top_schemes) == 0: + logger.debug("No measured schemes yet, generating baseline (empty) scheme") + return InsertionScheme() + + base_scheme = random.choice(top_schemes) + total_base_points = ( + len(base_scheme.node_inputs) + + len(base_scheme.child_region_inputs) + + len(base_scheme.region_outputs) + ) + logger.debug( + f"Mutating from top {len(top_schemes)} schemes: " + f"selected base with {total_base_points} points (latency={base_scheme.latency_ms:.3f} ms)" + ) + + max_mutations = getattr(self.config, "maximum_mutations", 3) + + scheme = InsertionScheme() + base_node_points = {(p.node_index, p.input_index) for p in base_scheme.node_inputs} + scheme.node_inputs = self._mutate_insertion_points( + base_node_points, full_insertion_scheme.node_inputs, "node input points", max_mutations + ) + + base_region_composite_points = { + (p.region_index, p.input_index) for p in base_scheme.child_region_inputs + } + scheme.child_region_inputs = self._mutate_insertion_points( + base_region_composite_points, + full_insertion_scheme.child_region_inputs, + "region composite points", + max_mutations, + ) + + base_region_output_points = { + (p.region_index, p.node_index, p.output_index) for p in base_scheme.region_outputs + } + scheme.region_outputs = self._mutate_insertion_points( + base_region_output_points, + full_insertion_scheme.region_outputs, + "region output points", + max_mutations, + ) + + return scheme + + def _copy_graph(self) -> gs.Graph: + """Create an independent copy of the computation graph.""" + new_graph = gs.import_onnx(self.onnx_model) + new_graph.toposort() + return new_graph + + def _get_quant_dtype(self, quant_type: str) -> np.dtype: + """Get numpy dtype for quantization type.""" + if quant_type == "fp8": + try: + return np.dtype(np.float8_e4m3fn) + except (AttributeError, TypeError): + logger.warning( + "FP8 dtype not available (requires numpy >= 2.0), " + "using uint8 as placeholder. Note: This may not produce " + "correct results without proper FP8 support." + ) + return np.uint8 + + dtype_map = { + "int8": np.int8, + "uint8": np.uint8, + } + + if quant_type not in dtype_map: + logger.warning(f"Unknown quantization type '{quant_type}', defaulting to int8") + return np.int8 + + return dtype_map[quant_type] + + def _get_dq_output_dtype(self, dtype_str: str) -> np.dtype: + """Convert DQ dtype string to numpy dtype.""" + dtype_map = { + "float16": np.float16, + "float32": np.float32, + } + + if hasattr(np, "bfloat16"): + dtype_map["bfloat16"] = np.bfloat16 + + if dtype_str not in dtype_map: + logger.warning(f"Unknown DQ dtype '{dtype_str}', defaulting to float32") + return np.float32 + + return dtype_map[dtype_str] + + def _build_tensor_map(self, graph: gs.Graph) -> dict[str, gs.Tensor]: + """Build mapping from tensor names to tensor objects.""" + tensor_map = {} + + for node in graph.nodes: + for output in node.outputs: + if hasattr(output, "name") and output.name: + tensor_map[output.name] = output + + for input_tensor in graph.inputs: + if hasattr(input_tensor, "name") and input_tensor.name: + tensor_map[input_tensor.name] = input_tensor + + for node in graph.nodes: + for input_tensor in node.inputs: + if ( + isinstance(input_tensor, gs.Constant) + and hasattr(input_tensor, "name") + and input_tensor.name + ): + tensor_map[input_tensor.name] = input_tensor + + return tensor_map + + def _get_tensor_metadata( + self, tensor: gs.Tensor, is_constant: bool + ) -> tuple[tuple | None, np.dtype]: + """Extract shape and dtype metadata from a tensor.""" + default_dtype = self._get_dq_output_dtype(self.config.default_dq_dtype) + + if is_constant and hasattr(tensor, "values") and tensor.values is not None: + return tensor.values.shape, tensor.values.dtype + elif hasattr(tensor, "shape"): + dtype = ( + tensor.dtype + if hasattr(tensor, "dtype") and tensor.dtype is not None + else default_dtype + ) + return tensor.shape, dtype + return None, default_dtype + + def _fix_zero_point_initializers(self, model: onnx.ModelProto) -> None: + """Fix INT8 zero_point initializers to use int32_data instead of raw_data.""" + fixed_count = 0 + + for initializer in model.graph.initializer: + if ( + "_zp_" in initializer.name + and initializer.data_type == onnx.TensorProto.INT8 + and len(initializer.raw_data) > 0 + and len(initializer.int32_data) == 0 + ): + np_array = onnx.numpy_helper.to_array(initializer) + int32_values = np_array.astype(np.int32).flatten().tolist() + + new_tensor = onnx.helper.make_tensor( + initializer.name, + onnx.TensorProto.INT8, + list(initializer.dims), + int32_values, + ) + initializer.CopyFrom(new_tensor) + fixed_count += 1 + + if fixed_count > 0: + logger.debug(f"Fixed {fixed_count} zero_point initializers (int32_data format)") + + def _create_qdq_nodes( + self, + tensor_name: str, + qdq_input: gs.Tensor, + output_shape: tuple | None, + output_dtype: np.dtype, + quant_dtype: np.dtype, + quant_type: str, + q_scale: float, + ) -> tuple[gs.Node, gs.Node]: + """Create QuantizeLinear and DequantizeLinear node pair.""" + # Create unique names for Q/DQ nodes + q_name = f"QDQ_Q_{tensor_name}".replace("/", "_").replace(":", "_") + dq_name = f"QDQ_DQ_{tensor_name}".replace("/", "_").replace(":", "_") + + # Determine scale dtype from output_dtype (fp16/tf32/fp32) + # Scale should match the precision of the original I/O tensor + dtype_map = {"float16": np.float16, "float32": np.float32} + if hasattr(np, "bfloat16"): + dtype_map["bfloat16"] = np.bfloat16 + scale_dtype = dtype_map.get(np.dtype(output_dtype).name, np.float32) + + logger.debug( + f"Creating Q/DQ pair for '{tensor_name}' (scale_dtype={np.dtype(scale_dtype).name})" + ) + + q_scale_values = np.array([q_scale], dtype=scale_dtype) + q_zp_values = np.array([0], dtype=np.int8) + q_inputs = [ + qdq_input, + gs.Constant(f"q_scale_{tensor_name}", values=q_scale_values), + gs.Constant(f"q_zp_{tensor_name}", values=q_zp_values), + ] + q_node = gs.Node( + op="QuantizeLinear", + name=q_name, + inputs=q_inputs, + outputs=[ + gs.Variable(f"{tensor_name}_quantized", dtype=quant_dtype, shape=output_shape) + ], + ) + + dq_scale_values = np.array([q_scale], dtype=scale_dtype) + dq_zp_values = np.array([0], dtype=np.int8) + dq_inputs = [ + q_node.outputs[0], + gs.Constant(f"dq_scale_{tensor_name}", values=dq_scale_values), + gs.Constant(f"dq_zp_{tensor_name}", values=dq_zp_values), + ] + dq_node = gs.Node( + op="DequantizeLinear", + name=dq_name, + inputs=dq_inputs, + outputs=[ + gs.Variable(f"{tensor_name}_dequantized", dtype=output_dtype, shape=output_shape) + ], + ) + + return q_node, dq_node + + def _insert_qdq_at_tensors( + self, graph: gs.Graph, resolved_insertion_points: set[ResolvedInsertionPoint] + ) -> None: + """Insert Q/DQ (Quantize/Dequantize) node pairs at specified locations.""" + q_scale = self.config.default_q_scale + quant_type = self.config.default_quant_type + quant_dtype = self._get_quant_dtype(quant_type) + + logger.debug(f"Q/DQ parameters: type={quant_type}, scale={q_scale}, zero_point=0") + + resolved_insertion_points = merge_resolved_insertion_points( + graph, resolved_insertion_points + ) + + tensor_map = self._build_tensor_map(graph) + tensor_users_map = get_tensor_consumer_node_indices(graph) + logger.debug( + f"Built tensor maps: {len(tensor_map)} tensors, {len(tensor_users_map)} with users" + ) + + for insertion_point in resolved_insertion_points: + tensor_name = insertion_point.tensor_name + node_index = insertion_point.node_index + input_index = insertion_point.input_index + + original_tensor = tensor_map[tensor_name] + if node_index is not None: + assert node_index < len(graph.nodes), "Node index out of range" + target_node = graph.nodes[node_index] + assert input_index is not None, "Input index must be set when node index is set" + assert input_index < len(target_node.inputs), ( + f"Input index out of range for node {target_node.name}" + ) + original_tensor = target_node.inputs[input_index] + assert tensor_name == original_tensor.name, ( + f"Tensor name mismatch for node {target_node.name} input {input_index}" + ) + else: + assert tensor_name in tensor_map, f"Tensor {tensor_name} not found in tensor map" + assert input_index is None, "Input index must be None when node index is None" + + is_constant = isinstance(original_tensor, gs.Constant) + output_shape, output_dtype = self._get_tensor_metadata(original_tensor, is_constant) + + unique_suffix = "qdq" + if node_index is not None: + unique_suffix = f"n{node_index}_i{input_index}" + unique_tensor_name = f"{tensor_name}_{unique_suffix}" + + q_node, dq_node = self._create_qdq_nodes( + unique_tensor_name, + original_tensor, + output_shape, + output_dtype, + quant_dtype, + quant_type, + q_scale, + ) + + graph.nodes.extend([q_node, dq_node]) + + if node_index is not None: + target_node.inputs[input_index] = dq_node.outputs[0] + logger.debug( + f" Q/DQ inserted: tensor '{tensor_name}' → node #{node_index} " + f"({target_node.name}) input #{input_index}" + ) + else: + users = tensor_users_map[tensor_name] + for user_index in users: + user_node = graph.nodes[user_index] + for i, input_tensor in enumerate(user_node.inputs): + if hasattr(input_tensor, "name") and input_tensor.name == tensor_name: + user_node.inputs[i] = dq_node.outputs[0] + break + logger.debug(f" Q/DQ inserted: tensor '{tensor_name}' → {len(users)} users") + + logger.debug("Running graph cleanup and topological sort") + try: + graph.cleanup().toposort() + logger.debug("Graph cleanup completed") + except Exception as e: + logger.warning(f"Graph cleanup failed: {e}") + logger.debug("Continuing anyway") + + +class QDQAutotuner(QDQAutotunerBase): + """Q/DQ autotuner with automatic region discovery around compute-intensive ops.""" + + def initialize( + self, config: Config | None = None, pattern_cache: PatternCache | None = None + ) -> None: + """Initialize autotuner and discover optimization regions automatically.""" + super().initialize(config, pattern_cache) + self._search_regions() + + def _visit_region_recursively(self, region: Region) -> list[Region]: + """Recursively traverse region hierarchy and collect all regions.""" + regions = [region] + + for child in region.get_children(): + regions.extend(self._visit_region_recursively(child)) + + return regions + + def _reassign_region_ids(self, regions: list[Region]) -> None: + """Reassign sequential IDs to regions in breadth-first order.""" + region_id = 0 + + queue = deque(regions) + + while queue: + region = queue.popleft() + region.id = region_id + region_id += 1 + queue.extend(region.get_children()) + + def _search_regions(self) -> None: + """Discover and organize optimization regions automatically.""" + logger.info("Discovering optimization regions") + search = CombinedRegionSearch( + self.graph, + maximum_sequence_region_size=self.config.maximum_sequence_region_size, + minimum_topdown_search_size=self.config.minimum_topdown_search_size, + ) + self.regions = search.search_regions() + + self._reassign_region_ids(self.regions) + logger.debug(f"Found {len(self.regions)} top-level regions") + + all_regions = [] + for region in self.regions: + all_regions.extend(self._visit_region_recursively(region)) + + logger.debug(f"Flattened hierarchy to {len(all_regions)} total regions") + + leaf_regions = [region for region in all_regions if region.type == RegionType.LEAF] + other_regions = [region for region in all_regions if region.type != RegionType.LEAF] + + all_regions = leaf_regions + other_regions + self.regions = all_regions + + num_leaf = sum(1 for r in self.regions if r.type == RegionType.LEAF) + num_composite = sum(1 for r in self.regions if r.type == RegionType.COMPOSITE) + num_root = sum(1 for r in self.regions if r.type == RegionType.ROOT) + + logger.info( + f"Discovery complete: {len(self.regions)} regions " + f"({num_leaf} LEAF, {num_composite} COMPOSITE, {num_root} ROOT)" + ) + logger.debug("Regions prioritized: LEAF regions first for profiling") diff --git a/modelopt/onnx/quantization/autotune/benchmark.py b/modelopt/onnx/quantization/autotune/benchmark.py new file mode 100644 index 000000000..fe852315a --- /dev/null +++ b/modelopt/onnx/quantization/autotune/benchmark.py @@ -0,0 +1,780 @@ +#!/usr/bin/env python3 +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""TensorRT Utilities and Benchmark Module. + +This module provides comprehensive TensorRT utilities including: +- Benchmark framework for measuring TensorRT engine performance +- Graph utilities for tensor analysis + +**Benchmark Classes:** +- Benchmark: Abstract base class defining the benchmarking interface +- TrtExecBenchmark: Uses trtexec command-line tool for benchmarking +- TensorRTPyBenchmark: Uses TensorRT Python API for direct engine profiling + +**Features:** +- Timing cache management for faster subsequent builds +- File path or raw bytes as model input +- Configurable warmup and timing iterations +- Custom TensorRT plugin library loading +- Automatic cleanup of temporary resources +""" + +import ctypes +import os +import re +import shutil +import subprocess # nosec B404 +import tempfile +import time +from abc import ABC, abstractmethod +from pathlib import Path + +import numpy as np + +# Optional dependencies - gracefully handle missing packages +try: + import tensorrt as trt + + TRT_AVAILABLE = True +except ImportError: + TRT_AVAILABLE = False + +try: + import pycuda.autoinit # noqa: F401 # Automatically initializes CUDA (side-effect import) + import pycuda.driver as cuda + + PYCUDA_AVAILABLE = True +except ImportError: + PYCUDA_AVAILABLE = False + +from modelopt.onnx.logging_config import logger + +# ============================================================================= +# Benchmark Framework +# ============================================================================= + + +class Benchmark(ABC): + """Abstract base class for TensorRT model benchmarking. + + This class defines the interface that all benchmark implementations must follow. + It provides a consistent API for measuring inference latency of ONNX models + when converted to TensorRT engines. + + Attributes: + timing_cache_file: Path to the TensorRT timing cache file. + logger: Logger instance for this benchmark. + + Subclasses must implement: + run(): Execute the benchmark and return latency in milliseconds. + """ + + def __init__(self, timing_cache_file: str | None = None): + """Initialize the benchmark. + + Args: + timing_cache_file: Path to timing cache file to accelerate engine builds. + If None, uses '/tmp/trtexec_timing.cache' as default. + """ + global logger + self.timing_cache_file = timing_cache_file or "/tmp/trtexec_timing.cache" # nosec B108 + self.logger = logger + + @abstractmethod + def run(self, path_or_bytes: str | bytes, log_file: str | None = None) -> float: + """Run benchmark on the given ONNX model. + + Args: + path_or_bytes: Path to the ONNX model (str) or raw model data (bytes) + log_file: Optional path to save benchmark logs + + Returns: + Measured latency in milliseconds, or float("inf") on failure + """ + + def __call__(self, path_or_bytes: str | bytes, log_file: str | None = None) -> float: + """Convenience method to call benchmark as a function. + + Args: + path_or_bytes: Path to the ONNX model (str) or raw model data (bytes) + log_file: Optional path to save benchmark logs + + Returns: + Measured latency in milliseconds + """ + return self.run(path_or_bytes, log_file) + + +class TrtExecBenchmark(Benchmark): + """TensorRT benchmark using trtexec command-line tool. + + This implementation uses the trtexec binary to build engines and measure + inference latency. It is the most straightforward method and closely + mirrors standard TensorRT workflows. + + Features: + - Uses subprocess to call trtexec binary + - Supports all trtexec command-line arguments + - Custom TensorRT plugin library loading + - Automatic temporary directory management for engines + - Timing cache persistence across benchmarks + - Supports both file paths and raw bytes as input + + Attributes: + trtexec_path: Path to the trtexec binary. + trtexec_args: Additional command-line arguments for trtexec. + warmup_runs: Number of warmup iterations before timing. + timing_runs: Number of iterations for latency measurement. + timeout: Maximum time in seconds for trtexec execution. + plugin_libraries: List of paths to plugin libraries. + engine_dir: Directory for storing temporary engine files. + engine_path: Path to the engine file. + temp_model_path: Path for temporary ONNX model (when using bytes). + """ + + def __init__( + self, + trtexec_path: str = "trtexec", + trtexec_args: list | None = None, + timing_cache_file: str | None = None, + warmup_runs: int = 5, + timing_runs: int = 10, + timeout: int = 300, + plugin_libraries: list[str] | None = None, + ): + """Initialize the trtexec benchmark. + + Args: + trtexec_path: Path to trtexec binary. Defaults to 'trtexec' which + looks for the binary in PATH. + trtexec_args: Additional command-line arguments to pass to trtexec. + These are appended after the standard arguments. + Example: ['--fp16', '--workspace=4096', '--verbose'] + timing_cache_file: Path to TensorRT timing cache file for faster + subsequent builds. Defaults to '/tmp/trtexec_timing.cache'. + warmup_runs: Number of warmup iterations before timing measurements. + timing_runs: Number of iterations for latency measurement. Results + are averaged across these runs. + timeout: Maximum time in seconds for trtexec execution before timeout. + plugin_libraries: List of paths to TensorRT plugin shared libraries (.so files). + These plugins will be loaded by trtexec during engine building. + If None, no custom plugins are loaded. + """ + super().__init__(timing_cache_file) + + # Store configuration + self.trtexec_path = trtexec_path + self.trtexec_args = trtexec_args or [] + self.warmup_runs = warmup_runs + self.timing_runs = timing_runs + self.timeout = timeout + self.plugin_libraries = plugin_libraries or [] + + # Create persistent temporary directory for engine and model files + # This directory persists for the lifetime of this benchmark object + self._temp_dir = tempfile.mkdtemp(prefix="trtexec_benchmark_") + self.engine_dir = self._temp_dir + self.engine_path = os.path.join(self.engine_dir, "engine.trt") + self.temp_model_path = os.path.join(self.engine_dir, "temp_model.onnx") + self.logger.debug(f"Created temporary engine directory: {self.engine_dir}") + self.logger.debug(f"Temporary model path: {self.temp_model_path}") + + # Construct base trtexec command template + # The '--onnx' argument will be added dynamically in run() + self._base_cmd = [ + self.trtexec_path, + f"--avgRuns={self.timing_runs}", + f"--iterations={self.timing_runs}", + f"--warmUp={self.warmup_runs}", + "--stronglyTyped", # Enable strongly typed mode for Q/DQ ops + f"--saveEngine={self.engine_path}", + f"--timingCacheFile={self.timing_cache_file}", + ] + + # Add plugin libraries + for plugin_lib in self.plugin_libraries: + plugin_path = Path(plugin_lib).resolve() + if not plugin_path.exists(): + self.logger.warning(f"Plugin library not found: {plugin_path}") + else: + self._base_cmd.append(f"--staticPlugins={plugin_path}") + self.logger.debug(f"Added plugin library: {plugin_path}") + + # Append user-provided custom arguments + if self.trtexec_args: + self._base_cmd.extend(self.trtexec_args) + + self.logger.debug(f"Base command template: {' '.join(self._base_cmd)}") + + def __del__(self): + """Cleanup temporary directory.""" + if hasattr(self, "_temp_dir"): + try: + shutil.rmtree(self._temp_dir, ignore_errors=True) + self.logger.debug(f"Cleaned up temporary directory: {self._temp_dir}") + except Exception as e: + self.logger.warning(f"Failed to cleanup temporary directory: {e}") + + def run( + self, + path_or_bytes: str | bytes, + log_file: str | None = None, + flush_timing_cache: bool = False, + ) -> float: + """Run benchmark using trtexec. + + Args: + path_or_bytes: Path to the ONNX model (str) or raw model data (bytes) + log_file: Optional path to save trtexec logs + + Returns: + Measured median latency in milliseconds + """ + cache_exists = os.path.exists(self.timing_cache_file) + if cache_exists: + self.logger.debug(f"Using existing timing cache: {self.timing_cache_file}") + else: + self.logger.debug(f"Will create timing cache: {self.timing_cache_file}") + + try: + # If bytes provided, write to temporary model path + if isinstance(path_or_bytes, bytes): + with open(self.temp_model_path, "wb") as f: + f.write(path_or_bytes) + model_path = self.temp_model_path + self.logger.debug(f"Wrote model bytes to temporary file: {model_path}") + else: + model_path = path_or_bytes + + # Build complete command from base template + cmd = [self._base_cmd[0], f"--onnx={model_path}", *self._base_cmd[1:]] + + self.logger.debug(f"Running: {' '.join(cmd)}") + + # Run trtexec and capture output + result = subprocess.run(cmd, capture_output=True, text=True, timeout=self.timeout) # nosec B603 + + # Save logs if requested + if log_file is not None: + try: + log_path = Path(log_file) + log_path.parent.mkdir(parents=True, exist_ok=True) + with open(log_path, "w") as f: + f.write(f"Command: {' '.join(cmd)}\n") + f.write(f"Return code: {result.returncode}\n") + f.write("=" * 80 + "\n") + f.write("STDOUT:\n") + f.write("=" * 80 + "\n") + f.write(result.stdout) + f.write("\n" + "=" * 80 + "\n") + f.write("STDERR:\n") + f.write("=" * 80 + "\n") + f.write(result.stderr) + self.logger.debug(f"Saved trtexec logs to: {log_file}") + except Exception as e: + self.logger.warning(f"Failed to save logs to {log_file}: {e}") + + if result.returncode != 0: + self.logger.error(f"trtexec failed with return code {result.returncode}") + self.logger.error(f"stderr: {result.stderr}") + return float("inf") + + # Parse output to extract latency + # trtexec outputs lines like: + # "[I] Latency: min = X ms, max = Y ms, mean = Z ms, median = W ms, ..." + output = result.stdout + + # Look for median latency in the main "[I] Latency:" line + pattern = r"\[I\]\s+Latency:.*?median\s*=\s*([\d.]+)\s*ms" + + match = re.search(pattern, output, re.IGNORECASE) + if match: + latency = float(match.group(1)) + self.logger.info(f"TrtExec benchmark (median): {latency:.2f} ms") + return latency + + self.logger.warning("Could not parse median latency from trtexec output") + self.logger.debug(f"trtexec stdout:\n{output}") + return float("inf") + + except subprocess.TimeoutExpired: + self.logger.error(f"trtexec timed out after {self.timeout} seconds") + return float("inf") + except FileNotFoundError: + self.logger.error(f"trtexec binary not found: {self.trtexec_path}") + self.logger.error("Please ensure TensorRT is installed and trtexec path is correct") + return float("inf") + except Exception as e: + self.logger.error(f"Benchmark failed: {e}") + return float("inf") + + +class TensorRTPyBenchmark(Benchmark): + """TensorRT benchmark using Python API with plugin support. + + This implementation directly uses the TensorRT Python API to build engines + and measure inference latency. It provides more control than trtexec and + can be faster for certain workflows as it avoids subprocess overhead. + + Features: + - Direct TensorRT Python API usage (no subprocess) + - Persistent Builder, Logger, and Runtime objects + - Custom TensorRT plugin library loading + - Automatic dynamic shape handling + - In-memory timing cache management + - CUDA memory management via PyCUDA + - Detailed latency statistics (min, max, mean, median) + + Requirements: + - tensorrt package + - pycuda package + - CUDA-capable GPU + + Attributes: + trt_logger: TensorRT Logger instance (persistent). + builder: TensorRT Builder instance (persistent). + runtime: TensorRT Runtime instance (persistent). + config: Builder configuration (recreated per run). + warmup_runs: Number of warmup iterations. + timing_runs: Number of timing iterations. + plugin_libraries: List of loaded plugin library paths. + _shape_configs: Dictionary storing custom shape configurations. + _plugin_registry: TensorRT PluginRegistry instance. + + Methods: + set_shapes(): Configure min/opt/max shapes for dynamic inputs. + run(): Execute the benchmark and return latency. + """ + + def __init__( + self, + timing_cache_file: str | None = None, + warmup_runs: int = 5, + timing_runs: int = 20, + plugin_libraries: list[str] | None = None, + ): + """Initialize the TensorRT Python API benchmark. + + Creates persistent TensorRT objects (Logger, Builder, Runtime) and + loads the timing cache from disk if available. Optionally loads custom + TensorRT plugin libraries for models with custom operations. + + Args: + timing_cache_file: Path to TensorRT timing cache file. If None, + defaults to '/tmp/trtexec_timing.cache'. + warmup_runs: Number of warmup iterations before timing measurements. + timing_runs: Number of iterations for latency measurement. + plugin_libraries: List of paths to TensorRT plugin shared libraries (.so files). + These plugins will be loaded and registered for use during + engine building. If None, no custom plugins are loaded. + + Raises: + ImportError: If tensorrt or pycuda packages are not available. + FileNotFoundError: If a specified plugin library file does not exist. + RuntimeError: If plugin library loading fails. + """ + super().__init__(timing_cache_file) + self.warmup_runs = warmup_runs + self.timing_runs = timing_runs + self.plugin_libraries = plugin_libraries or [] + + if not TRT_AVAILABLE: + raise ImportError("TensorRT Python API not available. Please install tensorrt package.") + if not PYCUDA_AVAILABLE: + raise ImportError("PyCUDA not available. Please install pycuda package.") + + self.trt_logger = trt.Logger(trt.Logger.WARNING) + self.builder = trt.Builder(self.trt_logger) + self.runtime = trt.Runtime(self.trt_logger) + # Load custom plugin libraries before initializing TensorRT plugins + self._loaded_plugin_handles = [] + if self.plugin_libraries: + self._load_plugin_libraries() + # Get plugin registry (must be done after loading plugin libraries) + trt.init_libnvinfer_plugins(self.trt_logger, "") + + self._plugin_registry = trt.get_plugin_registry() + + # Set network flag + self.network_flags = 1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH) + self.network_flags |= 1 << int(trt.NetworkDefinitionCreationFlag.STRONGLY_TYPED) + + # Load timing cache from disk or create new one + self._timing_cache = None + self._load_timing_cache() + + # Storage for user-defined shape configurations + # Format: {input_name: (min_shape, opt_shape, max_shape)} + self._shape_configs = {} + + def _load_plugin_libraries(self): + """Load custom TensorRT plugin libraries from shared object files. + + This method loads plugin libraries using ctypes and initializes them + with the TensorRT plugin registry. Plugins must export the + initLibNvInferPlugins function to register their implementations. + + The loaded library handles are stored to prevent them from being + garbage collected during the benchmark's lifetime. + + Raises: + FileNotFoundError: If a plugin library file does not exist. + RuntimeError: If plugin initialization fails. + """ + for plugin_lib in self.plugin_libraries: + plugin_path = Path(plugin_lib).resolve() + + if not plugin_path.exists(): + raise FileNotFoundError(f"Plugin library not found: {plugin_path}") + + self.logger.info(f"Loading TensorRT plugin: {plugin_path}") + + try: + if hasattr(os, "RTLD_LAZY") and hasattr(os, "RTLD_GLOBAL"): + plugin_handle = ctypes.CDLL( + str(plugin_path), mode=os.RTLD_LAZY | os.RTLD_GLOBAL + ) + else: + # Fallback for platforms without RTLD flags (e.g., Windows) + plugin_handle = ctypes.CDLL(str(plugin_path)) + + # Store handle to prevent garbage collection + self._loaded_plugin_handles.append(plugin_handle) + + # Try to initialize plugin with TensorRT registry + # Most TensorRT plugins export initLibNvInferPlugins function + if hasattr(plugin_handle, "initLibNvInferPlugins"): + init_func = plugin_handle.initLibNvInferPlugins + # Function signature: bool initLibNvInferPlugins(void* logger, const char* namespace) + init_func.argtypes = [ctypes.c_void_p, ctypes.c_char_p] + init_func.restype = ctypes.c_bool + + # Initialize with the TensorRT logger and default namespace + success = init_func(None, b"") + if not success: + self.logger.warning( + f"Plugin initialization returned false for: {plugin_path}" + ) + else: + self.logger.info(f"Successfully initialized plugin: {plugin_path.name}") + else: + self.logger.info( + f"Plugin loaded (no initLibNvInferPlugins function): {plugin_path.name}" + ) + + except Exception as e: + raise RuntimeError(f"Failed to load plugin library {plugin_path}: {e}") from e + + def set_shapes(self, input_name: str, min_shape: list, opt_shape: list, max_shape: list): + """Set custom min/opt/max shapes for a dynamic input. + + This method allows you to specify custom shape ranges for dynamic inputs + (inputs with -1 dimensions). If not specified, the benchmark will use + default shapes (all -1 dimensions become 1). + + Args: + input_name: Name of the input tensor to configure. + min_shape: Minimum shape for this input. List of integers. + opt_shape: Optimal/default shape for this input. List of integers. + max_shape: Maximum shape for this input. List of integers. + """ + if len(min_shape) != len(opt_shape) or len(opt_shape) != len(max_shape): + raise ValueError("min_shape, opt_shape, and max_shape must have the same length") + + for i, (min_dim, opt_dim, max_dim) in enumerate(zip(min_shape, opt_shape, max_shape)): + if not (min_dim <= opt_dim <= max_dim): + raise ValueError( + f"Invalid shape range at dimension {i}: " + f"min={min_dim}, opt={opt_dim}, max={max_dim}. " + f"Must satisfy min <= opt <= max" + ) + + self._shape_configs[input_name] = (min_shape, opt_shape, max_shape) + self.logger.debug( + f"Set shapes for input '{input_name}': " + f"min={min_shape}, opt={opt_shape}, max={max_shape}" + ) + + def run( + self, + path_or_bytes: str | bytes, + log_file: str | None = None, + flush_timing_cache: bool = False, + ) -> float: + """Run benchmark using TensorRT Python API. + + Args: + path_or_bytes: Path to the ONNX model (str) or raw model data (bytes) + log_file: Optional path to save benchmark logs + + Returns: + Measured median latency in milliseconds + """ + config = None + network = None + parser = None + serialized_engine = None + engine = None + context = None + inputs = [] + outputs = [] + stream = None + + try: + self.logger.debug("Creating TensorRT builder...") + config = self.builder.create_builder_config() + config.set_flag(trt.BuilderFlag.DIRECT_IO) + if not config.set_timing_cache(self._timing_cache, ignore_mismatch=True): + self.logger.warning("Failed to set timing cache to builder config") + network = self.builder.create_network(self.network_flags) + # Create network and parser using the shared builder and logger + parser = trt.OnnxParser(network, self.trt_logger) + + # Parse ONNX model + if isinstance(path_or_bytes, bytes): + self.logger.debug(f"Parsing ONNX model from bytes (size: {len(path_or_bytes)})") + model_data = path_or_bytes + else: + self.logger.debug(f"Parsing ONNX model: {path_or_bytes}") + with open(path_or_bytes, "rb") as f: + model_data = f.read() + + if not parser.parse(model_data): + self.logger.error("Failed to parse ONNX model") + for error_idx in range(parser.num_errors): + self.logger.error(f" {parser.get_error(error_idx)}") + return float("inf") + + has_dynamic_shapes = False + for i in range(network.num_inputs): + input_tensor = network.get_input(i) + shape = input_tensor.shape + if any(dim == -1 for dim in shape): + has_dynamic_shapes = True + break + + if has_dynamic_shapes: + profile = self.builder.create_optimization_profile() + for i in range(network.num_inputs): + input_tensor = network.get_input(i) + input_name = input_tensor.name + shape = list(input_tensor.shape) + + # Check if user provided custom shape configuration + if input_name in self._shape_configs: + min_shape, opt_shape, max_shape = self._shape_configs[input_name] + self.logger.debug( + f"Using custom shapes for input '{input_name}': " + f"min={min_shape}, opt={opt_shape}, max={max_shape}" + ) + else: + # Use default: replace -1 with concrete values (1) + min_shape = [1 if dim == -1 else dim for dim in shape] + opt_shape = [1 if dim == -1 else dim for dim in shape] + max_shape = [1 if dim == -1 else dim for dim in shape] + self.logger.debug( + f"Using default shapes for input '{input_name}': {opt_shape}" + ) + + profile.set_shape(input_name, min_shape, opt_shape, max_shape) + + config.add_optimization_profile(profile) + + self.logger.debug("Building TensorRT engine...") + build_start = time.perf_counter() + serialized_engine = self.builder.build_serialized_network(network, config) + build_time = time.perf_counter() - build_start + + if serialized_engine is None: + self.logger.error("Failed to build TensorRT engine") + return float("inf") + + self.logger.debug(f"Engine built successfully in {build_time:.2f}s") + + if flush_timing_cache: + self._save_timing_cache() + + engine = self.runtime.deserialize_cuda_engine(serialized_engine) + + if engine is None: + self.logger.error("Failed to deserialize engine") + return float("inf") + + context = engine.create_execution_context() + + inputs = [] + outputs = [] + + for i in range(engine.num_io_tensors): + tensor_name = engine.get_tensor_name(i) + dtype = trt.nptype(engine.get_tensor_dtype(tensor_name)) + shape = context.get_tensor_shape(tensor_name) + + size = trt.volume(shape) + host_mem = cuda.pagelocked_empty(size, dtype) + device_mem = cuda.mem_alloc(host_mem.nbytes) + + if engine.get_tensor_mode(tensor_name) == trt.TensorIOMode.INPUT: + np.copyto(host_mem, np.random.randn(size).astype(dtype)) + inputs.append({"host": host_mem, "device": device_mem, "name": tensor_name}) + else: + outputs.append({"host": host_mem, "device": device_mem, "name": tensor_name}) + + context.set_tensor_address(tensor_name, int(device_mem)) + + stream = cuda.Stream() + + self.logger.debug(f"Running {self.warmup_runs} warmup iterations...") + for _ in range(self.warmup_runs): + for inp in inputs: + cuda.memcpy_htod_async(inp["device"], inp["host"], stream) + context.execute_async_v3(stream_handle=stream.handle) + for out in outputs: + cuda.memcpy_dtoh_async(out["host"], out["device"], stream) + stream.synchronize() + + self.logger.debug(f"Running {self.timing_runs} timing iterations...") + latencies = [] + + for _ in range(self.timing_runs): + for inp in inputs: + cuda.memcpy_htod_async(inp["device"], inp["host"], stream) + + stream.synchronize() + start = time.perf_counter() + context.execute_async_v3(stream_handle=stream.handle) + stream.synchronize() + end = time.perf_counter() + + latency_ms = (end - start) * 1000.0 + latencies.append(latency_ms) + + for out in outputs: + cuda.memcpy_dtoh_async(out["host"], out["device"], stream) + + latencies = np.array(latencies) + median_latency = float(np.median(latencies)) + mean_latency = float(np.mean(latencies)) + std_latency = float(np.std(latencies)) + min_latency = float(np.min(latencies)) + max_latency = float(np.max(latencies)) + + self.logger.info("TensorRT Python API benchmark:") + self.logger.info( + f" min={min_latency:.3f}ms, max={max_latency:.3f}ms, " + f"mean={mean_latency:.3f}ms, std={std_latency:.3f}ms, median={median_latency:.3f}ms" + ) + + if log_file is not None: + try: + log_path = Path(log_file) + log_path.parent.mkdir(parents=True, exist_ok=True) + model_info = ( + f"" + if isinstance(path_or_bytes, bytes) + else path_or_bytes + ) + with open(log_path, "w") as f: + f.write(f"""TensorRT Python API Benchmark +Model: {model_info} +Build time: {build_time:.2f}s +Warmup runs: {self.warmup_runs} +Timing runs: {self.timing_runs} + +Latency Statistics: + Min: {min_latency:.3f} ms + Max: {max_latency:.3f} ms + Mean: {mean_latency:.3f} ms + Std: {std_latency:.3f} ms + Median: {median_latency:.3f} ms + +All latencies: {latencies.tolist()} +""") + self.logger.debug(f"Saved benchmark logs to: {log_file}") + except Exception as e: + self.logger.warning(f"Failed to save logs to {log_file}: {e}") + return median_latency + except Exception as e: + self.logger.error(f"Benchmark failed: {e}", exc_info=True) + return float("inf") + finally: + try: + for inp in inputs: + if "device" in inp: + inp["device"].free() + if "host" in inp: + del inp["host"] + for out in outputs: + if "device" in out: + out["device"].free() + if "host" in out: + del out["host"] + inputs.clear() + outputs.clear() + + if context is not None: + del context + if stream is not None: + del stream + if engine is not None: + del engine + if serialized_engine is not None: + del serialized_engine + if parser is not None: + del parser + if network is not None: + del network + if config is not None: + del config + except Exception as cleanup_error: + self.logger.warning(f"Error during cleanup: {cleanup_error}") + + def _load_timing_cache(self): + """Load timing cache from file or create a new one.""" + config = self.builder.create_builder_config() + if os.path.exists(self.timing_cache_file): + try: + with open(self.timing_cache_file, "rb") as f: + timing_cache_data = f.read() + self._timing_cache = config.create_timing_cache(timing_cache_data) + self.logger.debug(f"Loaded timing cache from: {self.timing_cache_file}") + except Exception as e: + self.logger.warning(f"Failed to load timing cache: {e}") + self.logger.debug("Creating new timing cache") + self._timing_cache = None + + if self._timing_cache is None: + self._timing_cache = config.create_timing_cache(b"") + self.logger.debug("Created new timing cache") + del config + + def _save_timing_cache(self): + """Save timing cache to file.""" + try: + if self._timing_cache is not None: + config = self.builder.create_builder_config() + output_cache = config.create_timing_cache(b"") + if self._timing_cache is None: + output_cache.combline(self._timing_cache, ignore_errors=True) + timing_cache_data = output_cache.serialize() + with open(self.timing_cache_file, "wb") as f: + f.write(timing_cache_data) + self.logger.debug(f"Saved timing cache to: {self.timing_cache_file}") + except Exception as e: + self.logger.warning(f"Failed to save timing cache: {e}") + finally: + del config diff --git a/modelopt/onnx/quantization/autotune/cli.py b/modelopt/onnx/quantization/autotune/cli.py new file mode 100644 index 000000000..a5809f9a5 --- /dev/null +++ b/modelopt/onnx/quantization/autotune/cli.py @@ -0,0 +1,294 @@ +#!/usr/bin/env python3 +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""CLI argument parsing and execution for ONNX Q/DQ autotuning. + +This module provides `run_autotune` which handles both argument parsing and +workflow execution. See `__main__.py` for usage examples. +""" + +import argparse +import sys +from pathlib import Path + +from modelopt.onnx.logging_config import logger +from modelopt.onnx.quantization.autotune.workflows import ( + init_benchmark_instance, + region_pattern_autotuning_workflow, +) + +DEFAULT_OUTPUT_DIR = "./autotuner_output" +DEFAULT_NUM_SCHEMES = 30 +DEFAULT_QUANT_TYPE = "int8" +DEFAULT_DQ_DTYPE = "float32" +DEFAULT_TIMING_CACHE = "/tmp/trtexec_timing.cache" # nosec B108 +DEFAULT_WARMUP_RUNS = 5 +DEFAULT_TIMING_RUNS = 20 + + +def validate_file_path(path: str | None, description: str) -> Path | None: + """Validate that a file path exists. + + Args: + path: Path string to validate (can be None) + description: Description of the file for error messages + + Returns: + Path object if valid, None if path is None + + Raises: + SystemExit: If path is provided but doesn't exist + """ + if path is None: + return None + + path_obj = Path(path) + if not path_obj.exists(): + logger.error(f"{description} not found: {path_obj}") + sys.exit(1) + + return path_obj + + +def log_benchmark_config(args): + """Log TensorRT benchmark configuration for transparency. + + Logs timing cache path, warmup/timing run counts, and any custom + plugin libraries that will be loaded. + + Args: + args: Parsed command-line arguments with benchmark configuration + """ + logger.info("Initializing TensorRT benchmark") + logger.info(f" Timing cache: {args.timing_cache}") + logger.info(f" Warmup runs: {args.warmup_runs}") + logger.info(f" Timing runs: {args.timing_runs}") + if args.plugin_libraries: + logger.info(f" Plugin libraries: {', '.join(args.plugin_libraries)}") + + +def run_autotune(args=None) -> int: + """Execute the complete pattern-based Q/DQ autotuning workflow. + + This function orchestrates the entire optimization process: + 1. Parses command-line arguments (if not provided) + 2. Validates input paths (model, baseline, output directory) + 3. Initializes TensorRT benchmark instance + 4. Runs pattern-based region autotuning workflow + 5. Handles interruptions gracefully with state preservation + + Args: + args: Optional parsed command-line arguments. If None, parses sys.argv. + + Returns: + Exit code: + - 0: Success + - 1: Autotuning failed (exception occurred) + - 130: Interrupted by user (Ctrl+C) + """ + if args is None: + args = _get_autotune_parser().parse_args() + + model_path = validate_file_path(args.onnx_path, "Model file") + validate_file_path(args.qdq_baseline, "QDQ baseline model") + output_dir = Path(args.output) + + log_benchmark_config(args) + init_benchmark_instance( + use_trtexec=args.use_trtexec, + plugin_libraries=args.plugin_libraries, + timing_cache_file=args.timing_cache, + warmup_runs=args.warmup_runs, + timing_runs=args.timing_runs, + ) + + logger.info("Autotuning Mode: Pattern-Based") + + try: + node_filter_list = None + if args.node_filter_list: + filter_file = validate_file_path(args.node_filter_list, "Node filter list file") + if filter_file: + with open(filter_file) as f: + node_filter_list = [ + line.strip() + for line in f + if line.strip() and not line.strip().startswith("#") + ] + logger.info(f"Loaded {len(node_filter_list)} filter patterns from {filter_file}") + + region_pattern_autotuning_workflow( + model_path=str(model_path), + output_dir=output_dir, + num_schemes_per_region=args.num_schemes, + pattern_cache_file=args.pattern_cache_file, + state_file=args.state_file, + quant_type=args.quant_type, + default_dq_dtype=args.default_dq_dtype, + qdq_baseline_model=args.qdq_baseline, + node_filter_list=node_filter_list, + ) + + logger.info("\n" + "=" * 70) + logger.info("✓ Autotuning completed successfully!") + logger.info(f"✓ Results: {output_dir}") + logger.info("=" * 70) + return 0 + + except KeyboardInterrupt: + logger.warning("\nInterrupted by user") + state_file = args.state_file or output_dir / "autotuner_state.yaml" + logger.info(f"Progress saved to: {state_file}") + return 130 + + except Exception as e: + logger.error(f"\nAutotuning failed: {e}", exc_info=args.verbose) + return 1 + + +def _get_autotune_parser() -> argparse.ArgumentParser: + """Create and configure the command-line argument parser.""" + parser = argparse.ArgumentParser( + prog="modelopt.onnx.quantization.autotune", + description="ONNX Q/DQ Autotuning with TensorRT", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=""" +Examples: + # Basic usage + python -m modelopt.onnx.quantization.autotune --onnx_path model.onnx + + # Import patterns from QDQ baseline model + python -m modelopt.onnx.quantization.autotune \\ + --onnx_path model.onnx --qdq_baseline baseline.onnx + + # Use pattern cache for warm-start + python -m modelopt.onnx.quantization.autotune --onnx_path model.onnx --pattern_cache cache.yaml + + # Full example with all options + python -m modelopt.onnx.quantization.autotune \\ + --onnx_path model.onnx --schemes_per_region 50 \\ + --pattern_cache cache.yaml --qdq_baseline baseline.onnx \\ + --quant_type int8 --verbose + """, + ) + + # Model and Output + io_group = parser.add_argument_group("Model and Output") + io_group.add_argument( + "--onnx_path", "-m", type=str, required=True, help="Path to ONNX model file" + ) + io_group.add_argument( + "--output", + "-o", + type=str, + default=DEFAULT_OUTPUT_DIR, + help=f"Output directory for results (default: {DEFAULT_OUTPUT_DIR})", + ) + + # Autotuning Strategy + strategy_group = parser.add_argument_group("Autotuning Strategy") + strategy_group.add_argument( + "--schemes_per_region", + "-s", + type=int, + default=DEFAULT_NUM_SCHEMES, + dest="num_schemes", + help=f"Number of schemes to test per region (default: {DEFAULT_NUM_SCHEMES})", + ) + strategy_group.add_argument( + "--pattern_cache", + type=str, + default=None, + dest="pattern_cache_file", + help="Path to pattern cache YAML for warm-start (optional)", + ) + strategy_group.add_argument( + "--qdq_baseline", + type=str, + default=None, + help="Path to QDQ baseline ONNX model to import quantization patterns (optional)", + ) + strategy_group.add_argument( + "--state_file", + type=str, + default=None, + help="State file path for resume capability (default: /autotuner_state.yaml)", + ) + strategy_group.add_argument( + "--node_filter_list", + type=str, + default=None, + help="Path to a file containing wildcard patterns to filter ONNX nodes (one pattern per line). " + "Regions without any matching nodes are skipped during autotuning.", + ) + + # Quantization + quant_group = parser.add_argument_group("Quantization") + quant_group.add_argument( + "--quant_type", + type=str, + default=DEFAULT_QUANT_TYPE, + choices=["int8", "fp8"], + help=f"Quantization data type (default: {DEFAULT_QUANT_TYPE})", + ) + quant_group.add_argument( + "--default_dq_dtype", + type=str, + default=DEFAULT_DQ_DTYPE, + choices=["float16", "float32", "bfloat16"], + help="Default DQ output dtype if cannot be deduced (optional)", + ) + + # TensorRT Benchmark + trt_group = parser.add_argument_group("TensorRT Benchmark") + trt_group.add_argument( + "--use_trtexec", + action="store_true", + help="Use trtexec for benchmarking (default: False)", + default=False, + ) + trt_group.add_argument( + "--timing_cache", + type=str, + default=DEFAULT_TIMING_CACHE, + help=f"TensorRT timing cache file (default: {DEFAULT_TIMING_CACHE})", + ) + trt_group.add_argument( + "--warmup_runs", + type=int, + default=DEFAULT_WARMUP_RUNS, + help=f"Number of warmup runs (default: {DEFAULT_WARMUP_RUNS})", + ) + trt_group.add_argument( + "--timing_runs", + type=int, + default=DEFAULT_TIMING_RUNS, + help=f"Number of timing runs (default: {DEFAULT_TIMING_RUNS})", + ) + trt_group.add_argument( + "--plugin_libraries", + "--plugins", + type=str, + nargs="+", + default=None, + dest="plugin_libraries", + help="TensorRT plugin libraries (.so files) to load (optional, space-separated)", + ) + + # Logging + parser.add_argument("--verbose", "-v", action="store_true", help="Enable verbose DEBUG logging") + + return parser diff --git a/modelopt/onnx/quantization/autotune/workflows.py b/modelopt/onnx/quantization/autotune/workflows.py new file mode 100644 index 000000000..94a41fade --- /dev/null +++ b/modelopt/onnx/quantization/autotune/workflows.py @@ -0,0 +1,417 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""ONNX Q/DQ Autotuning Workflows. + +This module provides high-level workflow functions for automated Q/DQ (Quantization/Dequantization) +optimization of ONNX models using pattern-based region analysis and TensorRT performance measurement. + +**Core Capabilities:** + +1. **Automated Region Discovery**: Discovers hierarchical regions in the computation graph + - LEAF regions: Contain actual graph nodes + - COMPOSITE regions: Contain child regions with hierarchical structure + +2. **Pattern-Based Optimization**: Groups regions by structural pattern + - Regions with identical patterns share optimization schemes + - One optimization applies to all matching regions simultaneously + +3. **TensorRT Benchmarking**: Measures actual inference performance + - Builds TensorRT engines for each Q/DQ configuration + - Measures median latency across multiple runs + - Caches timing data for faster iteration + +4. **Incremental State Management**: Supports crash recovery and resume + - Saves state after each region profiling + - Resumes from last checkpoint automatically + - Preserves baseline and all measurements + +5. **Pattern Cache Warm-Start**: Leverages previous optimization results + - Loads known-good schemes from cache + - Reduces exploration time for similar models + - Transfers learned patterns across runs + +**Key Functions:** + +- **benchmark_onnx_model()**: Benchmark ONNX model inference latency using TensorRT +- **init_benchmark_instance()**: Initialize global TensorRT benchmark instance +- **region_pattern_autotuning_workflow()**: Complete end-to-end Q/DQ optimization workflow + +**Workflow Overview:** + +1. Initialize autotuner with automatic region discovery +2. Measure baseline performance (no Q/DQ) +3. For each region pattern: + - Generate Q/DQ insertion schemes + - Benchmark each scheme with TensorRT + - Select best scheme for pattern + - Apply to all regions with matching pattern +4. Export final optimized model + +**Performance Optimization:** + +- Pattern-based approach reduces redundant evaluation +- TensorRT timing cache speeds up engine builds +- Incremental state saves enable long-running optimizations +- Pattern cache enables cross-model learning +""" + +import fnmatch +from pathlib import Path + +import onnx + +from modelopt.onnx.logging_config import logger +from modelopt.onnx.quantization.autotune.autotuner import QDQAutotuner +from modelopt.onnx.quantization.autotune.benchmark import TensorRTPyBenchmark, TrtExecBenchmark +from modelopt.onnx.quantization.autotune.common import Config, PatternCache +from modelopt.onnx.quantization.qdq_utils import get_quantized_tensors + +_benchmark_instance = None + + +def benchmark_onnx_model( + model_path: str | bytes, log_file: str | None = None, flush_timing_cache: bool = False +) -> float: + """Benchmark ONNX model inference latency using TensorRT Python API. + + Args: + model_path: Path to ONNX model file, or bytes containing serialized model protobuf + log_file: Optional path to save detailed TensorRT build and benchmark logs + (default: None, no logging) + flush_timing_cache: If True, flushes TensorRT timing cache before building engine. + Useful for periodic cache refresh (default: False) + + Returns: + Measured median inference latency in milliseconds. + Returns float('inf') on failure (invalid model, build error, etc.) + + Raises: + No exceptions raised - errors are caught and logged, returning float('inf') + """ + global _benchmark_instance + + if _benchmark_instance is None: + logger.error("Benchmark instance not initialized") + return float("inf") + + try: + latency = _benchmark_instance.run( + model_path, log_file=log_file, flush_timing_cache=flush_timing_cache + ) + + if latency == float("inf"): + if isinstance(model_path, bytes): + logger.warning("Benchmark failed for model bytes") + else: + logger.warning(f"Benchmark failed: {model_path}") + return float("inf") + + logger.debug(f"Benchmark result: {latency:.2f} ms") + return latency + + except Exception as e: + logger.error(f"Benchmark error: {e}", exc_info=True) + return float("inf") + + +def init_benchmark_instance( + use_trtexec: bool = False, + plugin_libraries: list[str] | None = None, + timing_cache_file: str | None = None, + warmup_runs: int = 5, + timing_runs: int = 20, +): + """Initialize global TensorRT benchmark instance for model performance measurement. + + Args: + use_trtexec: Whether to use trtexec for benchmarking. + plugin_libraries: List of paths to TensorRT plugin shared libraries (.so files). + These plugins will be loaded by trtexec or TensorRT Python API during engine building. + If None, no custom plugins are loaded. + timing_cache_file: Path to TensorRT timing cache file for faster engine builds. + If None, uses default "trtexec_timing.cache" (default: None) + warmup_runs: Number of warmup inference iterations before measurement. + Allows GPU to reach stable performance state (default: 5) + timing_runs: Number of timed inference iterations for latency measurement. + Higher values give more stable median (default: 20) + """ + global _benchmark_instance + try: + if use_trtexec: + _benchmark_instance = TrtExecBenchmark( + timing_cache_file=timing_cache_file, + warmup_runs=warmup_runs, + timing_runs=timing_runs, + plugin_libraries=plugin_libraries, + ) + logger.info("Trtexec benchmark initialized") + else: + _benchmark_instance = TensorRTPyBenchmark( + timing_cache_file=timing_cache_file, + warmup_runs=warmup_runs, + timing_runs=timing_runs, + plugin_libraries=plugin_libraries, + ) + logger.info("TensorRT Python API benchmark initialized") + logger.debug( + f"Settings: warmup={warmup_runs}, timing={timing_runs}, " + f"cache={timing_cache_file or 'trtexec_timing.cache'}, plugin_libraries={plugin_libraries}" + ) + return _benchmark_instance + except Exception as e: + logger.error(f"TensorRT initialization failed: {e}", exc_info=True) + return None + + +def _region_matches_filter(region, graph, filter_patterns: list[str]) -> bool: + """Check if any node in the region matches any of the filter patterns. + + Args: + region: Region object to check + graph: ONNX graph (graphsurgeon) containing node information + filter_patterns: List of wildcard patterns to match against node names + + Returns: + True if at least one node in the region matches any pattern, False otherwise + """ + if not filter_patterns: + return True + + node_indices = region.get_all_nodes_recursive() + + for node_idx in node_indices: + if node_idx < len(graph.nodes): + node_name = graph.nodes[node_idx].name + for pattern in filter_patterns: + if fnmatch.fnmatch(node_name, pattern): + return True + + return False + + +def region_pattern_autotuning_workflow( + model_path: str, + output_dir: Path, + num_schemes_per_region: int = 30, + pattern_cache_file: str | None = None, + state_file: str | None = None, + quant_type: str = "int8", + default_dq_dtype: str = "float32", + qdq_baseline_model: str | None = None, + node_filter_list: list[str] | None = None, +) -> QDQAutotuner: + """Run automated Q/DQ (Quantization/Dequantization) optimization on an ONNX model. + + This workflow uses pattern-based region optimization to efficiently find optimal + Q/DQ insertion points. The key insight: regions with identical structural patterns + can share the same Q/DQ scheme. When a best scheme is found for a pattern, it + automatically applies to all regions matching that pattern, making optimization + both efficient and consistent. + + Automatically discovers regions, generates and tests Q/DQ insertion schemes, + and exports optimized model. Supports incremental state saving for crash recovery + and pattern cache-based warm-start. + + **Workflow Steps:** + 1. Load model and initialize autotuner with automatic hierarchical region discovery + 2. Resume from checkpoint if state file exists (crash recovery) + 3. Load pattern cache if provided (warm-start with known-good schemes) + 4. Import Q/DQ patterns from baseline model if provided (transfer learning) + 5. Measure baseline performance without Q/DQ insertions + 6. For each discovered region pattern: + a. Generate Q/DQ insertion schemes (pattern-relative) + b. Build TensorRT engine and measure latency for each scheme + c. Select best scheme for this pattern (applies to all matching regions) + d. Save checkpoint and intermediate model + 7. Export final optimized model with best Q/DQ scheme for each pattern + + Args: + model_path: Path to ONNX model file to optimize + output_dir: Directory for output files (state, logs, models). Created if doesn't exist. + num_schemes_per_region: Number of Q/DQ insertion schemes to test per region pattern. + Higher values explore more configurations but take longer (default: 30) + pattern_cache_file: Optional path to pattern cache YAML file containing known-good schemes + from previous runs. Enables warm-start optimization (default: None) + state_file: Optional path to state file for checkpoint/resume. If None, automatically + uses /autotuner_state.yaml (default: None) + quant_type: Quantization data type - "int8" for INT8 quantization (default), + "fp8" for FP8 quantization + qdq_baseline_model: Optional path to a pre-quantized ONNX model. If provided, + extracts Q/DQ insertion patterns and adds them to pattern cache + for warm-start (default: None) + + Returns: + QDQAutotuner instance after autotuning + """ + output_dir.mkdir(parents=True, exist_ok=True) + logs_dir = output_dir / "logs" + logs_dir.mkdir(exist_ok=True) + models_dir = output_dir / "region_models" + models_dir.mkdir(exist_ok=True) + + if state_file is None: + state_file = str(output_dir / "autotuner_state.yaml") + state_path = Path(state_file) + + logger.info(f"Loading model: {model_path}") + model = onnx.load(model_path) + + pattern_cache = None + if pattern_cache_file: + pattern_cache_path = Path(pattern_cache_file) + if pattern_cache_path.exists(): + pattern_cache = PatternCache.load(str(pattern_cache_path)) + logger.info( + f"Loaded pattern cache: {pattern_cache.num_patterns} patterns, " + f"{pattern_cache.total_schemes} schemes" + ) + else: + logger.warning(f"Pattern cache not found: {pattern_cache_file}") + + logger.info( + f"Initializing autotuner (quant_type={quant_type}, default_dq_dtype={default_dq_dtype})" + ) + config = Config( + default_quant_type=quant_type, + default_dq_dtype=default_dq_dtype, + verbose=True, + ) + + autotuner = QDQAutotuner(model) + autotuner.initialize(config, pattern_cache) + + if state_path.exists(): + logger.info(f"Resuming from checkpoint: {state_path}") + autotuner.load_state(str(state_path)) + else: + logger.info("Starting new autotuning session") + + if qdq_baseline_model: + qdq_baseline_path = Path(qdq_baseline_model) + if qdq_baseline_path.exists(): + logger.info(f"Importing patterns from QDQ baseline: {qdq_baseline_model}") + qdq_model = onnx.load(str(qdq_baseline_path)) + quantized_tensors = get_quantized_tensors(qdq_model) + logger.debug(f"Found {len(quantized_tensors)} quantized tensors in baseline") + autotuner.import_insertion_points(quantized_tensors) + logger.info("Pattern import complete") + else: + logger.warning(f"QDQ baseline not found: {qdq_baseline_model}") + + regions = autotuner.regions + logger.info(f"Ready to profile {len(regions)} regions") + + if autotuner.baseline_latency_ms is None: + logger.info("Measuring baseline (no Q/DQ)") + baseline_path = output_dir / "baseline.onnx" + autotuner.export_onnx(str(baseline_path), insert_qdq=False) + baseline_log = logs_dir / "baseline.log" + baseline_latency = benchmark_onnx_model(str(baseline_path), str(baseline_log)) + autotuner.submit(baseline_latency) + logger.info(f"Baseline: {baseline_latency:.2f} ms") + else: + baseline_latency = autotuner.baseline_latency_ms + logger.info(f"Using baseline from checkpoint: {baseline_latency:.2f} ms") + + logger.info(f"Starting region profiling ({num_schemes_per_region} schemes per region)") + + iteration_count = 0 + + for region_idx, region in enumerate(regions): + logger.info( + f"Region {region_idx + 1}/{len(regions)} (ID={region.id}, level={region.get_level()})" + ) + + if node_filter_list and not _region_matches_filter( + region, autotuner.graph, node_filter_list + ): + logger.info(" Skipping (no nodes match filter patterns)") + continue + + commit = region_idx > 0 + autotuner.set_profile_region(region, commit=commit) + + if autotuner.current_profile_pattern_schemes is None: + logger.info(" Skipping (already profiled)") + continue + + schemes_tested = 0 + for scheme_num in range(num_schemes_per_region): + iteration_count += 1 + scheme_idx = autotuner.generate() + + if scheme_idx == -1: + logger.debug(f" Stopping at scheme {scheme_num + 1} (no more unique schemes)") + break + + schemes_tested += 1 + model_bytes = autotuner.export_onnx(None, insert_qdq=True) + test_log = logs_dir / f"region_{region.id}_scheme_{scheme_idx}.log" + flush_timing_cache = (iteration_count % 10) == 0 + latency = benchmark_onnx_model( + model_bytes, str(test_log), flush_timing_cache=flush_timing_cache + ) + + autotuner.submit(latency, success=(latency != float("inf"))) + + ps = autotuner.current_profile_pattern_schemes + if ps and ps.schemes: + best_scheme = ps.best_scheme + if best_scheme and best_scheme.latency_ms < float("inf") and baseline_latency > 0: + speedup = baseline_latency / best_scheme.latency_ms + logger.info( + f" Tested {schemes_tested} schemes: " + f"best {best_scheme.latency_ms:.2f} ms ({speedup:.3f}x speedup)" + ) + else: + logger.info(f" Tested {schemes_tested} schemes: no valid measurements") + else: + logger.info(f" Tested {schemes_tested} schemes") + + region_model_path = models_dir / f"region_{region.id}_level_{region.get_level()}.onnx" + autotuner.export_onnx(str(region_model_path), insert_qdq=True, best=True) + logger.debug(f" Saved best model: {region_model_path.name}") + + # Save state after each region (incremental, crash recovery) + autotuner.save_state(str(state_path)) + logger.debug(" Checkpoint saved") + + # Commit final region + autotuner.set_profile_region(None, commit=True) + + logger.info("Exporting final optimized model") + final_model_path = output_dir / "optimized_final.onnx" + autotuner.export_onnx(str(final_model_path), insert_qdq=True) + final_log = logs_dir / "final.log" + final_latency = benchmark_onnx_model(str(final_model_path), str(final_log)) + + if final_latency > 0 and final_latency != float("inf"): + speedup = baseline_latency / final_latency + logger.info( + f"Results: {baseline_latency:.2f} ms → {final_latency:.2f} ms ({speedup:.3f}x speedup)" + ) + else: + logger.info(f"Results: {baseline_latency:.2f} ms → failed (invalid measurement)") + + autotuner.save_state(str(state_path)) + + logger.info("Autotuning complete") + logger.info(f" Final model: {final_model_path}") + logger.info(f" State: {state_path}") + logger.debug(f" Logs: {logs_dir}") + logger.debug(f" Region models: {models_dir}") + + return autotuner diff --git a/modelopt/onnx/quantization/graph_utils.py b/modelopt/onnx/quantization/graph_utils.py index 67596d5df..a30a113ec 100755 --- a/modelopt/onnx/quantization/graph_utils.py +++ b/modelopt/onnx/quantization/graph_utils.py @@ -302,6 +302,23 @@ def get_tensor_consumer_nodes( return tensor_consumers +def get_tensor_consumer_node_indices(graph: onnx.GraphProto | gs.Graph) -> dict[str, list[int]]: + """Build a mapping from tensor names to the indices of nodes that use them. + + Args: + graph: ONNX GraphSurgeon graph to analyze + Returns: + Dictionary mapping tensor names to lists of node indices that consume them + """ + tensor_consumer_map: dict[str, list[int]] = defaultdict(list) + nodes = graph.nodes if isinstance(graph, gs.Graph) else graph.node + for node_idx, node in enumerate(nodes): + inputs = node.inputs if isinstance(node, gs.Node) else node.input + for tensor in inputs: + tensor_consumer_map[tensor.name].append(node_idx) + return tensor_consumer_map + + def filter_quantizable_kgen_heads( cask_fusible_partitions: list[list[Node]], kgen_partitions: list[list[Node]], diff --git a/modelopt/onnx/quantization/qdq_utils.py b/modelopt/onnx/quantization/qdq_utils.py index 026b8d062..794f8c728 100644 --- a/modelopt/onnx/quantization/qdq_utils.py +++ b/modelopt/onnx/quantization/qdq_utils.py @@ -1035,3 +1035,47 @@ def cast_initializer_to_dtype( input_onnx = onnx.numpy_helper.from_array(input, input_name) input_onnx.data_type = onnx_dtype_map[dtype] initializer_map[input_name].CopyFrom(input_onnx) + + +def get_quantized_tensors(onnx_model: onnx.ModelProto) -> set[str]: + """Get the names of all quantized tensors from an ONNX model. + + This function identifies all QuantizeLinear nodes in the ONNX model + and extracts the names of tensors being quantized (the first input of + each QuantizeLinear node, excluding scale and zero-point inputs). + + Args: + onnx_model: ONNX model protobuf to analyze + + Returns: + Set of tensor names that are inputs to QuantizeLinear nodes + (i.e., the tensors being quantized) + + Example: + >>> import onnx + >>> from modelopt.onnx.quantization.qdq_utils import get_quantized_tensors + >>> + >>> # Load a quantized model + >>> model = onnx.load("quantized_model.onnx") + >>> + >>> # Get all quantized tensor names + >>> quantized_tensors = get_quantized_tensors(model) + >>> print(f"Found {len(quantized_tensors)} quantized tensors") + >>> + >>> # Use with autotuner to import insertion points + >>> from modelopt.onnx.quantization.autotune import QDQAutotuner + >>> autotuner = QDQAutotuner(new_model) + >>> autotuner.initialize() + >>> autotuner.import_insertion_points(quantized_tensors) + """ + quantized_tensors = set() + + for node in onnx_model.graph.node: + if node.op_type == "QuantizeLinear": + # First input is the tensor being quantized + # (inputs[1] is scale, inputs[2] is zero-point) + if node.input and len(node.input) > 0: + quantized_tensors.add(node.input[0]) + + logger.debug(f"Found {len(quantized_tensors)} quantized tensors in ONNX model") + return quantized_tensors diff --git a/tests/unit/onnx/quantization/autotune/test_autotuner.py b/tests/unit/onnx/quantization/autotune/test_autotuner.py new file mode 100644 index 000000000..411256a49 --- /dev/null +++ b/tests/unit/onnx/quantization/autotune/test_autotuner.py @@ -0,0 +1,409 @@ +#!/usr/bin/env python3 +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Tests for QDQAutotuner class. + +Tests the main autotuner class public API. +Note: Full integration tests with TensorRT benchmarking should be in separate integration test files. +""" + +import os +import sys +import tempfile +import unittest + +# Add parent directory to path +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +import onnx +import onnx_graphsurgeon as gs +from onnx import helper + +from modelopt.onnx.quantization.autotune import Config, QDQAutotuner, RegionPattern +from modelopt.onnx.quantization.autotune.common import PatternCache, RegionType + + +def create_simple_conv_model(): + """ + Create a simple ONNX model: Input -> Conv -> Relu -> Output. + + This is a minimal model for testing autotuner initialization. + """ + # Input + input_tensor = helper.make_tensor_value_info("input", onnx.TensorProto.FLOAT, [1, 3, 224, 224]) + + # Output + output_tensor = helper.make_tensor_value_info( + "output", onnx.TensorProto.FLOAT, [1, 64, 224, 224] + ) + + # Conv node + conv_node = helper.make_node( + "Conv", inputs=["input", "conv_weight"], outputs=["conv_out"], name="conv" + ) + + # Relu node + relu_node = helper.make_node("Relu", inputs=["conv_out"], outputs=["output"], name="relu") + + # Create graph + graph = helper.make_graph( + [conv_node, relu_node], + "simple_conv", + [input_tensor], + [output_tensor], + initializer=[ + helper.make_tensor( + "conv_weight", onnx.TensorProto.FLOAT, [64, 3, 3, 3], [0.1] * (64 * 3 * 3 * 3) + ) + ], + ) + + # Create model + model = helper.make_model(graph, producer_name="test") + return model + + +class TestQDQAutotuner(unittest.TestCase): + """Test QDQAutotuner functionality.""" + + @staticmethod + def _create_test_config(): + """ + Create a reasonable config for testing. + + Uses sensible defaults suitable for unit tests: + - verbose=False: Keep test output clean + - maximum_sequence_region_size=50: Allow larger test regions + - Other parameters: Match Config defaults for typical behavior + """ + return Config( + # Logging + verbose=False, + # Performance Requirements + # Quantization Parameters + default_q_scale=0.1, + default_q_zero_point=0, + default_quant_type="int8", + # Region Builder Settings + maximum_sequence_region_size=50, + minimum_topdown_search_size=10, + # Scheme Generation Settings + top_percent_to_mutate=0.1, + minimum_schemes_to_mutate=10, + maximum_mutations=3, + maximum_generation_attempts=100, + # Pattern Cache Settings + pattern_cache_minimum_distance=4, + pattern_cache_max_entries_per_pattern=32, + ) + + def test_creation_with_onnx_model(self): + """Test creating autotuner with ONNX ModelProto.""" + model = create_simple_conv_model() + autotuner = QDQAutotuner(model) + + assert autotuner is not None + assert autotuner.onnx_model is not None + assert autotuner.graph is not None + print("✓ QDQAutotuner creation with ONNX model") + + def test_creation_with_gs_graph(self): + """Test creating autotuner with GraphSurgeon graph.""" + model = create_simple_conv_model() + gs_graph = gs.import_onnx(model) + + autotuner = QDQAutotuner(gs_graph) + + assert autotuner is not None + assert autotuner.graph is not None + print("✓ QDQAutotuner creation with GS graph") + + def test_initialize_with_default_config(self): + """Test initialization with default test config.""" + model = create_simple_conv_model() + autotuner = QDQAutotuner(model) + + config = self._create_test_config() + autotuner.initialize(config) + + # Should have provided config + assert autotuner.config is not None + assert autotuner.config.maximum_sequence_region_size == 50 + + # Should have discovered regions + assert len(autotuner.regions) > 0 + print("✓ QDQAutotuner initialize with default config") + + def test_initialize_with_config(self): + """Test initialization with custom config (different from default).""" + model = create_simple_conv_model() + autotuner = QDQAutotuner(model) + + # Create custom config with different values + config = Config( + verbose=True, + default_q_scale=0.05, + default_q_zero_point=128, + default_quant_type="fp8", + maximum_sequence_region_size=20, + minimum_topdown_search_size=5, + top_percent_to_mutate=0.2, + minimum_schemes_to_mutate=5, + maximum_mutations=5, + maximum_generation_attempts=50, + pattern_cache_minimum_distance=2, + pattern_cache_max_entries_per_pattern=16, + ) + autotuner.initialize(config) + + # Should use provided custom config values + assert autotuner.config.verbose + assert autotuner.config.default_q_scale == 0.05 + assert autotuner.config.default_q_zero_point == 128 + assert autotuner.config.default_quant_type == "fp8" + assert autotuner.config.maximum_sequence_region_size == 20 + assert autotuner.config.minimum_topdown_search_size == 5 + assert autotuner.config.top_percent_to_mutate == 0.2 + assert autotuner.config.minimum_schemes_to_mutate == 5 + assert autotuner.config.maximum_mutations == 5 + assert autotuner.config.maximum_generation_attempts == 50 + assert autotuner.config.pattern_cache_minimum_distance == 2 + assert autotuner.config.pattern_cache_max_entries_per_pattern == 16 + print("✓ QDQAutotuner initialize with config") + + def test_initialize_with_pattern_cache(self): + """Test initialization with pattern cache.""" + model = create_simple_conv_model() + autotuner = QDQAutotuner(model) + + config = self._create_test_config() + pattern_cache = PatternCache() + autotuner.initialize(config, pattern_cache=pattern_cache) + + assert autotuner.pattern_cache is not None + print("✓ QDQAutotuner initialize with pattern cache") + + def test_region_discovery(self): + """Test that regions are automatically discovered.""" + model = create_simple_conv_model() + autotuner = QDQAutotuner(model) + + config = self._create_test_config() + autotuner.initialize(config) + + # Should discover at least one region + assert len(autotuner.regions) > 0 + + # Regions should be valid + for region in autotuner.regions: + assert region.get_id() is not None + assert region.get_type() in [RegionType.LEAF, RegionType.COMPOSITE, RegionType.ROOT] + + print("✓ QDQAutotuner region discovery") + + def test_export_baseline_model(self): + """Test exporting baseline model without Q/DQ.""" + model = create_simple_conv_model() + autotuner = QDQAutotuner(model) + config = self._create_test_config() + autotuner.initialize(config) + + with tempfile.NamedTemporaryFile(suffix=".onnx", delete=False) as f: + output_path = f.name + + try: + # Export baseline without Q/DQ insertion + autotuner.export_onnx(output_path, insert_qdq=False) + + # Verify file was created + assert os.path.exists(output_path) + + # Verify it's a valid ONNX model + exported_model = onnx.load(output_path) + assert exported_model is not None + print("✓ QDQAutotuner export baseline model") + finally: + if os.path.exists(output_path): + os.unlink(output_path) + + def test_set_profile_region(self): + """Test setting a region for profiling.""" + model = create_simple_conv_model() + autotuner = QDQAutotuner(model) + config = self._create_test_config() + autotuner.initialize(config) + + if len(autotuner.regions) > 0: + region = autotuner.regions[0] + autotuner.set_profile_region(region) + + # Should set current profile region + assert autotuner.current_profile_region == region + assert autotuner.current_profile_pattern_schemes is not None + print("✓ QDQAutotuner set_profile_region") + else: + self.skipTest("No regions discovered") + + def test_generate_scheme(self): + """Test generating an insertion scheme.""" + model = create_simple_conv_model() + autotuner = QDQAutotuner(model) + config = self._create_test_config() + autotuner.initialize(config) + + if len(autotuner.regions) > 0: + region = autotuner.regions[0] + autotuner.set_profile_region(region) + + # Generate a scheme + scheme_idx = autotuner.generate() + + # Should return a valid index (>= 0) or -1 if no more unique schemes + assert isinstance(scheme_idx, int) + print("✓ QDQAutotuner generate scheme") + else: + self.skipTest("No regions discovered") + + def test_submit_latency(self): + """Test submitting performance measurement.""" + model = create_simple_conv_model() + autotuner = QDQAutotuner(model) + config = self._create_test_config() + autotuner.initialize(config) + + # Submit baseline latency + autotuner.submit(10.5) + + # Baseline should be recorded + assert autotuner.baseline_latency_ms == 10.5 + print("✓ QDQAutotuner submit latency") + + def test_save_and_load_state(self): + """Test saving and loading autotuner state.""" + model = create_simple_conv_model() + autotuner = QDQAutotuner(model) + config = self._create_test_config() + autotuner.initialize(config) + + # Submit some results + autotuner.submit(10.5) # baseline + + with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f: + state_path = f.name + + try: + # Save state + autotuner.save_state(state_path) + assert os.path.exists(state_path) + + # Create new autotuner and load state + autotuner2 = QDQAutotuner(model) + config2 = self._create_test_config() + autotuner2.initialize(config2) + autotuner2.load_state(state_path) + + # Baseline should match + assert autotuner2.baseline_latency_ms == 10.5 + print("✓ QDQAutotuner save and load state") + finally: + if os.path.exists(state_path): + os.unlink(state_path) + + def test_regions_prioritization(self): + """Test that LEAF regions are prioritized.""" + model = create_simple_conv_model() + autotuner = QDQAutotuner(model) + config = self._create_test_config() + autotuner.initialize(config) + + # Check that LEAF regions come before non-LEAF + leaf_indices = [ + i for i, r in enumerate(autotuner.regions) if r.get_type() == RegionType.LEAF + ] + non_leaf_indices = [ + i for i, r in enumerate(autotuner.regions) if r.get_type() != RegionType.LEAF + ] + + if leaf_indices and non_leaf_indices: + # All LEAF should come before non-LEAF + assert max(leaf_indices) < min(non_leaf_indices) + print("✓ QDQAutotuner LEAF region prioritization") + else: + print("✓ QDQAutotuner regions (not enough for prioritization test)") + + def test_profiled_patterns_tracking(self): + """Test that profiled patterns are tracked.""" + model = create_simple_conv_model() + autotuner = QDQAutotuner(model) + config = self._create_test_config() + autotuner.initialize(config) + + # Submit baseline latency first + autotuner.submit(10.0) + + if len(autotuner.regions) > 0: + region = autotuner.regions[0] + autotuner.set_profile_region(region) + + # Generate and submit a scheme + scheme_idx = autotuner.generate() + if scheme_idx >= 0: + autotuner.submit(12.0) + autotuner.set_profile_region(None, commit=True) + + # Pattern should be tracked + pattern_sig = RegionPattern.from_region(region, autotuner.graph).signature + profiled_patterns = [p.pattern.signature for p in autotuner.profiled_patterns] + assert pattern_sig in profiled_patterns + print("✓ QDQAutotuner profiled patterns tracking") + else: + print("✓ QDQAutotuner (no schemes to test tracking)") + else: + self.skipTest("No regions discovered") + + +def run_tests(): + """Run all QDQAutotuner tests.""" + print("=" * 70) + print("QDQAutotuner Test Suite") + print("=" * 70) + + loader = unittest.TestLoader() + suite = unittest.TestSuite() + suite.addTests(loader.loadTestsFromTestCase(TestQDQAutotuner)) + + runner = unittest.TextTestRunner(verbosity=2) + result = runner.run(suite) + + print("\n" + "=" * 70) + print("Test Summary") + print("=" * 70) + print(f"Tests run: {result.testsRun}") + print(f"Successes: {result.testsRun - len(result.failures) - len(result.errors)}") + print(f"Failures: {len(result.failures)}") + print(f"Errors: {len(result.errors)}") + + if result.wasSuccessful(): + print("\n✓ All QDQAutotuner tests passed!") + return 0 + else: + print("\n✗ Some tests failed") + return 1 + + +if __name__ == "__main__": + sys.exit(run_tests()) diff --git a/tests/unit/onnx/quantization/autotune/test_config.py b/tests/unit/onnx/quantization/autotune/test_config.py new file mode 100644 index 000000000..db6b02aa3 --- /dev/null +++ b/tests/unit/onnx/quantization/autotune/test_config.py @@ -0,0 +1,144 @@ +#!/usr/bin/env python3 +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Tests for the Config class in the autotuner. + +Tests configuration parameter validation and defaults. +""" + +import os +import sys +import unittest + +# Add parent directory to path +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +from modelopt.onnx.quantization.autotune.common import Config + + +class TestConfig(unittest.TestCase): + """Test Config class functionality.""" + + def test_default_values(self): + """Test that Config has correct default values.""" + config = Config() + + # Logging + assert not config.verbose + + # Performance thresholds + + # Q/DQ defaults + assert config.default_q_scale == 0.1 + assert config.default_q_zero_point == 0 + assert config.default_quant_type == "int8" + + # Region builder settings + assert config.maximum_sequence_region_size == 10 + assert config.minimum_topdown_search_size == 10 + + # Scheme generation parameters + assert config.top_percent_to_mutate == 0.1 + assert config.minimum_schemes_to_mutate == 10 + assert config.maximum_mutations == 3 + assert config.maximum_generation_attempts == 100 + + # Pattern cache parameters + assert config.pattern_cache_minimum_distance == 4 + assert config.pattern_cache_max_entries_per_pattern == 32 + + print("✓ Config default values are correct") + + def test_custom_values(self): + """Test creating Config with custom values.""" + config = Config( + verbose=True, + default_q_scale=0.05, + default_q_zero_point=128, + default_quant_type="fp8", + maximum_sequence_region_size=20, + ) + + assert config.verbose + assert config.default_q_scale == 0.05 + assert config.default_q_zero_point == 128 + assert config.default_quant_type == "fp8" + assert config.maximum_sequence_region_size == 20 + print("✓ Config custom values work correctly") + + def test_region_size_validation(self): + """Test that region size parameters are positive.""" + config = Config(maximum_sequence_region_size=50, minimum_topdown_search_size=5) + assert config.maximum_sequence_region_size > 0 + assert config.minimum_topdown_search_size > 0 + print("✓ Config region size validation") + + def test_genetic_algorithm_params(self): + """Test genetic algorithm parameters.""" + config = Config( + top_percent_to_mutate=0.2, + minimum_schemes_to_mutate=2, + maximum_mutations=5, + maximum_generation_attempts=50, + ) + + assert config.top_percent_to_mutate == 0.2 + assert config.minimum_schemes_to_mutate == 2 + assert config.maximum_mutations == 5 + assert config.maximum_generation_attempts == 50 + print("✓ Config genetic algorithm parameters") + + def test_pattern_cache_params(self): + """Test pattern cache parameters.""" + config = Config(pattern_cache_minimum_distance=3, pattern_cache_max_entries_per_pattern=10) + + assert config.pattern_cache_minimum_distance == 3 + assert config.pattern_cache_max_entries_per_pattern == 10 + print("✓ Config pattern cache parameters") + + +def run_tests(): + """Run all Config tests.""" + print("=" * 70) + print("Config Class Test Suite") + print("=" * 70) + + loader = unittest.TestLoader() + suite = unittest.TestSuite() + suite.addTests(loader.loadTestsFromTestCase(TestConfig)) + + runner = unittest.TextTestRunner(verbosity=2) + result = runner.run(suite) + + print("\n" + "=" * 70) + print("Test Summary") + print("=" * 70) + print(f"Tests run: {result.testsRun}") + print(f"Successes: {result.testsRun - len(result.failures) - len(result.errors)}") + print(f"Failures: {len(result.failures)}") + print(f"Errors: {len(result.errors)}") + + if result.wasSuccessful(): + print("\n✓ All Config tests passed!") + return 0 + else: + print("\n✗ Some tests failed") + return 1 + + +if __name__ == "__main__": + sys.exit(run_tests())