diff --git a/modelopt/onnx/op_types.py b/modelopt/onnx/op_types.py index cc94a221f..7e11d25e6 100644 --- a/modelopt/onnx/op_types.py +++ b/modelopt/onnx/op_types.py @@ -96,7 +96,7 @@ def is_fusible_scaling_op(op_type: str): ] -def get_copy_ops(): +def get_copy_ops() -> list[str]: """Returns list of copy operators.""" return [ "Flatten", @@ -303,3 +303,86 @@ def is_data_dependent_shape_op(op_type: str): "NonZero", "RoiAlign", ] + + +def get_bool_ops(): + """Returns set of bool operations.""" + return { + "Not", + "And", + "Or", + "Xor", + } + + +def get_bitwise_ops(): + """Returns set of bitwise operations.""" + return { + "BitwiseAnd", + "BitwiseOr", + "BitwiseXor", + "BitShift", + } + + +def get_value_check_ops(): + """Returns set of value checking operations.""" + return { + "IsNaN", + "IsInf", + "Sign", + "Abs", + } + + +def get_comparison_ops(): + """Returns set of comparison operations.""" + return { + "Equal", + "Greater", + "GreaterOrEqual", + "Less", + "LessOrEqual", + } + + +def get_conditional_ops(): + """Returns set of conditional operations.""" + return { + "Where", + } + + +def get_aggregation_ops(): + """Returns set of aggregation operations.""" + return { + "All", + "Any", + } + + +def get_set_ops(): + """Returns set of set/search operations.""" + return { + "Unique", + "NonZero", + } + + +def get_symmetric_ops(): + """Returns set of commutative/symmetric operations where operand order doesn't matter.""" + return { + "Add", + "Mul", + "And", + "Or", + "Xor", + "Equal", + "Max", + "Min", + "Sum", + "Mean", + "BitwiseAnd", + "BitwiseOr", + "BitwiseXor", + } diff --git a/modelopt/onnx/quantization/autotune/__init__.py b/modelopt/onnx/quantization/autotune/__init__.py new file mode 100644 index 000000000..f60dc917f --- /dev/null +++ b/modelopt/onnx/quantization/autotune/__init__.py @@ -0,0 +1,60 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Pattern-Based Q/DQ Autotuning for ONNX Models. + +This package provides automated optimization of Quantize/Dequantize (Q/DQ) node placement +in ONNX computation graphs to minimize TensorRT inference latency. It uses pattern-based +region analysis to efficiently explore and optimize Q/DQ insertion strategies. +""" + +# Core data structures +from .common import ( + AutotunerError, + AutotunerNotInitializedError, + Config, + InsertionScheme, + InvalidSchemeError, + PatternCache, + PatternSchemes, + Region, + RegionType, +) +from .insertion_points import ( + ChildRegionInputInsertionPoint, + ChildRegionOutputInsertionPoint, + NodeInputInsertionPoint, + ResolvedInsertionPoint, +) +from .region_pattern import RegionPattern +from .region_search import CombinedRegionSearch + +__all__ = [ + "AutotunerError", + "AutotunerNotInitializedError", + "ChildRegionInputInsertionPoint", + "ChildRegionOutputInsertionPoint", + "CombinedRegionSearch", + "Config", + "InsertionScheme", + "InvalidSchemeError", + "NodeInputInsertionPoint", + "PatternCache", + "PatternSchemes", + "Region", + "RegionPattern", + "RegionType", + "ResolvedInsertionPoint", +] diff --git a/modelopt/onnx/quantization/autotune/common.py b/modelopt/onnx/quantization/autotune/common.py new file mode 100644 index 000000000..db7c9b373 --- /dev/null +++ b/modelopt/onnx/quantization/autotune/common.py @@ -0,0 +1,812 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Common data structures and types for the QDQ Autotuner.""" + +import hashlib +from dataclasses import dataclass, field +from enum import Enum +from typing import TYPE_CHECKING, Any, Optional + +import onnx_graphsurgeon as gs +import yaml + +from modelopt.onnx.logging_config import logger +from modelopt.onnx.quantization.autotune.insertion_points import ( + ChildRegionInputInsertionPoint, + ChildRegionOutputInsertionPoint, + NodeInputInsertionPoint, + ResolvedInsertionPoint, +) + +if TYPE_CHECKING: + from modelopt.onnx.quantization.autotune.region_pattern import RegionPattern + + +class AutotunerError(Exception): + """Base exception for autotuner-related errors.""" + + +class AutotunerNotInitializedError(AutotunerError): + """Exception raised when autotuner is used without initialization.""" + + +class InvalidSchemeError(AutotunerError): + """Exception raised when an invalid scheme is referenced.""" + + +class RegionType(Enum): + """Region type enumeration for hierarchical graph structure. + + - LEAF: Atomic region containing direct nodes with no child regions + - COMPOSITE: Hierarchical region containing child regions (and optionally direct nodes) + - ROOT: Top-level region encompassing the entire computation graph + """ + + LEAF = "LEAF" + COMPOSITE = "COMPOSITE" + ROOT = "ROOT" + + +class Region: + """A subgraph region in an ONNX graph, used as the unit for Q/DQ insertion. + + Regions form a hierarchy: ROOT contains the entire graph, COMPOSITE regions + contain child regions, and LEAF regions contain only nodes. Each region tracks + its direct nodes, input/output tensors, and a pattern signature for matching + regions with identical structure. + """ + + def __init__(self, region_id: int, level: int, region_type: RegionType): + """Initialize a new region. + + Args: + region_id: Unique identifier within the region hierarchy + level: Hierarchical level (0 = leaf, higher = more composite) + region_type: Type classification (LEAF, COMPOSITE, or ROOT) + """ + self.id = region_id + self.level = level + self.type = region_type + self.parent: Region | None = None + self.children: list[Region] = [] + self.nodes: set[int] = set() + self.inputs: list[str] = [] + self.outputs: list[str] = [] + self.metadata: dict[str, str] = {} + + def get_children(self, *, sort: bool = False) -> list["Region"]: + """Get all child regions. If sort is True, sort the children by level and size. + + Args: + sort: Whether to sort the children by level and size + + Returns: + List of child regions + """ + if sort: + return sorted( + self.children, key=lambda r: (-r.level, r.get_size_of_region_and_descendants()) + ) + return self.children + + def remove_child(self, child: "Region") -> bool: + """Remove a child region from this region's children list.""" + if child not in self.children: + return False + self.children.remove(child) + if child.parent and child.parent.id == self.id: + child.parent = None + return True + + def add_child(self, child: "Region") -> None: + """Add a child sub-region.""" + if child.id == self.id: + logger.warning(f"Cannot add region {self.id} as its own child") + return + + if self.is_descendant_of(child): + logger.warning( + f"Cycle detected: region {self.id} is already a descendant of region {child.id}" + ) + return + + if child.parent is not None and child.parent.id != self.id: + old_parent_id = child.parent.id + logger.debug( + f"Re-parenting region {child.id}: moving from parent {old_parent_id} to {self.id}" + ) + child.parent.remove_child(child) + + if any(c.id == child.id for c in self.children): + logger.debug(f"Region {child.id} already child of {self.id}") + return + + self.children.append(child) + child.parent = self + + def is_descendant_of(self, potential_ancestor: "Region") -> bool: + """Check if this region is a descendant of potential_ancestor.""" + visited = set() + current = self.parent + while current: + if current.id in visited: + return False + visited.add(current.id) + if current.id == potential_ancestor.id: + return True + current = current.parent + return False + + def get_nodes(self, *, sort: bool = False) -> list[int]: + """Get direct node indices in this region only.""" + if sort: + return sorted(self.nodes) + return list(self.nodes) + + def get_region_nodes_and_descendants(self, _visited: set[int] | None = None) -> set[int]: + """Get all node indices recursively, including descendants.""" + if _visited is None: + _visited = set() + + # Detect cycles + assert self.id not in _visited, f"Cycle detected in region {self.id} during node traversal" + + _visited.add(self.id) + all_nodes = set(self.nodes) + for child in self.children: + all_nodes.update(child.get_region_nodes_and_descendants(_visited)) + return all_nodes + + def contains_node(self, node_index: int) -> bool: + """Check if region contains a specific node (direct only).""" + return node_index in self.nodes + + def contains_node_within_region_and_descendants(self, node_index: int) -> bool: + """Check if region contains a node recursively.""" + return node_index in self.get_region_nodes_and_descendants() + + def get_size_of_region_and_descendants(self, _visited: set[int] | None = None) -> int: + """Get total node count recursively including all descendants.""" + if _visited is None: + _visited = set() + + # Detect cycles + assert self.id not in _visited, ( + f"Cycle detected in region {self.id} during size calculation" + ) + + _visited.add(self.id) + total = len(self.nodes) + for child in self.children: + total += child.get_size_of_region_and_descendants(_visited) + return total + + def merge(self, other: "Region") -> None: + """Merge another region into this one.""" + if not other: + return + self.nodes.update(other.nodes) + for child in other.children: + self.add_child(child) + + def __repr__(self) -> str: + type_str = self.type.value + return ( + f"Region[id={self.id}, level={self.level}, type={type_str}, " + f"nodes={len(self.nodes)}, children={len(self.children)}, " + f"inputs={len(self.inputs)}, outputs={len(self.outputs)}]" + ) + + def compute_structural_signature(self, graph: gs.Graph) -> str: + """Compute deterministic structural signature for pattern matching. + + Creates a signature that uniquely identifies the region's topology, + node operations, and hierarchical structure. Regions with identical + signatures can share Q/DQ insertion schemes. + + Args: + graph: The ONNX graph containing the region's nodes + + Returns: + Signature string (e.g., "Conv->BatchNorm->Relu" or "COMPOSITE(...)") + """ + from modelopt.onnx.quantization.autotune.region_pattern import RegionPattern + + return RegionPattern.from_region(self, graph).signature + + +@dataclass +class InsertionScheme: + """Complete Q/DQ insertion specification for a region pattern. + + An InsertionScheme defines a complete Q/DQ configuration for a pattern, + combining both node-level and region-level insertion points. The scheme + is applied to all regions matching the pattern. + """ + + node_inputs: list[NodeInputInsertionPoint] = field(default_factory=list) + child_region_inputs: list[ChildRegionInputInsertionPoint] = field(default_factory=list) + region_outputs: list[ChildRegionOutputInsertionPoint] = field(default_factory=list) + latency_ms: float = float("inf") + error: bool = False + profile_timestamp: str | None = None + + @property + def hash(self) -> str: + """Compute deterministic hash for scheme identity. + + The hash uniquely identifies this scheme configuration based on its + insertion points. Two schemes with identical insertion points produce + the same hash, regardless of their measured latencies. + """ + sorted_nodes = sorted([(pt.node_index, pt.input_index) for pt in self.node_inputs]) + sorted_regions = sorted( + [(pt.region_index, pt.input_index) for pt in self.child_region_inputs] + ) + sorted_region_outputs = sorted( + [(pt.region_index, pt.node_index, pt.output_index) for pt in self.region_outputs] + ) + + hash_input = f"{sorted_nodes}|{sorted_regions}|{sorted_region_outputs}" + + return hashlib.sha256(hash_input.encode("utf-8")).hexdigest()[:32] + + @property + def is_empty(self) -> bool: + """Check if this is a baseline scheme with no Q/DQ insertions.""" + return not self.node_inputs and not self.child_region_inputs and not self.region_outputs + + @property + def is_profiled(self) -> bool: + """Check if this scheme has been profiled (measured). + + A scheme is considered profiled if it has been measured (has non-infinite latency) + or has encountered an error during measurement. + """ + return self.error or self.latency_ms != float("inf") + + def to_dict(self) -> dict[str, Any]: + """Convert to dictionary for serialization.""" + return { + "latency_ms": self.latency_ms, + "error": self.error, + "profile_timestamp": self.profile_timestamp, + "nodes_insertion_points": [pt.to_dict() for pt in self.node_inputs], + "child_region_inputs": [pt.to_dict() for pt in self.child_region_inputs], + "region_outputs": [pt.to_dict() for pt in self.region_outputs], + "hash": self.hash, + } + + @classmethod + def from_dict(cls, data: dict[str, Any]) -> "InsertionScheme": + """Create InsertionScheme from serialized dictionary.""" + scheme = cls() + scheme.latency_ms = data.get("latency_ms", float("inf")) + scheme.error = data.get("error", False) + scheme.profile_timestamp = data.get("profile_timestamp") + + scheme.node_inputs = [ + NodeInputInsertionPoint.from_dict(pt) for pt in data.get("nodes_insertion_points", []) + ] + scheme.child_region_inputs = [ + ChildRegionInputInsertionPoint.from_dict(pt) + for pt in data.get("child_region_inputs", []) + ] + scheme.region_outputs = [ + ChildRegionOutputInsertionPoint.from_dict(pt) for pt in data.get("region_outputs", []) + ] + + return scheme + + def distance(self, other: "InsertionScheme") -> int: + """Compute edit distance between this scheme and another scheme. + + The edit distance is the minimum number of add/remove operations needed + to transform this scheme into the other scheme. This is computed as the + symmetric difference between the insertion point sets. + + Args: + other: InsertionScheme to compare against + + Returns: + Total edit distance (number of add + remove operations) + """ + return ( + len(set(self.node_inputs).symmetric_difference(other.node_inputs)) + + len(set(self.child_region_inputs).symmetric_difference(other.child_region_inputs)) + + len(set(self.region_outputs).symmetric_difference(other.region_outputs)) + ) + + def __str__(self) -> str: + """String representation for debugging.""" + error_str = ", error=True" if self.error else "" + return ( + f"InsertionScheme(node_insertions={len(self.node_inputs)}, " + f"region_insertions={len(self.child_region_inputs)}, " + f"region_output_insertions={len(self.region_outputs)}, " + f"latency={self.latency_ms:.3f}ms{error_str})" + ) + + +@dataclass +class PatternSchemes: + """Collection of Q/DQ insertion schemes for a single pattern. + + Manages multiple InsertionScheme candidates for a region pattern, tracking + their performance and identifying the best-performing configuration. This + enables pattern-based optimization where all regions with the same structure + use the same Q/DQ insertion strategy. + + **Attributes:** + pattern: RegionPattern defining the structural signature + schemes: List of InsertionScheme candidates with measurements + """ + + pattern: Optional["RegionPattern"] = None + schemes: list[InsertionScheme] = field(default_factory=list) + + @property + def pattern_signature(self) -> str: + """Get the pattern signature string.""" + return self.pattern.signature if self.pattern else "" + + @property + def pattern_size(self) -> int: + """Get the pattern size (total node count).""" + return self.pattern.size if self.pattern else 0 + + @property + def best_scheme_index(self) -> int: + """Get index of the best performing scheme (lowest latency). + + Scans all schemes to find the one with minimum latency_ms, + excluding schemes with errors. + If no schemes exist or all have errors, returns -1. + + Returns: + Index of best scheme (without errors), or -1 if no valid schemes available + """ + if len(self.schemes) == 0: + return -1 + min_idx, min_latency = -1, float("inf") + for idx, scheme in enumerate(self.schemes): + if not scheme.error and scheme.latency_ms < min_latency: + min_idx = idx + min_latency = scheme.latency_ms + return min_idx + + @property + def best_scheme(self) -> InsertionScheme | None: + """Get the best performing scheme (lowest latency). + + Convenience property for accessing the optimal scheme directly + without needing to look up by index. Excludes schemes with errors. + + Returns: + InsertionScheme with lowest latency (excluding error schemes), + or None if no valid schemes exist + """ + index = self.best_scheme_index + if index < 0 or index >= len(self.schemes): + return None + return self.schemes[index] + + @property + def num_schemes(self) -> int: + """Get total number of schemes.""" + return len(self.schemes) + + @property + def has_schemes(self) -> bool: + """Check if any schemes have been added.""" + return len(self.schemes) > 0 + + def get_measured_schemes(self) -> list[InsertionScheme]: + """Get schemes that have been measured (finite latency). + + Returns: + List of schemes with performance measurements (excludes unmeasured schemes with inf latency) + """ + return [s for s in self.schemes if s.latency_ms != float("inf")] + + def get_valid_schemes(self) -> list[InsertionScheme]: + """Get schemes without errors. + + Returns: + List of schemes that completed successfully without errors + """ + return [s for s in self.schemes if not s.error] + + def to_dict(self) -> dict[str, Any]: + """Convert to dictionary for serialization.""" + return { + "pattern_signature": self.pattern_signature, + "pattern_size": self.pattern_size, + "schemes": [scheme.to_dict() for scheme in self.schemes], + } + + @classmethod + def from_dict( + cls, data: dict[str, Any], pattern: Optional["RegionPattern"] = None + ) -> "PatternSchemes": + """Create PatternSchemes from serialized dictionary. + + Reconstructs the pattern schemes collection from saved data. The + RegionPattern object must be provided separately since it's not + serialized (it's a runtime object computed from the graph). + + If no pattern is provided, creates a minimal RegionPattern from the + saved signature and size for signature matching purposes. + + Args: + data: Dictionary containing 'pattern_signature', 'pattern_size', + and 'schemes' keys + pattern: RegionPattern object to associate (must match signature). + If None, creates minimal pattern from saved data. + + Returns: + Reconstructed PatternSchemes instance + """ + from modelopt.onnx.quantization.autotune.region_pattern import RegionPattern + + ps = cls() + + if pattern is None and "pattern_signature" in data: + pattern = RegionPattern( + signature=data["pattern_signature"], size=data.get("pattern_size", 0) + ) + + ps.pattern = pattern + + ps.schemes = [ + InsertionScheme.from_dict(scheme_data) for scheme_data in data.get("schemes", []) + ] + + return ps + + def __str__(self) -> str: + """String representation for debugging.""" + best_latency = self.best_scheme.latency_ms if self.best_scheme else 0.0 + return ( + f"PatternSchemes(pattern='{self.pattern_signature[:40]}...', " + f"schemes={self.num_schemes}, best_latency={best_latency:.3f}ms)" + ) + + +@dataclass +class PatternCache: + """Pattern cache containing best-performing schemes for patterns with automatic eviction. + + Stores a collection of PatternSchemes that can be used as seeds for autotuning. + Each PatternSchemes contains high-performing insertion schemes for a specific + pattern signature. The cache automatically evicts non-performant schemes based on: + - Error status (schemes with errors are evicted) + - Duplicate schemes (only better-performing duplicate is kept) + - Similarity (similar schemes where only better-performing one is kept) + - Count limit (only top N best schemes are kept per pattern) + + **Attributes:** + pattern_schemes: List of PatternSchemes, one per pattern + minimum_distance: Minimum edit distance required between schemes in cache. + When adding new schemes, if a scheme is too similar (distance < minimum_distance) + to an existing scheme, only the better-performing one is kept (default: 4) + max_entries_per_pattern: Maximum number of schemes to keep per pattern. + Only the top N best-performing schemes are kept for each pattern. + Use 0 to keep all schemes (default: 32) + """ + + # List of PatternSchemes in the cache. + pattern_schemes: list[PatternSchemes] = field(default_factory=list) + # Minimum distance between schemes in cache. + minimum_distance: int = 4 + # Maximum number of schemes per pattern. + max_entries_per_pattern: int = 32 + + def add_pattern_schemes(self, pattern_schemes: PatternSchemes) -> None: + """Add PatternSchemes to pattern cache with automatic eviction of non-performant entries. + + Merges new schemes with existing schemes for the same pattern, automatically + evicting schemes that are non-performant based on multiple criteria. + + Args: + pattern_schemes: PatternSchemes to add to the cache + """ + if not pattern_schemes or not pattern_schemes.pattern: + return + + pattern_sig = pattern_schemes.pattern_signature + + # Find existing PatternSchemes for this pattern + existing_idx = None + for idx, ps in enumerate(self.pattern_schemes): + if ps.pattern_signature == pattern_sig: + existing_idx = idx + break + + # Collect all schemes (existing + new) + all_schemes = list(pattern_schemes.schemes) + if existing_idx is not None: + all_schemes.extend(self.pattern_schemes[existing_idx].schemes) + + # Filter out schemes with errors and deduplicate by hash + valid_schemes = [s for s in all_schemes if not s.error] + unique_schemes = {} + for scheme in valid_schemes: + scheme_hash = scheme.hash + if ( + scheme_hash not in unique_schemes + or scheme.latency_ms < unique_schemes[scheme_hash].latency_ms + ): + unique_schemes[scheme_hash] = scheme + + # Sort by latency to get best schemes + sorted_schemes = sorted(unique_schemes.values(), key=lambda s: s.latency_ms) + + # Apply distance-based filtering if minimum_distance > 0 + if self.minimum_distance > 0: + filtered_schemes = [] + for scheme in sorted_schemes: + # Check if this scheme is too similar to any already-filtered scheme + too_similar = False + schemes_to_replace = [] + for existing_scheme in filtered_schemes: + distance = scheme.distance(existing_scheme) + if distance < self.minimum_distance: + # Schemes are too similar, keep the better one + too_similar = True + if scheme.latency_ms < existing_scheme.latency_ms: + # New scheme is better, mark existing for replacement + schemes_to_replace.append(existing_scheme) + + if not too_similar: + filtered_schemes.append(scheme) + elif schemes_to_replace: + for scheme_to_replace in schemes_to_replace: + filtered_schemes.remove(scheme_to_replace) + filtered_schemes.append(scheme) + + sorted_schemes = filtered_schemes + + # Apply count limit if max_entries_per_pattern > 0 + # Keep only the top N best-performing schemes per pattern + if self.max_entries_per_pattern > 0: + sorted_schemes = sorted_schemes[: self.max_entries_per_pattern] + + # Create PatternSchemes with all schemes that passed the eviction criteria + result = PatternSchemes(pattern=pattern_schemes.pattern) + result.schemes = sorted_schemes + + # Replace existing or append new + if existing_idx is not None: + self.pattern_schemes[existing_idx] = result + else: + self.pattern_schemes.append(result) + + def get_pattern_schemes(self, pattern_signature: str) -> PatternSchemes | None: + """Get PatternSchemes for a specific pattern signature. + + Args: + pattern_signature: Pattern signature to lookup + + Returns: + PatternSchemes if found, None otherwise + """ + for ps in self.pattern_schemes: + if ps.pattern_signature == pattern_signature: + return ps + return None + + def has_pattern(self, pattern_signature: str) -> bool: + """Check if pattern cache contains a specific pattern. + + Args: + pattern_signature: Pattern signature to check + + Returns: + True if pattern exists in pattern cache + """ + return any(ps.pattern_signature == pattern_signature for ps in self.pattern_schemes) + + def add_pattern_from_region( + self, region: Region, graph: gs.Graph, quantized_tensors: set[str] + ) -> None: + """Build and add a pattern cache entry from a region in a quantized model. + + Analyzes a region from an already-quantized model to extract its Q/DQ + insertion scheme. This allows capturing known-good quantization strategies + from existing models and using them as seeds for autotuning. + + Args: + region: Region from the quantized model to analyze + graph: ONNX graph containing the region + quantized_tensors: Set of tensor names that have Q/DQ nodes + + """ + from modelopt.onnx.quantization.autotune.region_pattern import RegionPattern + + pattern = RegionPattern.from_region(region, graph) + scheme = InsertionScheme( + node_inputs=[], + child_region_inputs=[], + region_outputs=[], + latency_ms=float("inf"), + error=False, + ) + full_insertion_scheme = pattern.get_full_insertion_scheme(region, graph) + for point in full_insertion_scheme.node_inputs: + temp_scheme = InsertionScheme( + node_inputs=[point], + child_region_inputs=[], + region_outputs=[], + latency_ms=float("inf"), + error=False, + ) + temp_ips: list[ResolvedInsertionPoint] = pattern.matches(region, graph, temp_scheme) + temp_tensor_names = {tensor.tensor_name for tensor in temp_ips} + if len(temp_tensor_names.intersection(quantized_tensors)) > 0: + scheme.node_inputs.append(point) + if region.type == RegionType.COMPOSITE: + for child_point in full_insertion_scheme.child_region_inputs: + temp_scheme = InsertionScheme( + node_inputs=[], + child_region_inputs=[child_point], + region_outputs=[], + latency_ms=float("inf"), + error=False, + ) + temp_ips = pattern.matches(region, graph, temp_scheme) + temp_tensor_names = {tensor.tensor_name for tensor in temp_ips} + if len(temp_tensor_names.intersection(quantized_tensors)) > 0: + scheme.child_region_inputs.append(child_point) + for output_point in full_insertion_scheme.region_outputs: + temp_scheme = InsertionScheme( + node_inputs=[], + child_region_inputs=[], + region_outputs=[output_point], + latency_ms=float("inf"), + error=False, + ) + temp_ips = pattern.matches(region, graph, temp_scheme) + temp_tensor_names = {tensor.tensor_name for tensor in temp_ips} + if len(temp_tensor_names.intersection(quantized_tensors)) > 0: + scheme.region_outputs.append(output_point) + pattern_schemes = PatternSchemes(pattern=pattern, schemes=[scheme]) + self.add_pattern_schemes(pattern_schemes) + num_points = ( + len(scheme.node_inputs) + len(scheme.child_region_inputs) + len(scheme.region_outputs) + ) + logger.debug(f"Added pattern from region {region.id} with {num_points} insertion points") + if region.type == RegionType.COMPOSITE: + for child_region in region.get_children(): + self.add_pattern_from_region(child_region, graph, quantized_tensors) + + @property + def num_patterns(self) -> int: + """Get number of patterns in pattern cache.""" + return len(self.pattern_schemes) + + @property + def total_schemes(self) -> int: + """Get total number of schemes across all patterns.""" + return sum(ps.num_schemes for ps in self.pattern_schemes) + + def get_all_pattern_signatures(self) -> list[str]: + """Get list of all pattern signatures in pattern cache.""" + return [ps.pattern_signature for ps in self.pattern_schemes] + + def clear(self) -> None: + """Clear all pattern cache data.""" + self.pattern_schemes.clear() + + def merge(self, other: "PatternCache", prefer_existing: bool = True) -> None: + """Merge another PatternCache into this one.""" + for schemes in other.pattern_schemes: + if not self.has_pattern(schemes.pattern_signature) or not prefer_existing: + self.add_pattern_schemes(schemes) + + def to_dict(self) -> dict[str, Any]: + """Convert to dictionary for serialization.""" + return { + "minimum_distance": self.minimum_distance, + "max_entries_per_pattern": self.max_entries_per_pattern, + "pattern_schemes": [ps.to_dict() for ps in self.pattern_schemes], + } + + @classmethod + def from_dict(cls, data: dict[str, Any]) -> "PatternCache": + """Create PatternCache from serialized dictionary.""" + cache = cls( + minimum_distance=data.get("minimum_distance", 4), + max_entries_per_pattern=data.get("max_entries_per_pattern", 32), + ) + + for ps_data in data.get("pattern_schemes", []): + ps = PatternSchemes.from_dict(ps_data, pattern=None) + cache.pattern_schemes.append(ps) + + return cache + + def save(self, output_path: str) -> None: + """Save pattern cache to a YAML file.""" + state = self.to_dict() + + with open(output_path, "w") as f: + yaml.dump(state, f, default_flow_style=False, sort_keys=False) + + logger.info( + f"Saved pattern cache → {output_path} ({self.num_patterns} patterns, " + f"{self.total_schemes} schemes)" + ) + logger.debug( + f"Cache settings: min_distance={self.minimum_distance}, " + f"max_per_pattern={self.max_entries_per_pattern}" + ) + + @classmethod + def load(cls, input_path: str) -> "PatternCache": + """Load pattern cache from a YAML file.""" + with open(input_path) as f: + state = yaml.safe_load(f) + + cache = cls.from_dict(state) + + logger.info( + f"Loaded pattern cache from {input_path} ({cache.num_patterns} patterns, " + f"{cache.total_schemes} schemes)" + ) + logger.debug( + f"Cache settings: min_distance={cache.minimum_distance}, " + f"max_per_pattern={cache.max_entries_per_pattern}" + ) + + return cache + + def __str__(self) -> str: + """String representation for debugging.""" + return ( + f"PatternCache(patterns={self.num_patterns}, " + f"schemes={self.total_schemes}, " + f"minimum_distance={self.minimum_distance}, " + f"max_entries_per_pattern={self.max_entries_per_pattern})" + ) + + +@dataclass +class Config: + """Configuration parameters for QDQ autotuning.""" + + # Logging + verbose: bool = False + + # Quantization Parameters + default_q_scale: float = 0.1 + default_q_zero_point: int = 0 + default_quant_type: str = "int8" + default_dq_dtype: str = "float32" + + # Region Builder Settings + maximum_sequence_region_size: int = 10 + minimum_topdown_search_size: int = 10 + + # Scheme Generation Settings + top_percent_to_mutate: float = 0.1 + minimum_schemes_to_mutate: int = 10 + maximum_mutations: int = 3 + maximum_generation_attempts: int = 100 + + # Pattern Cache Settings + pattern_cache_minimum_distance: int = 4 + pattern_cache_max_entries_per_pattern: int = 32 diff --git a/modelopt/onnx/quantization/autotune/insertion_points.py b/modelopt/onnx/quantization/autotune/insertion_points.py new file mode 100644 index 000000000..dd01848dd --- /dev/null +++ b/modelopt/onnx/quantization/autotune/insertion_points.py @@ -0,0 +1,531 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Q/DQ insertion point management for ONNX quantization autotune. + +This module provides data structures and utilities for managing Quantization/Dequantization (Q/DQ) +insertion points in ONNX computational graphs during autotune optimization. It enables pattern-based +Q/DQ insertion that can be reused across multiple matching regions in a model. +""" + +from abc import ABC, abstractmethod +from dataclasses import asdict, dataclass +from typing import TYPE_CHECKING, Any + +import numpy as np +import onnx_graphsurgeon as gs + +if TYPE_CHECKING: + from modelopt.onnx.quantization.autotune.common import Region + +from modelopt.onnx.op_types import ( + get_aggregation_ops, + get_bitwise_ops, + get_bool_ops, + get_comparison_ops, + get_conditional_ops, + get_copy_ops, + get_set_ops, + get_value_check_ops, + is_fusible_reduction_op, +) +from modelopt.onnx.quantization.graph_utils import get_tensor_consumer_node_indices + + +class InsertionPoint(ABC): + """Abstract base class for pattern-relative Q/DQ insertion points.""" + + @abstractmethod + def to_dict(self) -> dict[str, Any]: + """Convert to dictionary for serialization.""" + ... + + @classmethod + @abstractmethod + def from_dict(cls, data: dict[str, Any]) -> "InsertionPoint": + """Create from dictionary.""" + ... + + @abstractmethod + def resolve(self, region: "Region", graph: gs.Graph) -> set["ResolvedInsertionPoint"]: + """Resolve pattern-relative insertion point to actual tensor names.""" + ... + + @staticmethod + @abstractmethod + def collect_from_region(region: "Region", graph: gs.Graph) -> list["InsertionPoint"]: + """Collect all valid insertion points of this type from a region.""" + ... + + +@dataclass(frozen=True) +class ResolvedInsertionPoint: + """Resolved Q/DQ insertion point with actual tensor name and optional node context. + + After resolving pattern-relative insertion points, this class represents the + actual location where Q/DQ pairs should be inserted in the graph. It contains the + tensor name and the node index (if applicable) and input index (if applicable). + + This class is immutable (frozen) to allow safe use in sets and as dict keys. + """ + + tensor_name: str + node_index: int | None = None # Absolute graph node index (or None for tensor-level insertion) + input_index: int | None = None # Input tensor index of that node (or None) + + def to_dict(self) -> dict[str, Any]: + """Convert to dictionary for serialization.""" + return asdict(self) + + @classmethod + def from_dict(cls, data: dict[str, Any]) -> "ResolvedInsertionPoint": + """Create from dictionary.""" + return cls(**data) + + +@dataclass(frozen=True) +class NodeInputInsertionPoint(InsertionPoint): + """Pattern-relative Q/DQ insertion point at a node's input (frozen/hashable). + + Specifies where to insert a Q/DQ pair within a region pattern using + pattern-relative indices rather than absolute node IDs. This enables + insertion scheme reuse across all regions matching the same pattern. + + This class is immutable (frozen) to allow safe use in sets and as dict keys. + """ + + node_index: int # Pattern-relative node index + input_index: int # Input tensor index of that node + + def to_dict(self) -> dict[str, Any]: + """Convert to dictionary for serialization.""" + return asdict(self) + + @classmethod + def from_dict(cls, data: dict[str, Any]) -> "NodeInputInsertionPoint": + """Create from dictionary.""" + return cls(node_index=data["node_index"], input_index=data["input_index"]) + + def resolve(self, region: "Region", graph: gs.Graph) -> set[ResolvedInsertionPoint]: + """Resolve a node input insertion point to actual tensor names for a matching region.""" + node_indices = region.get_nodes(sort=True) + assert self.node_index < len(node_indices), "Node index out of range" + actual_node_idx = node_indices[self.node_index] + node = graph.nodes[actual_node_idx] + assert self.input_index < len(node.inputs), "Input index out of range" + + resolved_ips = set() + # Determine which input indices to resolve (include weights for Conv/ConvTranspose) + input_indices = [self.input_index] + if node.op in ["Conv", "ConvTranspose"]: + assert self.input_index == 0, ( + "Conv/ConvTranspose inputs and weights must be quantized together" + ) + assert len(node.inputs) >= 2, "Conv/ConvTranspose should have at least 2 inputs" + input_indices.append(1) + + for idx in input_indices: + inp = node.inputs[idx] + if hasattr(inp, "name") and inp.name: + resolved_ips.add( + ResolvedInsertionPoint( + tensor_name=inp.name, node_index=actual_node_idx, input_index=idx + ) + ) + return resolved_ips + + @staticmethod + def collect_from_region(region: "Region", graph: gs.Graph) -> list["NodeInputInsertionPoint"]: + """Collect all valid node input insertion points from a region.""" + node_indices = region.get_nodes(sort=True) + insertion_points = [] + for local_idx, node_idx in enumerate(node_indices): + node = graph.nodes[node_idx] + for input_idx, inp in enumerate(node.inputs): + name = getattr(inp, "name", None) + if not name or skip_invalid_insertion_points(graph, name, node): + continue + insertion_points.append( + NodeInputInsertionPoint(node_index=local_idx, input_index=input_idx) + ) + return insertion_points + + +@dataclass(frozen=True) +class ChildRegionInputInsertionPoint(InsertionPoint): + """Pattern-relative Q/DQ insertion point at a child region's input boundary (frozen/hashable). + + Specifies where to insert Q/DQ pairs at the input boundaries of child regions + within COMPOSITE regions. This allows parent regions to control quantization + at child boundaries, potentially overriding or complementing child region + optimizations. + + Only applies to COMPOSITE regions; LEAF regions have no children. + + This class is immutable (frozen) to allow safe use in sets and as dict keys. + """ + + # Pattern-relative child region index + region_index: int + # Input tensor index of that child region + input_index: int + + def to_dict(self) -> dict[str, Any]: + """Convert to dictionary for serialization.""" + return asdict(self) + + @classmethod + def from_dict(cls, data: dict[str, Any]) -> "ChildRegionInputInsertionPoint": + """Create from dictionary.""" + return cls(**data) + + def resolve(self, region: "Region", graph: gs.Graph) -> set[ResolvedInsertionPoint]: + """Resolve a child region input insertion point to actual tensor names.""" + from modelopt.onnx.quantization.autotune.common import RegionType + + if region.type == RegionType.LEAF: + return set() + + children_regions = region.get_children(sort=True) + assert self.region_index < len(children_regions), "Child region index out of range" + child_region = children_regions[self.region_index] + assert self.input_index < len(child_region.inputs), "Input index out of range" + tensor_name = child_region.inputs[self.input_index] + return resolve_region_io_insertion_points(child_region, graph, tensor_name) + + @staticmethod + def collect_from_region( + region: "Region", graph: gs.Graph + ) -> list["ChildRegionInputInsertionPoint"]: + """Collect all valid child region input insertion points from a region.""" + from modelopt.onnx.quantization.autotune.common import RegionType + + if region.type == RegionType.LEAF: + return [] + + insertion_points = [] + for local_idx, child_region in enumerate(region.get_children(sort=True)): + for input_idx, inp in enumerate(child_region.inputs): + if skip_invalid_insertion_points(graph, inp, child_region): + continue + insertion_points.append( + ChildRegionInputInsertionPoint(region_index=local_idx, input_index=input_idx) + ) + return insertion_points + + +@dataclass(frozen=True) +class ChildRegionOutputInsertionPoint(InsertionPoint): + """Pattern-relative Q/DQ insertion point at a child region or node output (frozen/hashable). + + Specifies where to insert Q/DQ pairs at output boundaries. This can be either: + 1. Output from a child region (in COMPOSITE regions) + 2. Output from a node within the region + + This class is immutable (frozen) to allow safe use in sets and as dict keys. + """ + + region_index: int | None # Pattern-relative child region index (or None) + node_index: int | None # Pattern-relative node index (or None) + output_index: int # Output tensor index + + def to_dict(self) -> dict[str, Any]: + """Convert to dictionary for serialization.""" + return asdict(self) + + @classmethod + def from_dict(cls, data: dict[str, Any]) -> "ChildRegionOutputInsertionPoint": + """Create from dictionary.""" + return cls(**data) + + def resolve(self, region: "Region", graph: gs.Graph) -> set[ResolvedInsertionPoint]: + """Resolve a region output insertion point to actual tensor names.""" + if self.region_index is not None: + children_regions = region.get_children(sort=True) + assert self.region_index < len(children_regions), "Region index out of range" + child_region = children_regions[self.region_index] + assert self.output_index < len(child_region.outputs), "Output index out of range" + tensor_name = child_region.outputs[self.output_index] + return resolve_region_io_insertion_points(child_region, graph, tensor_name) + + if self.node_index is not None: + node_indices = region.get_nodes(sort=True) + assert self.node_index < len(node_indices), "Node index out of range" + node = graph.nodes[node_indices[self.node_index]] + assert self.output_index < len(node.outputs), "Output index out of range" + tensor = node.outputs[self.output_index] + assert hasattr(tensor, "name") and tensor.name, "Tensor name is required" + return resolve_region_io_insertion_points(None, graph, tensor.name) + + return set() + + @staticmethod + def collect_from_region( + region: "Region", graph: gs.Graph + ) -> list["ChildRegionOutputInsertionPoint"]: + """Collect all valid region output insertion points from a region.""" + from modelopt.onnx.quantization.autotune.common import RegionType + + region_outputs_set = set(region.outputs) + insertion_points = [] + + # For COMPOSITE regions: collect child region outputs + if region.type != RegionType.LEAF: + for local_idx, child_region in enumerate(region.get_children(sort=True)): + for output_idx, out in enumerate(child_region.outputs): + if out in region_outputs_set and not skip_invalid_insertion_points( + graph, out, child_region + ): + insertion_points.append( + ChildRegionOutputInsertionPoint( + region_index=local_idx, node_index=None, output_index=output_idx + ) + ) + + # For all regions: collect node outputs + for local_idx, node_idx in enumerate(region.get_nodes(sort=True)): + node = graph.nodes[node_idx] + for output_idx, out in enumerate(node.outputs): + if not (hasattr(out, "name") and out.name): + continue + if out.name in region_outputs_set and not skip_invalid_insertion_points( + graph, out.name, node + ): + insertion_points.append( + ChildRegionOutputInsertionPoint( + region_index=None, node_index=local_idx, output_index=output_idx + ) + ) + + return insertion_points + + +def skip_invalid_insertion_points( + graph: gs.Graph, tensor_name: str, region_or_node: "Region | gs.Node" +) -> bool: + """Determine if a tensor should be skipped for Q/DQ insertion. + + Filters out tensors that are not suitable for quantization based on various criteria: + - Boolean and shape operations (not quantizable) + - Fused operation patterns (Conv->BatchNorm->ReLU) + - Operation-specific non-quantizable inputs (weights, biases, BN parameters) + - Non-floating-point tensors (indices, masks) + - Small tensors (scalars, small vectors with < 8 elements) + + Args: + graph: The ONNX graph containing the nodes + tensor_name: Name of the tensor to evaluate + region_or_node: Either a Region or a Node to check for usage of this tensor + + Returns: + True if the insertion point should be skipped, False if it's valid for quantization + """ + from modelopt.onnx.quantization.autotune.common import Region + + if isinstance(region_or_node, Region): + node_indices = region_or_node.get_region_nodes_and_descendants() + nodes: list[gs.Node] = [graph.nodes[node_idx] for node_idx in node_indices] + else: + assert isinstance(region_or_node, gs.Node) + nodes = [region_or_node] + + for node in nodes: + for input_idx, inp in enumerate(node.inputs): + if hasattr(inp, "name") and inp.name == tensor_name: + # Skip weights of Conv and ConvTranspose, they should be quantized with inputs at same time + if node.op in ["Conv", "ConvTranspose"] and input_idx >= 1: + return True + # Conv -> ReLU/Softmax or Conv -> BatchNormalization -> ReLU/Softmax + if node.op in ["Relu", "Softmax"]: + if len(node.inputs) == 1 and len(node.inputs[0].inputs) == 1: + producer = node.inputs[0].inputs[0] + if producer.op in ["Conv", "ConvTranspose"]: + return True + if ( + producer.op == "BatchNormalization" + and len(producer.inputs[0].inputs) == 1 + and producer.inputs[0].inputs[0].op in ["Conv", "ConvTranspose"] + ): + return True + # Conv -> BatchNormalization + if node.op == "BatchNormalization": + assert len(node.inputs) >= 1, "BN node should have more than one inputs" + if len(node.inputs[0].inputs) == 1: + producer = node.inputs[0].inputs[0] + if producer.op in ["Conv", "ConvTranspose"]: + return True + # Filter 1: out boolean operations + if node.op in ( + get_bool_ops() + | get_bitwise_ops() + | get_value_check_ops() + | get_comparison_ops() + | get_conditional_ops() + | get_aggregation_ops() + | get_set_ops() + ) or is_fusible_reduction_op(node.op): + return True + # Filter 2: out shape operations + if node.op in get_autotuner_skip_ops(): + return True + # Filter 3: Skip operation-specific non-quantizable inputs + if node.op in ["BatchNormalization", "Resize"] and input_idx >= 1: + return True + if node.op in ["Conv", "Gemm"] and input_idx >= 2: + return True + # Filter 4: Skip non-floating-point tensors (int/bool indices, masks, etc.) + if hasattr(inp, "dtype") and inp.dtype not in [ + None, + np.float32, + np.float16, + np.float64, + ]: + return True + # Filter 5: Skip small tensors (scalars, small vectors) + if hasattr(inp, "shape") and inp.shape is not None: + if all(isinstance(s, int) for s in inp.shape): + if np.prod(inp.shape) < 8: + return True + return False + + +def has_quantizable_operations(region: "Region", graph: gs.Graph) -> bool: + """Check if a region contains major quantizable operations (only checks LEAF regions). + + Args: + region: The region to check + graph: The ONNX graph containing the nodes + + Returns: + True if the region contains major quantizable operations, False otherwise + """ + from modelopt.onnx.quantization.autotune.common import RegionType + + if region.type != RegionType.LEAF: + return True + region_ops = {graph.nodes[idx].op for idx in region.get_nodes()} + return bool(region_ops & get_autotuner_quantizable_ops()) + + +def resolve_region_io_insertion_points( + region: "Region | None", graph: gs.Graph, tensor_name: str +) -> set[ResolvedInsertionPoint]: + """Resolve region input/output boundaries to actual Q/DQ insertion points. + + For a given tensor at a region boundary (input or output), this function + identifies all the actual node inputs where Q/DQ pairs should be inserted. + It considers both nodes within the region (if provided) and all users of + the tensor in the graph. + + Args: + region: The region to search within (or None to search entire graph) + graph: The ONNX graph containing the nodes + tensor_name: Name of the tensor at the region boundary + + Returns: + Set of ResolvedInsertionPoint objects specifying where to insert Q/DQ pairs + """ + tensor_users_map = getattr(graph, "tensor_users_map", None) or get_tensor_consumer_node_indices( + graph + ) + + node_indices: set[int] = set() + if region is not None: + node_indices.update(region.get_region_nodes_and_descendants()) + node_indices.update(tensor_users_map.get(tensor_name, [])) + + resolved = set() + for node_idx in node_indices: + node = graph.nodes[node_idx] + for input_idx, inp in enumerate(node.inputs): + if hasattr(inp, "name") and inp.name == tensor_name: + if not skip_invalid_insertion_points(graph, tensor_name, node): + resolved.add( + ResolvedInsertionPoint( + tensor_name=tensor_name, node_index=node_idx, input_index=input_idx + ) + ) + return resolved + + +def merge_resolved_insertion_points( + graph: gs.Graph, resolved_insertion_points: set[ResolvedInsertionPoint] +) -> set[ResolvedInsertionPoint]: + """Optimize insertion points by merging node-specific insertions into tensor-level insertions. + + When all consumers (users) of a tensor have Q/DQ insertion points, it's more efficient + to insert Q/DQ once at the tensor level rather than at each individual node input. + This reduces the number of Q/DQ nodes in the graph and simplifies the quantization scheme. + + Args: + graph: The ONNX graph containing the nodes + resolved_insertion_points: Set of resolved insertion points to optimize + + Returns: + Optimized set of insertion points with merged tensor-level insertions where possible + """ + tensor_users_map = get_tensor_consumer_node_indices(graph) + node_ips = {ip for ip in resolved_insertion_points if ip.node_index is not None} + + results = resolved_insertion_points - node_ips + for tensor_name in {ip.tensor_name for ip in node_ips}: + all_users = set(tensor_users_map.get(tensor_name, [])) + qdq_users = {ip for ip in node_ips if ip.tensor_name == tensor_name} + if all_users == {ip.node_index for ip in qdq_users}: + results.add( + ResolvedInsertionPoint(tensor_name=tensor_name, node_index=None, input_index=None) + ) + else: + results.update(qdq_users) + return results + + +def get_autotuner_skip_ops(): + """Returns set of shape/structural operations that are not quantizable.""" + return set(get_copy_ops()) | { + # Additional indexing/scatter/reshape ops + "Compress", + "Scatter", + "ExpandDims", + "Unsqueeze", + "View", + "Pad", + # Utility ops + "Cast", + "Ceil", + "Clip", + "Identity", + "Range", + "Shape", + } + + +def get_autotuner_quantizable_ops(): + """Returns set of key operations that benefit from quantization.""" + return { + "Conv", + "ConvTranspose", + "Gemm", + "MatMul", + "AveragePool", + "MaxPool", + "GlobalAveragePool", + "GlobalMaxPool", + "Resize", + "Add", + "Sum", + "Mul", + "Relu", + } diff --git a/modelopt/onnx/quantization/autotune/region_inspect.py b/modelopt/onnx/quantization/autotune/region_inspect.py new file mode 100644 index 000000000..8c0950fe9 --- /dev/null +++ b/modelopt/onnx/quantization/autotune/region_inspect.py @@ -0,0 +1,203 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Region search inspection tool for ONNX models.""" + +import argparse +import logging +import sys +from collections import Counter + +import onnx +import onnx_graphsurgeon as gs + +from modelopt.onnx.logging_config import logger +from modelopt.onnx.quantization.autotune.common import Region, RegionType +from modelopt.onnx.quantization.autotune.insertion_points import has_quantizable_operations +from modelopt.onnx.quantization.autotune.region_search import ( + DEFAULT_MAX_STEPS, + CombinedRegionSearch, +) + + +def inspect_region_search( + onnx_path: str, + max_sequence_size: int = 10, + include_all_regions: bool = False, +) -> list[Region]: + """Inspect region search results for an ONNX model. + + This function loads an ONNX model, runs CombinedRegionSearch (which performs + both bottom-up partitioning and top-down refinement internally), and prints + detailed information about the discovered regions including their hierarchical + structure. + + **What it does:** + 1. Loads ONNX model and converts to GraphSurgeon format + 2. Creates CombinedRegionSearch instance with specified parameters + 3. Runs two-phase search (partitioning + refinement) via search() + 4. Displays detailed region structure and statistics + 5. Returns the final list of refined regions + + **Output Sections:** + - Initialization: Shows search parameters + - Two-Phase Search: Runs automatically via CombinedRegionSearch.search() + - Detailed Structure: Shows each region's hierarchy and properties + - Summary Statistics: Shows region counts and node coverage + + Args: + onnx_path: Path to the ONNX model file + max_sequence_size: Maximum size for sequence regions during refinement (default: 10) + include_all_regions: Include all regions, even those without major quantizable + operations (Conv, MatMul, etc.). Default: False (skips such regions) + + Returns: + List of discovered and refined regions (LEAF and COMPOSITE) + """ + # Load ONNX model + logger.info(f"Loading model: {onnx_path}") + onnx_model = onnx.load(onnx_path) + # Convert to onnx_graphsurgeon Graph + graph = gs.import_onnx(onnx_model) + graph.cleanup().toposort() + logger.info( + f"Loaded graph: {len(graph.nodes)} nodes, {len(graph.inputs)} inputs, {len(graph.outputs)} outputs" + ) + # Initialize CombinedRegionSearch (contains RegionPartitioner internally) + logger.debug( + f"Search parameters: max_steps={DEFAULT_MAX_STEPS}, max_sequence_size={max_sequence_size}" + ) + + combined_search = CombinedRegionSearch(graph, maximum_sequence_region_size=max_sequence_size) + + # Run complete two-phase region search + logger.info("Running region search") + regions = combined_search.search_regions() + # Show detailed region structure + logger.info("Analyzing region structure") + all_regions = [] + for i, region in enumerate(regions): + region.children = [ + c + for c in region.get_children() + if include_all_regions or has_quantizable_operations(c, graph) + ] + if not include_all_regions and not has_quantizable_operations(region, graph): + logger.debug(f"Filtered out region {i} (no quantizable operations)") + continue + logger.debug( + f"Region {i}: {region.type.value}, {len(region.get_region_nodes_and_descendants())} nodes, " + f"{len(region.inputs)} inputs, {len(region.outputs)} outputs" + ) + all_regions.append(region) + if region.type == RegionType.COMPOSITE: + logger.debug(f" {len(region.get_children())} child regions") + all_regions.extend(region.get_children()) + combined_search.print_tree(region, indent=2) + + # Summary statistics + type_counts = Counter(r.type for r in all_regions) + leaf_regions, composite_regions = ( + type_counts[RegionType.LEAF], + type_counts[RegionType.COMPOSITE], + ) + + all_nodes = {n for r in all_regions for n in r.get_region_nodes_and_descendants()} + total_nodes = len(all_nodes) + coverage_pct = 100 * total_nodes / len(graph.nodes) if graph.nodes else 0 + + logger.info( + f"Summary: {len(all_regions)} regions ({leaf_regions} LEAF, {composite_regions} COMPOSITE), " + f"{total_nodes}/{len(graph.nodes)} nodes ({coverage_pct:.1f}%)" + ) + + # Print histogram of region sizes + region_sizes = [ + len(r.get_region_nodes_and_descendants()) for r in all_regions if r.type == RegionType.LEAF + ] + + if region_sizes: + min_size = min(region_sizes) + max_size = max(region_sizes) + avg_size = sum(region_sizes) / len(region_sizes) + + logger.info(f"LEAF region sizes: min={min_size}, max={max_size}, avg={avg_size:.1f}") + size_counts = Counter(region_sizes) + logger.debug("Size distribution:") + for size in sorted(size_counts.keys()): + count = size_counts[size] + bar = "█" * min(count, 50) + logger.debug(f" {size:4d} nodes: {bar} ({count} regions)") + + return all_regions + + +def main(): + """Command-line entry point for region search inspection.""" + parser = argparse.ArgumentParser( + prog="modelopt.onnx.quantization.autotune.region_inspect", + description="Inspect region search results for ONNX models", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=""" +Examples: + # Basic inspection + python -m modelopt.onnx.quantization.autotune.region_inspect --model model.onnx + + # Verbose mode for debug logging + python -m modelopt.onnx.quantization.autotune.region_inspect \\ + --model model.onnx --verbose + + # Custom maximum sequence size + python -m modelopt.onnx.quantization.autotune.region_inspect \\ + --model model.onnx --max-sequence-size 20 + """, + ) + + parser.add_argument("--model", "-m", type=str, required=True, help="Path to ONNX model file") + parser.add_argument( + "--max-sequence-size", + type=int, + default=10, + help="Maximum size for sequence regions during refinement (default: 10)", + ) + parser.add_argument( + "--include-all-regions", + action="store_true", + help="Include all regions, even those without major quantizable operations. " + "Default: False (skips such regions)", + ) + parser.add_argument("--verbose", "-v", action="store_true", help="Enable verbose debug logging") + + args = parser.parse_args() + + log_level = logging.DEBUG if args.verbose else logging.INFO + logging.basicConfig(level=log_level, format="%(asctime)s - %(levelname)s - %(message)s") + logger.setLevel(log_level) + + try: + regions = inspect_region_search( + onnx_path=args.model, + max_sequence_size=args.max_sequence_size, + include_all_regions=args.include_all_regions, + ) + logger.info(f"✓ Inspection complete: {len(regions)} top-level regions discovered") + return 0 + except Exception as e: + logger.error(f"Inspection failed: {e}", exc_info=args.verbose) + return 1 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/modelopt/onnx/quantization/autotune/region_pattern.py b/modelopt/onnx/quantization/autotune/region_pattern.py new file mode 100644 index 000000000..60d3225b6 --- /dev/null +++ b/modelopt/onnx/quantization/autotune/region_pattern.py @@ -0,0 +1,300 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Region pattern signature generator for grouping structurally similar regions.""" + +import hashlib +from typing import Union, overload + +import onnx_graphsurgeon as gs + +from modelopt.onnx.op_types import get_symmetric_ops +from modelopt.onnx.quantization.autotune.common import InsertionScheme, Region +from modelopt.onnx.quantization.autotune.insertion_points import ( + ChildRegionInputInsertionPoint, + ChildRegionOutputInsertionPoint, + NodeInputInsertionPoint, + ResolvedInsertionPoint, +) + + +class RegionPattern: + """Represents a structural pattern of a region.""" + + def __init__(self, signature: str, size: int): + """Initialize a region pattern.""" + self.signature = signature + self.size = size + + @property + def is_empty(self) -> bool: + """Check if the pattern represents an empty region.""" + return self.size == 0 + + @property + def is_composite(self) -> bool: + """Check if the pattern represents a composite region.""" + return self.signature.startswith("COMPOSITE(") + + @property + def is_leaf(self) -> bool: + """Check if the pattern represents a leaf region (no composite structure).""" + return not self.is_composite and not self.is_empty + + def __str__(self) -> str: + """String representation of the pattern.""" + return self.signature + + def __repr__(self) -> str: + """Developer-friendly representation with signature and size.""" + return f"RegionPattern('{self.signature}', size={self.size})" + + def __eq__(self, other) -> bool: + """Check equality based on signature only.""" + if not isinstance(other, RegionPattern): + return False + return self.signature == other.signature + + def __hash__(self) -> int: + """Hash based on signature for use as dict key.""" + return hash(self.signature) + + def get_hash(self) -> str: + """Get a 128-bit cryptographic hash of the pattern signature.""" + return hashlib.sha256(self.signature.encode("utf-8")).hexdigest()[:32] + + def get_short_signature(self, max_length: int = 80) -> str: + """Get a truncated version of the signature for display purposes.""" + if len(self.signature) <= max_length: + return self.signature + return self.signature[: max_length - 3] + "..." + + @classmethod + def from_region(cls, region: Region, graph: gs.Graph) -> "RegionPattern": + """Compute a structural pattern for a region.""" + signature_str = cls._compute_signature_recursive(region, graph) + total_size = len(region.get_region_nodes_and_descendants()) + return cls(signature_str, total_size) + + @overload + def matches(self, other: "RegionPattern") -> bool: ... + @overload + def matches(self, other: Region, graph: gs.Graph, scheme: None = None) -> list[int] | None: ... + @overload + def matches( + self, other: Region, graph: gs.Graph, scheme: InsertionScheme + ) -> set[ResolvedInsertionPoint]: ... + + def matches( + self, + other: Union["RegionPattern", Region], + graph: gs.Graph | None = None, + scheme: InsertionScheme | None = None, + ) -> bool | list[int] | set[ResolvedInsertionPoint] | None: + """Check if this pattern matches another pattern or region. + + Args: + other: Either a RegionPattern or Region to compare with + graph: Required when other is a Region (for computing its pattern) + scheme: Optional InsertionScheme containing node_inputs, + child_region_inputs, and region_outputs + to resolve to tensor names + + Returns: + - True if other is RegionPattern and patterns match + - List of node IDs in pattern order if other is Region and scheme is None, None if no match + - Set of resolved insertion points for Q/DQ insertion if other is Region and scheme is provided + + Raises: + ValueError: If other is Region but graph is not provided, or if scheme + is provided but other is not a Region + TypeError: If other is neither RegionPattern nor Region + """ + if isinstance(other, RegionPattern): + if scheme is not None: + raise ValueError("scheme parameter can only be used when matching against a Region") + return self._matches_pattern(other) + elif isinstance(other, Region) and scheme is None: + return self._matches_region(other, graph) + elif isinstance(other, Region) and scheme is not None: + if graph is None: + raise ValueError("graph parameter is required") + + region_pattern = RegionPattern.from_region(other, graph) + if self != region_pattern: + return set() + + resolved_ips = set() + for ip in scheme.node_inputs: + resolved_ips.update(ip.resolve(other, graph)) + for ip in scheme.child_region_inputs: + resolved_ips.update(ip.resolve(other, graph)) + for ip in scheme.region_outputs: + resolved_ips.update(ip.resolve(other, graph)) + return resolved_ips + else: + raise TypeError(f"Expected RegionPattern or Region, got {type(other).__name__}") + + def _matches_pattern(self, other: "RegionPattern") -> bool: + """Internal function: Match this pattern against another pattern.""" + return self == other + + def _matches_region(self, region: Region, graph: gs.Graph | None) -> list[int] | None: + """Internal function: Match this pattern against a region.""" + if graph is None: + raise ValueError("graph parameter is required when matching against a Region") + + region_pattern = RegionPattern.from_region(region, graph) + + if self == region_pattern: + return self._collect_nodes_in_match_order(region) + else: + return None + + def get_full_insertion_scheme(self, region: Region, graph: gs.Graph) -> InsertionScheme: + """Get all possible insertion points for a region in a single InsertionScheme.""" + region_pattern = RegionPattern.from_region(region, graph) + assert self == region_pattern, "Region pattern mismatch" + + scheme = InsertionScheme() + scheme.node_inputs = NodeInputInsertionPoint.collect_from_region(region, graph) + scheme.child_region_inputs = ChildRegionInputInsertionPoint.collect_from_region( + region, graph + ) + scheme.region_outputs = ChildRegionOutputInsertionPoint.collect_from_region(region, graph) + + return scheme + + def format_tree(self, region: Region, graph: gs.Graph, indent: int = 0) -> str: + """Format this pattern and region as a human-readable tree.""" + prefix = " " * indent + result = f"{prefix}Region {region.id}: {self.signature} (size={self.size})\n" + + for child in region.get_children(): + child_pattern = RegionPattern.from_region(child, graph) + result += child_pattern.format_tree(child, graph, indent + 1) + + return result + + @staticmethod + def _collect_nodes_in_match_order(region: Region) -> list[int]: + """Collect node IDs in the same order as signature computation.""" + node_ids = [] + + node_ids.extend(region.get_nodes(sort=True)) + sorted_children = region.get_children(sort=True) + + for child in sorted_children: + node_ids.extend(RegionPattern._collect_nodes_in_match_order(child)) + + return node_ids + + @staticmethod + def _compute_signature_recursive(region: Region, graph: gs.Graph) -> str: + """Recursively compute structural signature for a region.""" + nodes_list = list(graph.nodes) + node_indices_set = set(region.get_nodes()) + + node_ops = [ + RegionPattern._make_node_with_params_signature(nodes_list[idx], graph, node_indices_set) + for idx in sorted(node_indices_set) + if idx < len(nodes_list) + ] + + sorted_children = region.get_children(sort=True) + + if not sorted_children and not node_ops: + return "EMPTY" + + if not sorted_children: + return "->".join(node_ops) + + child_sigs = "+".join( + [RegionPattern._compute_signature_recursive(child, graph) for child in sorted_children] + ) + + if node_ops: + node_sig = "->".join(node_ops) + return f"COMPOSITE({node_sig}|{child_sigs})" + return f"COMPOSITE({'+'.join(child_sigs)})" + + @staticmethod + def _get_symmetric_input_signature( + node: gs.Node, graph: gs.Graph, region_node_indices: set + ) -> str | None: + """Compute normalized input source signature for symmetric operations.""" + if node.op not in get_symmetric_ops() or len(node.inputs) <= 1: + return None + + nodes_list = list(graph.nodes) + node_to_idx = {id(n): idx for idx, n in enumerate(nodes_list)} + + input_sources = [] + for inp in node.inputs: + if inp is None or not hasattr(inp, "inputs") or not inp.inputs: + input_sources.append(("external", "input-or-constant")) + else: + producer_node = inp.inputs[0] if inp.inputs else None + if producer_node and id(producer_node) in node_to_idx: + producer_idx = node_to_idx[id(producer_node)] + location = "internal" if producer_idx in region_node_indices else "external" + input_sources.append((location, producer_node.op)) + else: + input_sources.append(("external", "unknown")) + + sorted_sources = sorted(input_sources) + return ",".join(f"{loc}:{op}" for loc, op in sorted_sources) + + @staticmethod + def _format_attr_value(value: object) -> str: + """Format an attribute value for inclusion in a signature.""" + if isinstance(value, (list, tuple)): + if len(value) > 0 and all(isinstance(v, (int, float)) for v in value): + if all(isinstance(v, int) for v in value): + return "x".join(str(v) for v in value) + return "x".join(f"{v:.4g}" if isinstance(v, float) else str(v) for v in value) + return ",".join(str(v) for v in value) + if isinstance(value, float): + return f"{value:.4g}" + if isinstance(value, bool): + return "1" if value else "0" + if isinstance(value, bytes): + hex_str = value.hex() + return hex_str if len(hex_str) <= 16 else f"{hex_str[:16]}..." + return str(value) + + @staticmethod + def _make_node_with_params_signature( + node: gs.Node, graph: gs.Graph, region_node_indices: set + ) -> str: + """Create signature for a single node including its parameters.""" + op = node.op + sym_sig = RegionPattern._get_symmetric_input_signature(node, graph, region_node_indices) + + attr_sig = "" + if node.attrs: + attr_parts = [ + f"{key}={RegionPattern._format_attr_value(node.attrs[key])}" + for key in sorted(node.attrs.keys()) + ] + attr_sig = f"[{','.join(attr_parts)}]" + + if attr_sig and sym_sig: + return f"{op}{attr_sig}<{sym_sig}>" + if sym_sig: + return f"{op}<{sym_sig}>" + if attr_sig: + return f"{op}{attr_sig}" + return op diff --git a/modelopt/onnx/quantization/autotune/region_search.py b/modelopt/onnx/quantization/autotune/region_search.py new file mode 100644 index 000000000..37a9b175d --- /dev/null +++ b/modelopt/onnx/quantization/autotune/region_search.py @@ -0,0 +1,821 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Hierarchical region discovery and partitioning for ONNX graphs.""" + +import sys +from collections import deque + +import onnx_graphsurgeon as gs + +from modelopt.onnx.logging_config import logger +from modelopt.onnx.quantization.autotune.common import Region, RegionType +from modelopt.onnx.quantization.autotune.region_pattern import RegionPattern +from modelopt.onnx.quantization.graph_utils import get_tensor_consumer_node_indices + +DEFAULT_MAX_STEPS = 10 +DEFAULT_MAX_NODES_TO_SHOW = 20 +MAX_PROBE_STEPS_AFTER_CONVERGE = 3 + + +class RegionSearchBase: + """Base class for region search algorithms providing common graph analysis utilities.""" + + def __init__( + self, + graph: gs.Graph, + root: Region | None = None, + max_steps: int = DEFAULT_MAX_STEPS, + tensor_users_map: dict[str, list[int]] | None = None, + forward_reachable_nodes_map: dict[int, dict[int, int]] | None = None, + ): + """Initialize the base region search with graph analysis.""" + self.graph = graph + if root is None: + root = self._build_root_region() + self.root = root + if tensor_users_map is None: + tensor_users_map = get_tensor_consumer_node_indices(self.graph) + self.tensor_users_map = tensor_users_map + if forward_reachable_nodes_map is None: + forward_reachable_nodes_map = self._build_forward_reachable_nodes_map( + max_steps=max_steps + ) + self.forward_reachable_nodes_map = forward_reachable_nodes_map + + def _build_root_region(self) -> Region: + """Create a root region containing all nodes in the graph.""" + root = Region(region_id=0, level=0, region_type=RegionType.ROOT) + root.nodes.update(range(len(self.graph.nodes))) + self.compute_region_boundaries(root) + return root + + def _is_tensor_divergent(self, tensor_name: str) -> bool: + """Check if a tensor is consumed by multiple nodes (divergent).""" + return len(self.tensor_users_map.get(tensor_name, [])) > 1 + + def _is_node_divergent(self, node_idx: int) -> bool: + """Check if a node has outputs that branch to multiple consumers.""" + if node_idx not in self.root.get_nodes(): + logger.debug(f"Node {node_idx} not in root region") + return False + + node = self.graph.nodes[node_idx] + divergent_outputs = [ + out.name for out in node.outputs if self._is_tensor_divergent(out.name) + ] + is_divergent = len(divergent_outputs) > 0 + + if is_divergent: + logger.debug( + f"Divergent node {node_idx} ({node.op}): {len(divergent_outputs)} branches" + ) + + return is_divergent + + def _compute_forward_reachable_nodes( + self, start_node_idx: int, max_steps: int + ) -> dict[int, int]: + """Compute all nodes reachable forward from a starting node with distances.""" + reachable: dict[int, int] = {start_node_idx: 0} + queue: deque[tuple[int, int]] = deque([(start_node_idx, 0)]) + while queue: + current_node_idx, distance = queue.popleft() + if distance >= max_steps: + continue + current_node = self.graph.nodes[current_node_idx] + for output in current_node.outputs: + if output.name not in self.tensor_users_map: + continue + for next_node_idx in self.tensor_users_map[output.name]: + if next_node_idx not in reachable: + reachable[next_node_idx] = distance + 1 + queue.append((next_node_idx, distance + 1)) + return reachable + + def _build_forward_reachable_nodes_map(self, max_steps: int) -> dict[int, dict[int, int]]: + """Pre-compute forward reachability for all nodes in the graph.""" + logger.debug(f"Building forward reachability map (max_steps={max_steps})...") + forward_reachable_nodes_map: dict[int, dict[int, int]] = {} + for node_idx in self.root.get_nodes(): + forward_reachable_nodes_map[node_idx] = self._compute_forward_reachable_nodes( + node_idx, max_steps + ) + + total_reachable = sum(len(reachable) for reachable in forward_reachable_nodes_map.values()) + avg_reachable = total_reachable / len(self.root.get_nodes()) if self.root.get_nodes() else 0 + logger.debug(f"Reachability map complete: avg {avg_reachable:.1f} reachable nodes per node") + return forward_reachable_nodes_map + + def _find_converge_nodes(self, node_idx: int) -> tuple[int | None, set[int]]: + """Find convergence point and intermediate nodes for a divergent node.""" + node = self.graph.nodes[node_idx] + logger.debug(f"Finding convergence for node {node_idx} ({node.op})") + + branches: list[int] = [] + for output in node.outputs: + if output.name in self.tensor_users_map: + branches.extend(self.tensor_users_map[output.name]) + + branches = list(dict.fromkeys(branches)) + + logger.debug(f" {len(branches)} unique branches found") + + # Need at least 2 branches for convergence to be meaningful + if len(branches) <= 1: + logger.debug(" Insufficient branches for convergence") + return None, set() + + # STEP 1: Find Common Reachable Nodes (Potential Convergence Points) + branch_reachable = [self.forward_reachable_nodes_map.get(b, {}) for b in branches] + + if not branch_reachable: + logger.debug(" No reachable nodes from branches") + return None, set() + + common_nodes = set.intersection(*[set(r.keys()) for r in branch_reachable]) + logger.debug(f" {len(common_nodes)} common nodes found") + # Remove the divergent node itself (not a convergence point) + common_nodes.discard(node_idx) + + if not common_nodes: + logger.debug(" No valid convergence candidates") + return None, set() + + # STEP 2: Select Best Convergence Node with Region Validity Check + converge_node_idx: int | None = None + min_max_distance = float("inf") + + reachable_from_start = self.forward_reachable_nodes_map.get(node_idx, {}) + + # Evaluate each candidate convergence point + for candidate_idx in common_nodes: + # Define the potential region: nodes between start and candidate + region_nodes: set[int] = set() + region_nodes.update(set(reachable_from_start.keys())) + reachable_from_candidate = self.forward_reachable_nodes_map.get(candidate_idx, {}) + region_nodes.difference_update(set(reachable_from_candidate.keys())) + + broken_region = False + for rnode_index in region_nodes: + reachable_from_rnode = self.forward_reachable_nodes_map.get(rnode_index, {}) + rnode_to_candidate_distance = reachable_from_rnode.get(candidate_idx, float("inf")) + for test_node_idx in reachable_from_rnode: + # Skip nodes that are inside the region (they're fine) + if test_node_idx in region_nodes: + continue + # test_node is OUTSIDE the region. Check if it's "escaping" + # An escaping edge: region_node reaches test_node BEFORE candidate + rnode_to_test_distance = reachable_from_rnode.get(test_node_idx, float("inf")) + # If either distance is infinite, region is broken + # (indicates disconnected components or unreachable convergence) + if rnode_to_test_distance == float( + "inf" + ) or rnode_to_candidate_distance == float("inf"): + broken_region = True + break + # If test_node is closer than candidate, we have an escape! + # This means computation flows OUT of region before converging + if rnode_to_test_distance < rnode_to_candidate_distance: + broken_region = True + break + if broken_region: + break + # Skip this candidate if region is invalid + if broken_region: + continue + # Valid candidate! Check if it's the nearest one + max_distance = max(reachable[candidate_idx] for reachable in branch_reachable) + + if max_distance < min_max_distance: + min_max_distance = max_distance + converge_node_idx = candidate_idx + # If no valid convergence found, this divergence has no convergence + if converge_node_idx is None: + logger.debug(" No valid convergence found") + return None, set() + + converge_node = self.graph.nodes[converge_node_idx] + logger.debug( + f" Convergence at node {converge_node_idx} ({converge_node.op}), distance {min_max_distance}" + ) + # STEP 3: Compute All Nodes Between Divergence and Convergence + visited_nodes: set[int] = set() + for candidate_idx in reachable_from_start: + if candidate_idx == converge_node_idx: + continue + reachable_from_candidate = self.forward_reachable_nodes_map.get(candidate_idx, {}) + if converge_node_idx in reachable_from_candidate: + visited_nodes.add(candidate_idx) + logger.debug(f" {len(visited_nodes)} nodes between divergence and convergence") + return converge_node_idx, visited_nodes + + def _max_distance_to_nodes(self, src_idx: int, dst_indices: set[int]) -> int: + """Compute maximum distance from a source node to a set of destination nodes.""" + max_distance = 0 + for dst_idx in dst_indices: + reachable = self.forward_reachable_nodes_map.get(src_idx, {}) + if dst_idx in reachable: + max_distance = max(max_distance, reachable[dst_idx]) + + logger.debug( + f"Max distance from node {src_idx}: {max_distance} steps to {len(dst_indices)} nodes" + ) + return max_distance + + def compute_region_boundaries(self, region: Region, include_constant: bool = False) -> None: + """Compute input and output tensor boundaries for a region. + + Args: + region: The region to compute boundaries for + include_constant: Whether to include constant tensors in inputs + """ + node_indices = region.get_region_nodes_and_descendants() + + consumed_tensors: set[str] = set() + produced_tensors: set[str] = set() + region_outputs: set[str] = set() + + for node_idx in node_indices: + if node_idx >= len(self.graph.nodes): + continue + node = self.graph.nodes[node_idx] + + # Collect consumed tensors (potential inputs) + for input_tensor in node.inputs: + if isinstance(input_tensor, gs.Constant) and not include_constant: + continue + consumed_tensors.add(input_tensor.name) + + # Collect produced tensors and determine outputs + for output_tensor in node.outputs: + tensor_name = output_tensor.name + produced_tensors.add(tensor_name) + + consumer_indices = self.tensor_users_map.get(tensor_name, []) + has_external_consumer = any(idx not in node_indices for idx in consumer_indices) + is_graph_output = output_tensor in self.graph.outputs + + if has_external_consumer or is_graph_output or not consumer_indices: + region_outputs.add(tensor_name) + + # Region inputs = consumed tensors not produced internally + region.inputs = sorted(consumed_tensors - produced_tensors) + region.outputs = sorted(region_outputs) + + logger.debug( + f"Computed boundaries: {len(region.inputs)} inputs, {len(region.outputs)} outputs" + ) + + def print_tree( + self, + region: Region | None = None, + indent: int = 0, + max_items: int = DEFAULT_MAX_NODES_TO_SHOW, + file=None, + ) -> None: + """Print hierarchical region tree in human-readable text format.""" + region = region or self.root + + file = file or sys.stdout + p = " " * indent + + def print_items(items, label, formatter=str): + """Print a truncated list of items.""" + items = list(items) + print(f"{p}│ ├─ {label}: {len(items)}", file=file) + for item in items[:max_items]: + print(f"{p}│ │ - {formatter(item)}", file=file) + if len(items) > max_items: + print(f"{p}│ │ ... and {len(items) - max_items} more", file=file) + + # Header + print( + f"{p}├─ Region {region.id} (Level {region.level}, Type: {region.type.value})", + file=file, + ) + + # Counts + direct_nodes = region.get_nodes() + children = region.get_children() + print(f"{p}│ ├─ Direct nodes: {len(direct_nodes)}", file=file) + print(f"{p}│ ├─ Total nodes: {len(region.get_region_nodes_and_descendants())}", file=file) + print(f"{p}│ ├─ Children: {len(children)}", file=file) + + # I/O + print_items(region.inputs, "Inputs") + print_items(region.outputs, "Outputs") + + # Direct nodes + if direct_nodes: + print(f"{p}│\n{p}│ Nodes in this region:", file=file) + for node_idx in sorted(direct_nodes)[:max_items]: + if node_idx < len(self.graph.nodes): + node = self.graph.nodes[node_idx] + print(f"{p}│ - Node {node_idx}: {node.op} ({node.name})", file=file) + if len(direct_nodes) > max_items: + print(f"{p}│ ... and {len(direct_nodes) - max_items} more", file=file) + + # Children + if children: + print(f"{p}│\n{p}│ Child regions:", file=file) + for child in children: + print(f"{p}│", file=file) + self.print_tree(child, indent + 1, max_items, file) + + +class RegionPartitioner(RegionSearchBase): + """Bottom-up graph partitioner that creates initial regions based on divergence patterns. + + This class implements Phase 1 of the combined region search strategy. It performs + a systematic traversal of the computation graph from inputs to outputs, identifying + natural boundaries for region formation based on computation flow patterns. + """ + + def __init__( + self, + graph: gs.Graph, + tensor_users_map: dict[str, list[int]] | None = None, + forward_reachable_nodes_map: dict[int, dict[int, int]] | None = None, + ): + """Initialize the partitioner with a computation graph.""" + super().__init__( + graph, + root=None, + tensor_users_map=tensor_users_map, + forward_reachable_nodes_map=forward_reachable_nodes_map, + ) + self.regions: list[Region] = [] + self.current_region: Region | None = None + self.current_region_id: int = 0 + self.visited_nodes: set[int] = set() + + def _append_node_to_region(self, node_idx: int): + """Add a node to the current region, creating a new region if needed.""" + node = self.graph.nodes[node_idx] + if self.current_region is None: + self.current_region = Region( + region_id=self.current_region_id, level=0, region_type=RegionType.LEAF + ) + logger.debug(f"Started region {self.current_region_id}") + self.current_region_id += 1 + + self.current_region.nodes.add(node_idx) + logger.debug( + f" Added node {node_idx} ({node.op}), region size: {len(self.current_region.nodes)}" + ) + + def _commit_region(self): + """Finalize and store the current region being built.""" + if self.current_region is not None: + region_size = len(self.current_region.nodes) + region_id = self.current_region.id + + self.compute_region_boundaries(self.current_region) + + self.regions.append(self.current_region) + logger.debug( + f"Committed region {region_id}: {region_size} nodes (total: {len(self.regions)})" + ) + self.current_region = None + else: + logger.debug("No region to commit") + + def _build_sequence_from_node(self, node_idx: int, max_nodes: int = -1): + """Build a region from a linear sequence of nodes.""" + logger.debug(f"Building sequence from node {node_idx} ({self.graph.nodes[node_idx].op})") + + queue: deque[int] = deque([node_idx]) + nodes_added = 0 + + while queue: + current_idx = queue.popleft() + node = self.graph.nodes[current_idx] + + self._append_node_to_region(current_idx) + self.visited_nodes.add(current_idx) + nodes_added += 1 + + if self._is_node_divergent(current_idx): + logger.debug(f" Stopped at divergent node {current_idx} ({node.op})") + else: + # Queue successors for non-divergent nodes + for output in node.outputs: + if output.name in self.tensor_users_map: + queue.extend(self.tensor_users_map[output.name]) + + if 0 < max_nodes <= nodes_added: + logger.debug(" Max nodes reached") + break + + logger.debug(f"Sequence complete: {nodes_added} nodes") + + def _build_small_converged_region( + self, start_node_idx: int, converge_node_idx: int, visited_nodes: set[int] + ): + r"""Create a region encompassing divergence, branches, and convergence. + + Builds a single region containing: + - The divergent node (where branches split) + - All nodes in the branches + - The convergence node (where branches rejoin) + + This creates a "diamond" or "funnel" shaped region that captures + parallel computation paths and their merge point. + + **Structure:** + ``` + start (divergent) + / \ + path1 path2 (visited_nodes) + \\ / + convergence + ``` + """ + visited_nodes.remove(start_node_idx) + for node_idx in sorted(visited_nodes): + self._append_node_to_region(node_idx) + self.visited_nodes.add(node_idx) + if not self._is_node_divergent(converge_node_idx): + self._append_node_to_region(converge_node_idx) + self.visited_nodes.add(converge_node_idx) + self._build_sequence_from_node(converge_node_idx, max_nodes=MAX_PROBE_STEPS_AFTER_CONVERGE) + + def _build_region_from_node(self, node_idx: int): + """Process a single node and create appropriate region(s) based on its pattern.""" + node = self.graph.nodes[node_idx] + + # Skip nodes already assigned to regions + if node_idx in self.visited_nodes: + logger.debug(f"Skipping node {node_idx} ({node.op}): already visited") + return + + logger.debug(f"Processing node {node_idx} ({node.op})") + + # Pattern 1 & 2: Handle divergent nodes + if self._is_node_divergent(node_idx): + logger.debug(" Divergent node, searching for convergence") + # Attempt to find where branches rejoin + converge_node_idx, visited_nodes = self._find_converge_nodes(node_idx) + # Check if convergence creates a reasonable-sized region + max_distance = self._max_distance_to_nodes(node_idx, visited_nodes) + # Pattern 1: Convergence found and region size is acceptable + if converge_node_idx is not None and max_distance < DEFAULT_MAX_STEPS: + converge_node = self.graph.nodes[converge_node_idx] + logger.debug( + f" Creating converged region: {len(visited_nodes)} nodes, " + f"convergence at {converge_node_idx} ({converge_node.op}), distance {max_distance}" + ) + # Create region containing: divergence + all branches + convergence + self._build_small_converged_region(node_idx, converge_node_idx, visited_nodes) + self._commit_region() + # Pattern 2: No convergence or region would be too large + else: + logger.debug(" Creating orphan region for divergent node") + # Create single-node region for this divergent node + # Its successors will be processed separately + self._append_node_to_region(node_idx) + self.visited_nodes.add(node_idx) + self._commit_region() + else: + # Pattern 3: Handle non-divergent (sequential) nodes + logger.debug(" Non-divergent node, building sequence") + # Build region by following the linear chain forward + self._build_sequence_from_node(node_idx) + self._commit_region() + + def partition_graph(self): + """Partition the entire graph into non-overlapping LEAF regions.""" + logger.info(f"Partitioning graph ({len(self.graph.nodes)} nodes)") + logger.debug( + f"Initial state: {len(self.visited_nodes)} visited, {len(self.regions)} regions" + ) + + for node_idx in range(len(self.graph.nodes)): + self._build_region_from_node(node_idx) + + coverage_pct = ( + 100 * len(self.visited_nodes) / len(self.graph.nodes) if self.graph.nodes else 0 + ) + logger.info( + f"Partitioning complete: {len(self.regions)} regions, " + f"{len(self.visited_nodes)}/{len(self.graph.nodes)} nodes ({coverage_pct:.1f}%)" + ) + + if self.regions: + region_sizes = [len(r.nodes) for r in self.regions] + avg_size = sum(region_sizes) / len(region_sizes) + min_size = min(region_sizes) + max_size = max(region_sizes) + logger.debug(f"Region sizes: min={min_size}, max={max_size}, avg={avg_size:.1f}") + + return self.regions + + +class TopDownRegionBuilder(RegionSearchBase): + """Top-down region refiner that creates hierarchical structure from initial regions. + + This class implements Phase 2 of the combined region search strategy. It takes + a region created by RegionPartitioner and refines it by: + 1. Identifying and merging converged sub-patterns + 2. Splitting long sequences into optimal sub-regions + 3. Creating a hierarchical COMPOSITE region structure + """ + + def __init__( + self, + graph: gs.Graph, + root: Region, + next_region_id: int = 0, + maximum_sequence_region_size: int = 10, + tensor_users_map: dict[str, list[int]] | None = None, + forward_reachable_nodes_map: dict[int, dict[int, int]] | None = None, + ): + """Initialize the refiner with a region to refine.""" + super().__init__( + graph, + root=root, + tensor_users_map=tensor_users_map, + forward_reachable_nodes_map=forward_reachable_nodes_map, + ) + self.regions: list[Region] = [] + self.next_region_id = next_region_id + self.maximum_sequence_region_size = maximum_sequence_region_size + self.boundary_op_types = { + "Conv", + "ConvTranspose", + "Gemm", + "MatMul", + "AveragePool", + "MaxPool", + "GlobalAveragePool", + "GlobalMaxPool", + "Resize", + } + + def _create_leaf_region(self, node_indices: set[int]) -> Region: + """Create a new LEAF region containing specified nodes.""" + region = Region( + region_id=self.next_region_id, level=self.root.level + 1, region_type=RegionType.LEAF + ) + self.next_region_id += 1 + for node_idx in node_indices: + region.nodes.add(node_idx) + self.compute_region_boundaries(region) + return region + + def _build_region_usage_map(self, regions: list[Region]) -> dict[str, list[Region]]: + """Build mapping from tensor names to regions that consume them.""" + region_usage_map: dict[str, list[Region]] = {} + for region in regions: + for tensor_name in region.inputs: + if tensor_name not in region_usage_map: + region_usage_map[tensor_name] = [] + region_usage_map[tensor_name].append(region) + return region_usage_map + + def _split_sequence_regions(self, root: Region) -> list[Region]: + """Split a region into smaller sub-regions by merging producer-consumer chains.""" + result_regions: list[Region] = [] + removed_regions: set[int] = set() + + # PHASE 1: Split into Single-Node Regions + for node_idx in root.get_nodes(): + region = Region( + region_id=self.next_region_id, level=root.level + 1, region_type=RegionType.LEAF + ) + region.nodes.add(node_idx) + self.compute_region_boundaries(region) + result_regions.append(region) + self.next_region_id += 1 + + region_usage_map = self._build_region_usage_map(result_regions) + + # PHASE 2: Merge Regions in Data Flow Order + queue = deque(root.inputs) + + while len(queue) > 0: + tensor_name = queue.popleft() + # Skip tensors not produced by any region in our scope + if tensor_name not in region_usage_map: + continue + # Process each region consuming this tensor (potential merge targets) + consumers = region_usage_map[tensor_name] + for consumer in consumers: + # Skip regions already merged into others + if consumer.id in removed_regions: + continue + # Merging criteria: ALL outputs go to same single region + common_use_region = None + can_merge = True + # Check all outputs of the consumer region + for output_tensor in consumer.outputs: + queue.append(output_tensor) + if output_tensor not in region_usage_map: + can_merge = False + break + use_regions = region_usage_map[output_tensor] + if len(use_regions) != 1: + can_merge = False + break + if common_use_region is None: + common_use_region = use_regions[0] + elif common_use_region != use_regions[0]: + can_merge = False + break + # No valid downstream region to merge with + if common_use_region is None or common_use_region.id in removed_regions: + can_merge = False + continue + # Constraint 1: Limit the number of boundary operations after merge + nodes_after_merge = set() + nodes_after_merge.update(consumer.get_nodes()) + nodes_after_merge.update(common_use_region.get_nodes()) + node_ops = [self.graph.nodes[idx].op for idx in nodes_after_merge] + boundary_op_count = sum( + [1 if op in self.boundary_op_types else 0 for op in node_ops] + ) + if boundary_op_count > 3: + can_merge = False + continue + # Constraint 2: Size limits to avoid overly large regions + # Keep regions manageable for optimization passes + if ( + len(consumer.nodes) >= self.maximum_sequence_region_size + or len(common_use_region.nodes) >= self.maximum_sequence_region_size + ): + # One or both regions too large - don't merge + can_merge = False + continue + # All criteria met: merge consumer into its downstream region + if can_merge: + common_use_region.merge(consumer) + removed_regions.add(consumer.id) + # Remove regions that were merged into others + result_regions = [region for region in result_regions if region.id not in removed_regions] + # Recompute boundaries for all remaining regions + for region in result_regions: + self.compute_region_boundaries(region) + + return result_regions + + def _merge_converged_regions(self, root: Region): + """Identify and merge convergence patterns within a region.""" + result_regions: list[Region] = [] + removed_nodes: set[int] = set() + queue = deque(root.inputs) + while len(queue) > 0: + tensor_name = queue.popleft() + if tensor_name not in self.tensor_users_map: + continue + consumer_nodes = self.tensor_users_map[tensor_name] + for node_idx in consumer_nodes: + # stop at boundary nodes + if node_idx not in root.get_nodes(): + continue + consumer = self.graph.nodes[node_idx] + for output_tensor in consumer.outputs: + if output_tensor.name not in self.tensor_users_map: + continue + queue.append(output_tensor.name) + # if the node is already in a region, skip + if node_idx in removed_nodes: + continue + if not self._is_node_divergent(node_idx): + continue + converge_node_idx, visited_nodes = self._find_converge_nodes(node_idx) + visited_nodes = visited_nodes.intersection(root.get_region_nodes_and_descendants()) + # if no convergence found, skip + if converge_node_idx is None: + continue + # group converged nodes into a region + if converge_node_idx in root.get_nodes(): + converged_region = self._create_leaf_region(visited_nodes) + result_regions.append(converged_region) + removed_nodes.update(visited_nodes) + continue + # create a leaf region for the remaining nodes + remaining_nodes = set(root.get_nodes()) - removed_nodes + if len(remaining_nodes) > 0: + result_regions.append(self._create_leaf_region(remaining_nodes)) + # compute region boundaries for all regions + for region in result_regions: + self.compute_region_boundaries(region) + return result_regions + + def build_composite_region(self) -> Region: + """Refine a flat region into a hierarchical COMPOSITE region.""" + # merge converged regions into composite regions + self.regions = self._merge_converged_regions(self.root) + # split sequence regions into smaller regions + result_regions: list[Region] = [] + for region in self.regions: + result_regions.extend(self._split_sequence_regions(region)) + for region in result_regions: + self.compute_region_boundaries(region, include_constant=True) + self.regions = result_regions + # merge all regions into a single composite region + if len(self.regions) > 1: + composite = Region( + region_id=self.next_region_id, + level=self.root.level, + region_type=RegionType.COMPOSITE, + ) + self.next_region_id += 1 + self.regions = sorted( + self.regions, key=lambda x: RegionPattern.from_region(x, self.graph).signature + ) + for region in self.regions: + composite.add_child(region) + self.compute_region_boundaries(composite) + self.regions = [composite] + return self.regions[0] + + +class CombinedRegionSearch(RegionSearchBase): + """Two-phase region search combining bottom-up partitioning with top-down refinement. + + This class implements a sophisticated region discovery algorithm that combines two + complementary strategies to create well-formed, hierarchical regions from an ONNX + computation graph. + + """ + + def __init__( + self, + graph: gs.Graph, + maximum_sequence_region_size: int = 10, + minimum_topdown_search_size: int = 10, + ): + """Initialize CombinedRegionSearch for a given ONNX graph.""" + super().__init__(graph) + self.regions: list[Region] = [] + self.minimum_topdown_search_size = minimum_topdown_search_size + self.maximum_sequence_region_size = maximum_sequence_region_size + + def search_regions(self) -> list[Region]: + """Execute two-phase region search to partition the graph into hierarchical regions.""" + logger.info("Phase 1: Bottom-up partitioning") + logger.debug("Initializing RegionPartitioner") + region_partitioner = RegionPartitioner(self.graph) + + # Execute the bottom-up partitioning algorithm. + self.regions = region_partitioner.partition_graph() + + coverage_pct = ( + 100 * len(region_partitioner.visited_nodes) / len(self.graph.nodes) + if self.graph.nodes + else 0 + ) + logger.info( + f"Phase 1 complete: {len(self.regions)} regions, " + f"{len(region_partitioner.visited_nodes)}/{len(self.graph.nodes)} nodes ({coverage_pct:.1f}%)" + ) + logger.debug("Proceeding to Phase 2: Top-down refinement") + + logger.info("Phase 2: Top-down refinement") + next_region_id = region_partitioner.current_region_id + + refined_count = 0 + for idx, region in enumerate(self.regions): + node_count = len(region.get_region_nodes_and_descendants()) + if node_count < self.minimum_topdown_search_size: + logger.debug(f"Skipping region {idx}: {node_count} nodes (below minimum)") + continue + + logger.debug(f"Refining region {idx}: {node_count} nodes") + region_builder = TopDownRegionBuilder( + self.graph, + region, + next_region_id=next_region_id, + maximum_sequence_region_size=self.maximum_sequence_region_size, + tensor_users_map=region_partitioner.tensor_users_map, + forward_reachable_nodes_map=region_partitioner.forward_reachable_nodes_map, + ) + + self.regions[idx] = region_builder.build_composite_region() + node_count_after = len(self.regions[idx].get_region_nodes_and_descendants()) + if node_count != node_count_after: + logger.warning( + f"Node count mismatch in region {idx}: {node_count} → {node_count_after}" + ) + + region_partitioner.compute_region_boundaries(self.regions[idx]) + next_region_id = region_builder.next_region_id + refined_count += 1 + + logger.info(f"Phase 2 complete: refined {refined_count}/{len(self.regions)} regions") + + return self.regions diff --git a/modelopt/onnx/quantization/graph_utils.py b/modelopt/onnx/quantization/graph_utils.py index 67596d5df..77dec9441 100755 --- a/modelopt/onnx/quantization/graph_utils.py +++ b/modelopt/onnx/quantization/graph_utils.py @@ -302,6 +302,24 @@ def get_tensor_consumer_nodes( return tensor_consumers +def get_tensor_consumer_node_indices(graph: onnx.GraphProto | gs.Graph) -> dict[str, list[int]]: + """Build a mapping from tensor names to the indices of nodes that use them. + + Args: + graph: ONNX GraphSurgeon graph to analyze + Returns: + Dictionary mapping tensor names to lists of node indices that consume them + """ + tensor_consumer_map: dict[str, list[int]] = defaultdict(list) + nodes = graph.nodes if isinstance(graph, gs.Graph) else graph.node + for node_idx, node in enumerate(nodes): + inputs = node.inputs if isinstance(node, gs.Node) else node.input + for tensor in inputs: + tensor_name = tensor.name if isinstance(tensor, gs.Tensor) else tensor + tensor_consumer_map[tensor_name].append(node_idx) + return tensor_consumer_map + + def filter_quantizable_kgen_heads( cask_fusible_partitions: list[list[Node]], kgen_partitions: list[list[Node]], diff --git a/modelopt/onnx/quantization/qdq_utils.py b/modelopt/onnx/quantization/qdq_utils.py index 026b8d062..abda4dacd 100644 --- a/modelopt/onnx/quantization/qdq_utils.py +++ b/modelopt/onnx/quantization/qdq_utils.py @@ -1035,3 +1035,30 @@ def cast_initializer_to_dtype( input_onnx = onnx.numpy_helper.from_array(input, input_name) input_onnx.data_type = onnx_dtype_map[dtype] initializer_map[input_name].CopyFrom(input_onnx) + + +def get_quantized_tensors(onnx_model: onnx.ModelProto) -> set[str]: + """Get the names of all quantized tensors from an ONNX model. + + This function identifies all DequantizeLinear nodes in the ONNX model + and extracts the names of tensors being dequantized (the first input of + each DequantizeLinear node, excluding scale and zero-point inputs). + + Args: + onnx_model: ONNX model protobuf to analyze + + Returns: + Set of tensor names that are inputs to DequantizeLinear nodes + (i.e., the tensors being dequantized) + """ + quantized_tensors = set() + + for node in onnx_model.graph.node: + if node.op_type == "DequantizeLinear": + # First input is the tensor being dequantized + # (inputs[1] is scale, inputs[2] is zero-point) + if node.input and len(node.input) > 0: + quantized_tensors.add(node.input[0]) + + logger.debug(f"Found {len(quantized_tensors)} dequantized tensors in ONNX model") + return quantized_tensors diff --git a/tests/unit/onnx/quantization/autotune/test_pattern_cache.py b/tests/unit/onnx/quantization/autotune/test_pattern_cache.py new file mode 100644 index 000000000..bd9b67f90 --- /dev/null +++ b/tests/unit/onnx/quantization/autotune/test_pattern_cache.py @@ -0,0 +1,237 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Tests for PatternCache functionality in the autotuner. + +Tests pattern cache creation, serialization, and scheme management. +""" + +import os +import sys +import tempfile +import unittest + +# Add parent directory to path +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +from modelopt.onnx.quantization.autotune.common import ( + InsertionScheme, + NodeInputInsertionPoint, + PatternCache, + PatternSchemes, +) +from modelopt.onnx.quantization.autotune.region_pattern import RegionPattern + + +class TestPatternCache(unittest.TestCase): + """Test PatternCache functionality.""" + + @staticmethod + def _create_test_pattern(signature: str, size: int = 2): + """Create a test RegionPattern.""" + return RegionPattern(signature=signature, size=size) + + def test_empty_cache_creation(self): + """Test creating an empty PatternCache.""" + cache = PatternCache() + + assert len(cache.pattern_schemes) == 0 + assert cache.pattern_schemes is not None + + def test_add_pattern_schemes(self): + """Test adding pattern schemes to cache.""" + cache = PatternCache() + + # Create a pattern scheme + pattern = self._create_test_pattern("Conv->Relu") + ps = PatternSchemes(pattern=pattern) + scheme = InsertionScheme() + scheme.latency_ms = 10.0 + ps.schemes.append(scheme) + + cache.add_pattern_schemes(ps) + + assert len(cache.pattern_schemes) == 1 + assert cache.pattern_schemes[0].pattern_signature == "Conv->Relu" + + def test_multiple_patterns(self): + """Test cache with multiple pattern schemes.""" + cache = PatternCache() + + # Add multiple patterns + pattern_sigs = ["Conv->Relu", "Gemm->Relu", "Conv->Add->Relu"] + for pattern_sig in pattern_sigs: + pattern = self._create_test_pattern(pattern_sig) + ps = PatternSchemes(pattern=pattern) + scheme = InsertionScheme() + scheme.latency_ms = 10.0 + len(pattern_sig) + ps.schemes.append(scheme) + cache.add_pattern_schemes(ps) + + assert len(cache.pattern_schemes) == 3 + found_patterns = [ps.pattern_signature for ps in cache.pattern_schemes] + for pattern_sig in pattern_sigs: + assert pattern_sig in found_patterns + + def test_serialization_empty(self): + """Test serialization of empty cache.""" + cache = PatternCache() + + data = cache.to_dict() + assert "pattern_schemes" in data + assert len(data["pattern_schemes"]) == 0 + + restored = PatternCache.from_dict(data) + assert len(restored.pattern_schemes) == 0 + + def test_serialization_with_data(self): + """Test serialization with pattern schemes.""" + # Create cache with minimum_distance=0 to keep both schemes + cache = PatternCache(minimum_distance=0) + + # Add a pattern scheme + pattern = self._create_test_pattern("Conv->Relu") + ps = PatternSchemes(pattern=pattern) + + # Create schemes that are sufficiently different (distance >= 4) + scheme1 = InsertionScheme() + scheme1.node_inputs = [NodeInputInsertionPoint(0, 0)] + scheme1.latency_ms = 10.0 + ps.schemes.append(scheme1) + + scheme2 = InsertionScheme() + scheme2.node_inputs = [ + NodeInputInsertionPoint(0, 0), + NodeInputInsertionPoint(1, 0), + NodeInputInsertionPoint(2, 0), + NodeInputInsertionPoint(3, 0), + NodeInputInsertionPoint(4, 0), # 5 total points, diff = 4 from scheme1 + ] + scheme2.latency_ms = 12.0 + ps.schemes.append(scheme2) + + cache.add_pattern_schemes(ps) + + # Serialize and restore + data = cache.to_dict() + restored = PatternCache.from_dict(data) + + assert len(restored.pattern_schemes) == 1 + + restored_ps = restored.pattern_schemes[0] + assert restored_ps.pattern_signature == "Conv->Relu" + assert len(restored_ps.schemes) == 2 + assert restored_ps.best_scheme_index == 0 + assert restored_ps.schemes[0].latency_ms == 10.0 + + def test_yaml_round_trip(self): + """Test saving and loading cache as YAML.""" + cache = PatternCache() + + # Add a pattern scheme + pattern = self._create_test_pattern("Gemm->Relu") + ps = PatternSchemes(pattern=pattern) + scheme = InsertionScheme() + scheme.latency_ms = 15.0 + ps.schemes.append(scheme) + cache.add_pattern_schemes(ps) + + # Save to YAML + with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f: + yaml_path = f.name + + try: + cache.save(yaml_path) + + # Load from YAML + restored = PatternCache.load(yaml_path) + + assert len(restored.pattern_schemes) == 1 + assert restored.pattern_schemes[0].pattern_signature == "Gemm->Relu" + assert restored.pattern_schemes[0].schemes[0].latency_ms == 15.0 + finally: + if os.path.exists(yaml_path): + os.unlink(yaml_path) + + def test_update_cache(self): + """Test updating existing pattern in cache (merges schemes).""" + # Use minimum_distance=0 to keep all schemes + cache = PatternCache(minimum_distance=0) + + # Add initial pattern + pattern1 = self._create_test_pattern("Conv->Relu") + ps1 = PatternSchemes(pattern=pattern1) + scheme1 = InsertionScheme() + scheme1.latency_ms = 10.0 + ps1.schemes.append(scheme1) + cache.add_pattern_schemes(ps1) + + # Update with new scheme for same pattern + pattern2 = self._create_test_pattern("Conv->Relu") + ps2 = PatternSchemes(pattern=pattern2) + scheme2 = InsertionScheme() + scheme2.latency_ms = 8.0 # Better performance + scheme2.node_inputs = [NodeInputInsertionPoint(0, 0)] # Make it different + ps2.schemes.append(scheme2) + cache.add_pattern_schemes(ps2) + + # Verify merge (should have both schemes now) + assert len(cache.pattern_schemes) == 1 + conv_relu_ps = cache.pattern_schemes[0] + assert conv_relu_ps.pattern_signature == "Conv->Relu" + assert len(conv_relu_ps.schemes) == 2 # Merged + # Best scheme should be the one with lowest latency + assert conv_relu_ps.best_scheme.latency_ms == 8.0 + + def test_get_best_scheme(self): + """Test retrieving best scheme for a pattern.""" + # Use minimum_distance=0 to keep all different schemes + cache = PatternCache(minimum_distance=0) + + pattern = self._create_test_pattern("Conv->Relu") + ps = PatternSchemes(pattern=pattern) + + # Add multiple schemes with different insertion points + scheme1 = InsertionScheme() + scheme1.node_inputs = [NodeInputInsertionPoint(0, 0)] + scheme1.latency_ms = 12.0 + ps.schemes.append(scheme1) + + scheme2 = InsertionScheme() + scheme2.node_inputs = [NodeInputInsertionPoint(1, 0)] # Different node + scheme2.latency_ms = 8.0 + ps.schemes.append(scheme2) + + scheme3 = InsertionScheme() + scheme3.node_inputs = [NodeInputInsertionPoint(2, 0)] # Different node + scheme3.latency_ms = 10.0 + ps.schemes.append(scheme3) + + cache.add_pattern_schemes(ps) + + # Verify best scheme retrieval (automatically computed) + conv_relu_ps = cache.pattern_schemes[0] + assert conv_relu_ps.pattern_signature == "Conv->Relu" + assert len(conv_relu_ps.schemes) == 3 # All 3 kept + + # Verify best scheme has lowest latency (cache may reorder schemes) + best = conv_relu_ps.best_scheme + assert best is not None + assert best.latency_ms == 8.0 + + # Verify all three latencies are present + latencies = sorted([s.latency_ms for s in conv_relu_ps.schemes]) + assert latencies == [8.0, 10.0, 12.0] diff --git a/tests/unit/onnx/quantization/autotune/test_region_pattern.py b/tests/unit/onnx/quantization/autotune/test_region_pattern.py new file mode 100644 index 000000000..939cefa76 --- /dev/null +++ b/tests/unit/onnx/quantization/autotune/test_region_pattern.py @@ -0,0 +1,410 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Tests for RegionPattern functionality in the autotuner. + +Tests pattern generation, matching, and tree visualization. +""" + +import os +import sys +import unittest + +import numpy as np +import onnx_graphsurgeon as gs + +# Add parent directory to path +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +from modelopt.onnx.quantization.autotune.common import Region, RegionType +from modelopt.onnx.quantization.autotune.region_pattern import RegionPattern + + +class TestRegionPattern(unittest.TestCase): + """Test RegionPattern functionality.""" + + # ========================================================================= + # Helper Methods + # ========================================================================= + + @staticmethod + def _create_simple_graph(): + """Create a simple Conv->Relu graph for testing. + + Graph structure: + input -> Conv -> Relu -> output + """ + # Create inputs and outputs + inp = gs.Variable(name="input", dtype=np.float32, shape=[1, 3, 224, 224]) + conv_out = gs.Variable(name="conv_out", dtype=np.float32) + relu_out = gs.Variable(name="output", dtype=np.float32) + + # Create weights + conv_weight = gs.Constant( + name="conv_weight", values=np.ones((64, 3, 3, 3), dtype=np.float32) + ) + + # Create nodes + conv = gs.Node( + name="Conv_0", + op="Conv", + inputs=[inp, conv_weight], + outputs=[conv_out], + attrs={"kernel_shape": [3, 3], "strides": [1, 1], "pads": [1, 1, 1, 1]}, + ) + relu = gs.Node( + name="Relu_0", + op="Relu", + inputs=[conv_out], + outputs=[relu_out], + ) + + # Create graph + graph = gs.Graph( + nodes=[conv, relu], + inputs=[inp], + outputs=[relu_out], + opset=13, + ) + return graph + + @staticmethod + def _create_hierarchical_graph(): + """Create a hierarchical graph with composite regions. + + Graph structure: + input -> Conv -> Relu -> Add -> MatMul -> Relu -> output + ^ + | + other_input + + Region structure: + ROOT + ├── COMPOSITE (Conv->Relu->Add) + │ ├── LEAF (Conv->Relu) + │ └── LEAF (Add) + └── COMPOSITE (MatMul->Relu) + └── LEAF (MatMul->Relu) + """ + # Create inputs and intermediate tensors + inp = gs.Variable(name="input", dtype=np.float32, shape=[1, 64, 64, 64]) + other_inp = gs.Variable(name="other_input", dtype=np.float32, shape=[1, 64, 64, 64]) + conv_out = gs.Variable(name="conv_out", dtype=np.float32) + relu1_out = gs.Variable(name="relu1_out", dtype=np.float32) + add_out = gs.Variable(name="add_out", dtype=np.float32) + matmul_out = gs.Variable(name="matmul_out", dtype=np.float32) + output = gs.Variable(name="output", dtype=np.float32) + + # Create constants + conv_weight = gs.Constant( + name="conv_weight", values=np.ones((64, 64, 1, 1), dtype=np.float32) + ) + matmul_weight = gs.Constant( + name="matmul_weight", values=np.ones((64, 64), dtype=np.float32) + ) + + # Create nodes (order matters for node indices) + conv = gs.Node( + name="Conv_0", + op="Conv", + inputs=[inp, conv_weight], + outputs=[conv_out], + attrs={"kernel_shape": [1, 1]}, + ) # Node 0 + relu1 = gs.Node(name="Relu_0", op="Relu", inputs=[conv_out], outputs=[relu1_out]) # Node 1 + add = gs.Node( + name="Add_0", op="Add", inputs=[relu1_out, other_inp], outputs=[add_out] + ) # Node 2 + matmul = gs.Node( + name="MatMul_0", op="MatMul", inputs=[add_out, matmul_weight], outputs=[matmul_out] + ) # Node 3 + relu2 = gs.Node(name="Relu_1", op="Relu", inputs=[matmul_out], outputs=[output]) # Node 4 + + # Create graph + graph = gs.Graph( + nodes=[conv, relu1, add, matmul, relu2], + inputs=[inp, other_inp], + outputs=[output], + opset=13, + ) + return graph + + @staticmethod + def _create_test_region( + region_id: int, level: int, region_type: RegionType, node_indices: list[int] | None = None + ) -> Region: + """Create a test region.""" + region = Region(region_id, level, region_type) + if node_indices: + region.nodes.update(node_indices) + return region + + # ========================================================================= + # Test Cases + # ========================================================================= + + def test_pattern_creation(self): + """Test basic RegionPattern creation.""" + pattern = RegionPattern(signature="Conv->Relu", size=2) + + assert pattern.signature == "Conv->Relu" + assert pattern.size == 2 + assert not pattern.is_empty + assert pattern.is_leaf + assert not pattern.is_composite + + def test_pattern_equality(self): + """Test RegionPattern equality based on signature.""" + pattern1 = RegionPattern(signature="Conv->Relu", size=2) + pattern2 = RegionPattern(signature="Conv->Relu", size=5) # Different size + pattern3 = RegionPattern(signature="Gemm->Relu", size=2) + + # Same signature = equal (size doesn't affect equality) + assert pattern1 == pattern2 + # Different signature = not equal + assert pattern1 != pattern3 + + def test_pattern_hash(self): + """Test RegionPattern hashing for use in dicts/sets.""" + pattern1 = RegionPattern(signature="Conv->Relu", size=2) + pattern2 = RegionPattern(signature="Conv->Relu", size=5) + + # Same signature = same hash + assert hash(pattern1) == hash(pattern2) + + # Can be used as dict keys + pattern_dict = {pattern1: "scheme1"} + assert pattern_dict[pattern2] == "scheme1" # pattern2 finds pattern1's entry + + def test_pattern_from_simple_region(self): + """Test pattern computation from a simple region.""" + graph = self._create_simple_graph() + + # Create a leaf region with Conv and Relu nodes + region = self._create_test_region( + region_id=1, level=0, region_type=RegionType.LEAF, node_indices=[0, 1] + ) + + pattern = RegionPattern.from_region(region, graph) + + # Should capture both operations + assert "Conv" in pattern.signature + assert "Relu" in pattern.signature + assert pattern.size == 2 + assert pattern.is_leaf + + def test_pattern_from_composite_region(self): + """Test pattern computation from a composite region with children.""" + graph = self._create_hierarchical_graph() + + # Create leaf regions + leaf1 = self._create_test_region( + region_id=1, + level=0, + region_type=RegionType.LEAF, + node_indices=[0, 1], # Conv, Relu + ) + leaf2 = self._create_test_region( + region_id=2, + level=0, + region_type=RegionType.LEAF, + node_indices=[2], # Add + ) + + # Create composite region + composite = self._create_test_region( + region_id=3, level=1, region_type=RegionType.COMPOSITE, node_indices=[] + ) + composite.add_child(leaf1) + composite.add_child(leaf2) + + pattern = RegionPattern.from_region(composite, graph) + + assert pattern.is_composite + assert "COMPOSITE" in pattern.signature + assert pattern.size == 3 # Total nodes in region hierarchy + + def test_pattern_get_hash(self): + """Test cryptographic hash generation.""" + pattern = RegionPattern(signature="Conv->Relu", size=2) + hash_val = pattern.get_hash() + + # Hash should be 32 hex characters (128-bit truncated SHA-256) + assert len(hash_val) == 32 + assert all(c in "0123456789abcdef" for c in hash_val) + + # Same signature = same hash + pattern2 = RegionPattern(signature="Conv->Relu", size=5) + assert pattern.get_hash() == pattern2.get_hash() + + def test_pattern_get_short_signature(self): + """Test signature truncation.""" + long_sig = "COMPOSITE(" + "Conv->Relu->" * 20 + "Output)" + pattern = RegionPattern(signature=long_sig, size=20) + + short_sig = pattern.get_short_signature(max_length=50) + assert len(short_sig) == 50 + assert short_sig.endswith("...") + + # Short signature stays unchanged + short_pattern = RegionPattern(signature="Conv", size=1) + assert short_pattern.get_short_signature(max_length=50) == "Conv" + + def test_print_tree(self): + """Test format_tree to visualize region structure. + + This test demonstrates how to use format_tree to display + the hierarchical structure of regions and their patterns. + """ + graph = self._create_hierarchical_graph() + + # Build a hierarchical region structure: + # ROOT (level=2) + # ├── COMPOSITE (level=1) [Conv->Relu + Add] + # │ ├── LEAF (level=0) [Conv, Relu - nodes 0,1] + # │ └── LEAF (level=0) [Add - node 2] + # └── LEAF (level=0) [MatMul, Relu - nodes 3,4] + + # Create leaf regions + leaf_conv_relu = self._create_test_region( + region_id=1, level=0, region_type=RegionType.LEAF, node_indices=[0, 1] + ) + leaf_add = self._create_test_region( + region_id=2, level=0, region_type=RegionType.LEAF, node_indices=[2] + ) + leaf_matmul_relu = self._create_test_region( + region_id=3, level=0, region_type=RegionType.LEAF, node_indices=[3, 4] + ) + + # Create composite region containing conv_relu and add + composite = self._create_test_region( + region_id=4, level=1, region_type=RegionType.COMPOSITE, node_indices=[] + ) + composite.add_child(leaf_conv_relu) + composite.add_child(leaf_add) + + # Create root region containing everything + root = self._create_test_region( + region_id=5, level=2, region_type=RegionType.ROOT, node_indices=[] + ) + root.add_child(composite) + root.add_child(leaf_matmul_relu) + + # Generate pattern for root and print tree + root_pattern = RegionPattern.from_region(root, graph) + tree_output = root_pattern.format_tree(root, graph) + + print("\n" + "=" * 60) + print("Region Tree Structure:") + print("=" * 60) + print(tree_output) + print("=" * 60) + + # Verify tree output contains expected elements + assert "Region 5" in tree_output # Root + assert "Region 4" in tree_output # Composite + assert "Region 1" in tree_output # Leaf conv_relu + assert "Region 2" in tree_output # Leaf add + assert "Region 3" in tree_output # Leaf matmul_relu + + # Verify indentation shows hierarchy + lines = tree_output.strip().split("\n") + assert len(lines) >= 3 # At least root + children + + # Root should have no indentation + assert lines[0].startswith("Region 5") + + # Children should be indented + indented_lines = [line for line in lines if line.startswith(" ")] + assert len(indented_lines) > 0 + + def test_pattern_matches_pattern(self): + """Test pattern-to-pattern matching.""" + pattern1 = RegionPattern(signature="Conv->Relu", size=2) + pattern2 = RegionPattern(signature="Conv->Relu", size=5) + pattern3 = RegionPattern(signature="Gemm->Relu", size=2) + + assert pattern1.matches(pattern2) # Same signature + assert not pattern1.matches(pattern3) # Different signature + + def test_pattern_matches_region(self): + """Test pattern-to-region matching.""" + graph = self._create_simple_graph() + + # Create region + region = self._create_test_region( + region_id=1, level=0, region_type=RegionType.LEAF, node_indices=[0, 1] + ) + + # Create pattern from region + pattern = RegionPattern.from_region(region, graph) + + # Match should return node IDs + node_ids = pattern.matches(region, graph) + assert node_ids is not None + assert set(node_ids) == {0, 1} + + def test_empty_region_pattern(self): + """Test pattern for empty region.""" + graph = self._create_simple_graph() + + # Create empty region + empty_region = self._create_test_region( + region_id=1, level=0, region_type=RegionType.LEAF, node_indices=[] + ) + + pattern = RegionPattern.from_region(empty_region, graph) + + assert pattern.is_empty + assert pattern.signature == "EMPTY" + assert pattern.size == 0 + + def test_symmetric_operation_signature(self): + """Test that symmetric operations (Add, Mul) have consistent signatures.""" + # Create two graphs with Add inputs in different order + inp1 = gs.Variable(name="input1", dtype=np.float32, shape=[1, 64]) + inp2 = gs.Variable(name="input2", dtype=np.float32, shape=[1, 64]) + out = gs.Variable(name="output", dtype=np.float32) + + # Graph 1: Add(inp1, inp2) + add1 = gs.Node(name="Add_0", op="Add", inputs=[inp1, inp2], outputs=[out]) + graph1 = gs.Graph(nodes=[add1], inputs=[inp1, inp2], outputs=[out], opset=13) + + # Graph 2: Add(inp2, inp1) - reversed inputs + add2 = gs.Node(name="Add_0", op="Add", inputs=[inp2, inp1], outputs=[out]) + graph2 = gs.Graph(nodes=[add2], inputs=[inp1, inp2], outputs=[out], opset=13) + + # Create regions + region1 = self._create_test_region(1, 0, RegionType.LEAF, [0]) + region2 = self._create_test_region(1, 0, RegionType.LEAF, [0]) + + pattern1 = RegionPattern.from_region(region1, graph1) + pattern2 = RegionPattern.from_region(region2, graph2) + + # Patterns should be equal regardless of input order + assert pattern1 == pattern2 + + def test_pattern_repr_and_str(self): + """Test string representations.""" + pattern = RegionPattern(signature="Conv->Relu", size=2) + + # str() shows just signature + assert str(pattern) == "Conv->Relu" + + # repr() shows full info + assert "RegionPattern" in repr(pattern) + assert "Conv->Relu" in repr(pattern) + assert "size=2" in repr(pattern) diff --git a/tests/unit/onnx/quantization/autotune/test_region_search.py b/tests/unit/onnx/quantization/autotune/test_region_search.py new file mode 100644 index 000000000..96f92eba7 --- /dev/null +++ b/tests/unit/onnx/quantization/autotune/test_region_search.py @@ -0,0 +1,418 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Tests for region search algorithms. + +Tests CombinedRegionSearch, RegionPartitioner, and TopDownRegionBuilder. +Note: Comprehensive integration tests with real ONNX graphs should be in separate integration test files. +""" + +import io +import os +import sys +import unittest + +# Add parent directory to path +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +import onnx +import onnx_graphsurgeon as gs +from onnx import helper + +from modelopt.onnx.quantization.autotune.common import Region, RegionType +from modelopt.onnx.quantization.autotune.region_search import ( + CombinedRegionSearch, + RegionPartitioner, + TopDownRegionBuilder, +) + + +def create_simple_linear_graph(): + """ + Create a simple linear graph: Input -> Conv -> Relu -> Output. + + This is the simplest possible graph for testing region discovery. + """ + # Input + input_tensor = helper.make_tensor_value_info("input", onnx.TensorProto.FLOAT, [1, 3, 224, 224]) + + # Output + output_tensor = helper.make_tensor_value_info( + "output", onnx.TensorProto.FLOAT, [1, 64, 224, 224] + ) + + # Conv node + conv_node = helper.make_node( + "Conv", inputs=["input", "conv_weight"], outputs=["conv_out"], name="conv" + ) + + # Relu node + relu_node = helper.make_node("Relu", inputs=["conv_out"], outputs=["output"], name="relu") + + # Create graph + graph = helper.make_graph( + [conv_node, relu_node], + "simple_linear", + [input_tensor], + [output_tensor], + initializer=[ + helper.make_tensor( + "conv_weight", onnx.TensorProto.FLOAT, [64, 3, 3, 3], [0.1] * (64 * 3 * 3 * 3) + ) + ], + ) + + # Create model + model = helper.make_model(graph, producer_name="test") + + # Convert to GraphSurgeon + gs_graph = gs.import_onnx(model) + return gs_graph + + +def create_divergent_graph(): + """ + Create a graph with divergence: Input -> Conv -> [Relu1, Relu2] -> Add -> Output. + + Tests divergence/convergence pattern detection. + """ + input_tensor = helper.make_tensor_value_info("input", onnx.TensorProto.FLOAT, [1, 3, 224, 224]) + output_tensor = helper.make_tensor_value_info( + "output", onnx.TensorProto.FLOAT, [1, 64, 224, 224] + ) + + conv_node = helper.make_node( + "Conv", inputs=["input", "conv_weight"], outputs=["conv_out"], name="conv" + ) + relu1_node = helper.make_node("Relu", inputs=["conv_out"], outputs=["relu1_out"], name="relu1") + relu2_node = helper.make_node("Relu", inputs=["conv_out"], outputs=["relu2_out"], name="relu2") + add_node = helper.make_node( + "Add", inputs=["relu1_out", "relu2_out"], outputs=["output"], name="add" + ) + + graph = helper.make_graph( + [conv_node, relu1_node, relu2_node, add_node], + "divergent", + [input_tensor], + [output_tensor], + initializer=[ + helper.make_tensor( + "conv_weight", onnx.TensorProto.FLOAT, [64, 3, 3, 3], [0.1] * (64 * 3 * 3 * 3) + ) + ], + ) + + model = helper.make_model(graph, producer_name="test") + gs_graph = gs.import_onnx(model) + return gs_graph + + +class TestRegionPartitioner(unittest.TestCase): + """Test RegionPartitioner basic functionality.""" + + def test_creation_linear_graph(self): + """Test creating RegionPartitioner with a simple linear graph.""" + graph = create_simple_linear_graph() + partitioner = RegionPartitioner(graph) + + assert partitioner is not None + assert partitioner.graph == graph + + def test_partition_linear_graph(self): + """Test partitioning a simple linear graph.""" + graph = create_simple_linear_graph() + partitioner = RegionPartitioner(graph) + + regions = partitioner.partition_graph() + + # Should create at least one region + assert len(regions) > 0 + + # Check that regions cover most nodes (ONNX GS may add Constant nodes that aren't partitioned) + total_nodes = sum(len(r.get_region_nodes_and_descendants()) for r in regions) + assert total_nodes > 0 + assert total_nodes <= len(graph.nodes) + + def test_partition_divergent_graph(self): + """Test partitioning a divergent graph.""" + graph = create_divergent_graph() + partitioner = RegionPartitioner(graph) + + regions = partitioner.partition_graph() + + # Should create regions covering all nodes + assert len(regions) > 0 + + # Check that regions cover most nodes (ONNX GS may add Constant nodes that aren't partitioned) + total_nodes = sum(len(r.get_region_nodes_and_descendants()) for r in regions) + assert total_nodes > 0 + assert total_nodes <= len(graph.nodes) + + +class TestTopDownRegionBuilder(unittest.TestCase): + """Test TopDownRegionBuilder basic functionality.""" + + def test_creation(self): + """Test creating TopDownRegionBuilder.""" + graph = create_simple_linear_graph() + + # Create a root region with all nodes + root = Region(region_id=0, level=0, region_type=RegionType.LEAF) + for idx in range(len(graph.nodes)): + root.nodes.add(idx) + + builder = TopDownRegionBuilder(graph, root) + + assert builder is not None + assert builder.graph == graph + + def test_build_composite_region(self): + """Test building a composite region.""" + graph = create_simple_linear_graph() + + # First partition to get initial regions + partitioner = RegionPartitioner(graph) + initial_regions = partitioner.partition_graph() + + if len(initial_regions) > 0: + # Use first region as root for top-down building + root_region = initial_regions[0] + + builder = TopDownRegionBuilder(graph, root_region, next_region_id=100) + + # Build composite region (may return LEAF or COMPOSITE depending on structure) + composite = builder.build_composite_region() + + assert composite is not None + # Region type depends on whether refinement created internal structure + # For simple linear graphs, may stay as LEAF + assert composite.type in [RegionType.LEAF, RegionType.COMPOSITE] + else: + self.skipTest("No initial regions to refine") + + +class TestCombinedRegionSearch(unittest.TestCase): + """Test CombinedRegionSearch two-phase algorithm.""" + + def test_creation(self): + """Test creating CombinedRegionSearch.""" + graph = create_simple_linear_graph() + search = CombinedRegionSearch(graph) + + assert search is not None + assert search.graph == graph + + def test_search_linear_graph(self): + """Test searching regions in a simple linear graph.""" + graph = create_simple_linear_graph() + search = CombinedRegionSearch(graph) + + regions = search.search_regions() + + # Should create regions + assert len(regions) > 0 + + # Check that regions cover most nodes (ONNX GS may add Constant nodes that aren't partitioned) + total_nodes = sum(len(r.get_region_nodes_and_descendants()) for r in regions) + assert total_nodes > 0 + assert total_nodes <= len(graph.nodes) + + # Each region should have valid inputs/outputs + for region in regions: + assert region.inputs is not None + assert region.outputs is not None + + def test_search_divergent_graph(self): + """Test searching regions in a divergent graph.""" + graph = create_divergent_graph() + search = CombinedRegionSearch(graph) + + regions = search.search_regions() + + # Should create regions + assert len(regions) > 0 + + # Check that regions cover most nodes (ONNX GS may add Constant nodes that aren't partitioned) + total_nodes = sum(len(r.get_region_nodes_and_descendants()) for r in regions) + assert total_nodes > 0 + assert total_nodes <= len(graph.nodes) + + def test_region_hierarchy(self): + """Test that regions have proper hierarchical structure.""" + graph = create_simple_linear_graph() + search = CombinedRegionSearch(graph) + + regions = search.search_regions() + + # Check that regions have children (hierarchical structure) + for region in regions: + if region.type == RegionType.COMPOSITE: + assert len(region.get_children()) > 0 + + # Verify parent-child relationships + for child in region.get_children(): + assert child.parent == region + + def test_parameters(self): + """Test CombinedRegionSearch with custom parameters.""" + graph = create_simple_linear_graph() + + # Test with different parameter values + search = CombinedRegionSearch( + graph, maximum_sequence_region_size=5, minimum_topdown_search_size=5 + ) + + regions = search.search_regions() + + assert len(regions) > 0 + + +class TestPrintTree(unittest.TestCase): + """Test print_tree functionality.""" + + def test_print_tree_basic(self): + """Test basic print_tree output.""" + graph = create_simple_linear_graph() + search = CombinedRegionSearch(graph) + search.search_regions() + + # Capture output to StringIO + output = io.StringIO() + search.print_tree(file=output) + + result = output.getvalue() + + # Should contain region information + assert "Region" in result + assert "Level" in result + assert "Type:" in result + + def test_print_tree_contains_node_info(self): + """Test that print_tree output contains node information.""" + graph = create_simple_linear_graph() + search = CombinedRegionSearch(graph) + search.search_regions() + + output = io.StringIO() + search.print_tree(file=output) + + result = output.getvalue() + + # Should contain node counts + assert "Direct nodes:" in result + assert "Total nodes:" in result + assert "Children:" in result + + def test_print_tree_contains_io_info(self): + """Test that print_tree output contains input/output tensor info.""" + graph = create_simple_linear_graph() + search = CombinedRegionSearch(graph) + search.search_regions() + + output = io.StringIO() + search.print_tree(file=output) + + result = output.getvalue() + + # Should contain I/O information + assert "Inputs:" in result + assert "Outputs:" in result + + def test_print_tree_divergent_graph(self): + """Test print_tree on a divergent graph with more complex structure.""" + graph = create_divergent_graph() + search = CombinedRegionSearch(graph) + search.search_regions() + + output = io.StringIO() + search.print_tree(file=output) + + result = output.getvalue() + + # Should produce valid output + assert "Region" in result + assert len(result) > 0 + + def test_print_tree_max_nodes_to_show(self): + """Test print_tree with custom max_nodes_to_show parameter.""" + graph = create_simple_linear_graph() + search = CombinedRegionSearch(graph) + search.search_regions() + + # Test with different max_nodes_to_show values + output1 = io.StringIO() + search.print_tree(max_items=1, file=output1) + + output2 = io.StringIO() + search.print_tree(max_items=10, file=output2) + + # Both should produce output + assert len(output1.getvalue()) > 0 + assert len(output2.getvalue()) > 0 + + def test_print_tree_specific_region(self): + """Test print_tree with a specific region instead of root.""" + graph = create_simple_linear_graph() + search = CombinedRegionSearch(graph) + regions = search.search_regions() + + if len(regions) > 0: + # Print a specific region + output = io.StringIO() + search.print_tree(region=regions[0], file=output) + + result = output.getvalue() + assert "Region" in result + assert f"Region {regions[0].id}" in result + + def test_print_tree_partitioner(self): + """Test print_tree on RegionPartitioner.""" + graph = create_simple_linear_graph() + partitioner = RegionPartitioner(graph) + partitioner.partition_graph() + + output = io.StringIO() + partitioner.print_tree(file=output) + + result = output.getvalue() + assert "Region" in result + assert len(result) > 0 + + def test_print_tree_top_down_builder(self): + """Test print_tree on TopDownRegionBuilder.""" + graph = create_simple_linear_graph() + + # Create a root region with all nodes + root = Region(region_id=0, level=0, region_type=RegionType.LEAF) + root.nodes.update(range(len(graph.nodes))) + + builder = TopDownRegionBuilder(graph, root) + # Compute region I/O boundaries before building + builder.compute_region_boundaries(root) + builder.build_composite_region() + + output = io.StringIO() + builder.print_tree(file=output) + + result = output.getvalue() + print("\n" + "=" * 60) + print("Region Tree Structure:") + print("=" * 60) + print(result) + print("=" * 60) + + assert "Region" in result + assert len(result) > 0